Clean up the "custom protocol" code (#832)

* No longer use a channel for custom protocols

* Fix warnings

* Add graceful shutdown and comments

* Minor cleanups

* Add warning if queue goes over threshold

* Increase threshold

* Address grumble
This commit is contained in:
Pierre Krieger
2018-09-29 13:01:23 +02:00
committed by Gav Wood
parent c54350661d
commit 79d538ee66
2 changed files with 230 additions and 170 deletions
+201 -141
View File
@@ -15,22 +15,22 @@
// along with Substrate. If not, see <http://www.gnu.org/licenses/>.
use bytes::{Bytes, BytesMut};
use ProtocolId;
use libp2p::core::{Multiaddr, ConnectionUpgrade, Endpoint};
use PacketId;
use libp2p::tokio_codec::Framed;
use std::collections::VecDeque;
use std::io::Error as IoError;
use std::vec::IntoIter as VecIntoIter;
use futures::{future, Future, stream, Stream, Sink};
use futures::sync::mpsc;
use futures::{prelude::*, future, stream, task};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
use ProtocolId;
/// Connection upgrade for a single protocol.
///
/// Note that "a single protocol" here refers to `par` for example. However
/// each protocol can have multiple different versions for networking purposes.
#[derive(Clone)]
pub struct RegisteredProtocol<T> {
pub struct RegisteredProtocol<TUserData> {
/// Id of the protocol for API purposes.
id: ProtocolId,
/// Base name of the protocol as advertised on the network.
@@ -41,67 +41,202 @@ pub struct RegisteredProtocol<T> {
/// The packet count is used to filter out invalid messages.
supported_versions: Vec<(u8, u8)>,
/// Custom data.
custom_data: T,
custom_data: TUserData,
}
/// Output of a `RegisteredProtocol` upgrade.
pub struct RegisteredProtocolOutput<T> {
/// Data passed to `RegisteredProtocol::new`.
pub custom_data: T,
/// Id of the protocol.
pub protocol_id: ProtocolId,
/// Endpoint of the connection.
pub endpoint: Endpoint,
/// Version of the protocol that was negotiated.
pub protocol_version: u8,
/// Channel to sender outgoing messages to.
// TODO: consider assembling packet_id here
pub outgoing: mpsc::UnboundedSender<Bytes>,
/// Stream where incoming messages are received. The stream ends whenever
/// either side is closed.
pub incoming: Box<Stream<Item = (PacketId, Bytes), Error = IoError> + Send>,
}
impl<T> RegisteredProtocol<T> {
impl<TUserData> RegisteredProtocol<TUserData> {
/// Creates a new `RegisteredProtocol`. The `custom_data` parameter will be
/// passed inside the `RegisteredProtocolOutput`.
pub fn new(custom_data: T, protocol: ProtocolId, versions: &[(u8, u8)])
pub fn new(custom_data: TUserData, protocol: ProtocolId, versions: &[(u8, u8)])
-> Self {
let mut proto_name = Bytes::from_static(b"/substrate/");
proto_name.extend_from_slice(&protocol);
proto_name.extend_from_slice(b"/");
let mut base_name = Bytes::from_static(b"/substrate/");
base_name.extend_from_slice(&protocol);
base_name.extend_from_slice(b"/");
RegisteredProtocol {
base_name: proto_name,
base_name,
id: protocol,
supported_versions: {
let mut tmp: Vec<_> = versions.iter().rev().cloned().collect();
tmp.sort_unstable_by(|a, b| b.1.cmp(&a.1));
tmp
},
custom_data: custom_data,
custom_data,
}
}
/// Returns the ID of the protocol.
#[inline]
pub fn id(&self) -> ProtocolId {
self.id
}
/// Returns the custom data that was passed to `new`.
pub fn custom_data(&self) -> &T {
#[inline]
pub fn custom_data(&self) -> &TUserData {
&self.custom_data
}
}
// `Maf` is short for `MultiaddressFuture`
impl<T, C> ConnectionUpgrade<C> for RegisteredProtocol<T>
where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/
/// Output of a `RegisteredProtocol` upgrade.
pub struct RegisteredProtocolSubstream<TSubstream> {
/// If true, we are in the process of closing the sink.
is_closing: bool,
/// Buffer of packets to send.
send_queue: VecDeque<Bytes>,
/// If true, we should call `poll_complete` on the inner sink.
requires_poll_complete: bool,
/// The underlying substream.
inner: stream::Fuse<Framed<TSubstream, UviBytes<Bytes>>>,
/// Maximum packet id.
packet_count: u8,
/// Id of the protocol.
protocol_id: ProtocolId,
/// Version of the protocol that was negotiated.
protocol_version: u8,
/// Task to notify when something is changed and we need to be polled.
to_notify: Option<task::Task>,
}
/// Packet of data that can be sent or received.
#[derive(Debug, Clone)]
pub struct Packet {
/// Identifier of the packet.
pub id: u8,
/// The raw data.
pub data: Bytes,
}
impl<TSubstream> RegisteredProtocolSubstream<TSubstream> {
/// Returns the protocol id.
#[inline]
pub fn protocol_id(&self) -> ProtocolId {
self.protocol_id
}
/// Returns the version of the protocol that was negotiated.
#[inline]
pub fn protocol_version(&self) -> u8 {
self.protocol_version
}
/// Starts a graceful shutdown process on this substream.
///
/// Note that "graceful" means that we sent a closing message. We don't wait for any
/// confirmation from the remote.
///
/// After calling this, the stream is guaranteed to finish soon-ish.
pub fn shutdown(&mut self) {
self.is_closing = true;
if let Some(task) = self.to_notify.take() {
task.notify();
}
}
/// Sends a message to the substream.
pub fn send_message(&mut self, Packet { id: packet_id, data }: Packet) {
if packet_id >= self.packet_count {
error!(target: "sub-libp2p", "Tried to send a packet with an invalid ID {}", packet_id);
return;
}
let mut message = Bytes::with_capacity(1 + data.len());
message.extend_from_slice(&[packet_id]);
message.extend_from_slice(&data);
self.send_queue.push_back(message);
// If the length of the queue goes over a certain arbitrary threshold, we print a warning.
// TODO: figure out a good threshold
if self.send_queue.len() >= 2048 {
warn!(target: "sub-libp2p", "Queue of packets to send over substream is pretty \
large: {}", self.send_queue.len());
}
if let Some(task) = self.to_notify.take() {
task.notify();
}
}
/// Turns raw data into a packet and checks whether it is valid.
fn data_to_packet(&self, mut data: BytesMut) -> Result<Packet, ()> {
// The `data` should be prefixed by the packet ID, therefore an empty packet is invalid.
if data.is_empty() {
debug!(target: "sub-libp2p", "ignoring incoming packet because it was empty");
return Err(());
}
let packet = {
let id = data[0];
let data = data.split_off(1);
Packet { id, data: data.freeze() }
};
if packet.id >= self.packet_count {
debug!(target: "sub-libp2p", "ignoring incoming packet because packet_id {} is \
too large", packet.id);
return Err(())
}
Ok(packet)
}
}
impl<TSubstream> Stream for RegisteredProtocolSubstream<TSubstream>
where TSubstream: AsyncRead + AsyncWrite,
{
type Item = Packet;
type Error = IoError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
// If we are closing, close as soon as the Sink is closed.
if self.is_closing {
return Ok(self.inner.close()?.map(|()| None));
}
// Flushing the local queue.
while let Some(packet) = self.send_queue.pop_front() {
match self.inner.start_send(packet)? {
AsyncSink::NotReady(packet) => {
self.send_queue.push_front(packet);
break;
},
AsyncSink::Ready => self.requires_poll_complete = true,
}
}
// Flushing if necessary.
if self.requires_poll_complete {
if let Async::Ready(()) = self.inner.poll_complete()? {
self.requires_poll_complete = false;
}
}
// Receiving incoming packets.
// Note that `inner` is wrapped in a `Fuse`, therefore we can poll it forever.
loop {
match self.inner.poll()? {
Async::Ready(Some(data)) =>
if let Ok(packet) = self.data_to_packet(data) {
return Ok(Async::Ready(Some(packet)))
},
Async::Ready(None) =>
if !self.requires_poll_complete && self.send_queue.is_empty() {
return Ok(Async::Ready(None))
} else {
break
},
Async::NotReady => break,
}
}
self.to_notify = Some(task::current());
Ok(Async::NotReady)
}
}
impl<TSubstream, TUserData> ConnectionUpgrade<TSubstream> for RegisteredProtocol<TUserData>
where TSubstream: AsyncRead + AsyncWrite,
TUserData: Clone,
{
type NamesIter = VecIntoIter<(Bytes, Self::UpgradeIdentifier)>;
type UpgradeIdentifier = u8; // Protocol version
@@ -117,15 +252,15 @@ where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/
}).collect::<Vec<_>>().into_iter()
}
type Output = RegisteredProtocolOutput<T>;
type Output = RegisteredProtocolSubstream<TSubstream>;
type Future = future::FutureResult<Self::Output, IoError>;
#[allow(deprecated)]
fn upgrade(
self,
socket: C,
socket: TSubstream,
protocol_version: Self::UpgradeIdentifier,
endpoint: Endpoint,
_: Endpoint,
_: &Multiaddr
) -> Self::Future {
let packet_count = self.supported_versions
@@ -134,103 +269,27 @@ where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/
.expect("negotiated protocol version that wasn't advertised ; \
programmer error")
.1;
let framed = Framed::new(socket, UviBytes::default());
// This function is called whenever we successfully negotiated a
// protocol with a remote (both if initiated by us or by the remote)
// This channel is used to send outgoing packets to the custom_data
// for this open substream.
let (msg_tx, msg_rx) = mpsc::unbounded();
// Build the sink for outgoing network bytes, and the stream for
// incoming instructions. `stream` implements `Stream<Item = Message>`.
enum Message {
/// Received data from the network.
RecvSocket(BytesMut),
/// Data to send to the network.
/// The packet_id must already be inside the `Bytes`.
SendReq(Bytes),
/// The socket has been closed.
Finished,
}
let (sink, stream) = {
let framed = AsyncRead::framed(socket, UviBytes::default());
let msg_rx = msg_rx.map(Message::SendReq)
.map_err(|()| unreachable!("mpsc::UnboundedReceiver never errors"));
let (sink, stream) = framed.split();
let stream = stream.map(Message::RecvSocket)
.chain(stream::once(Ok(Message::Finished)));
(sink, msg_rx.select(stream))
};
let incoming = stream::unfold((sink, stream, false), move |(sink, stream, finished)| {
if finished {
return None
}
Some(stream
.into_future()
.map_err(|(err, _)| err)
.and_then(move |(message, stream)|
match message {
Some(Message::RecvSocket(mut data)) => {
// The `data` should be prefixed by the packet ID,
// therefore an empty packet is invalid.
if data.is_empty() {
debug!(target: "sub-libp2p", "ignoring incoming \
packet because it was empty");
let f = future::ok((None, (sink, stream, false)));
return future::Either::A(f)
}
let packet_id = data[0];
let data = data.split_off(1);
if packet_id >= packet_count {
debug!(target: "sub-libp2p", "ignoring incoming packet \
because packet_id {} is too large", packet_id);
let f = future::ok((None, (sink, stream, false)));
future::Either::A(f)
} else {
let out = Some((packet_id, data.freeze()));
let f = future::ok((out, (sink, stream, false)));
future::Either::A(f)
}
},
Some(Message::SendReq(data)) => {
let fut = sink.send(data)
.map(move |sink| (None, (sink, stream, false)));
future::Either::B(fut)
},
Some(Message::Finished) | None => {
let f = future::ok((None, (sink, stream, true)));
future::Either::A(f)
},
}
))
}).filter_map(|v| v);
let out = RegisteredProtocolOutput {
custom_data: self.custom_data,
future::ok(RegisteredProtocolSubstream {
is_closing: false,
send_queue: VecDeque::new(),
requires_poll_complete: false,
inner: framed.fuse(),
packet_count,
protocol_id: self.id,
endpoint,
protocol_version: protocol_version,
outgoing: msg_tx,
incoming: Box::new(incoming),
};
future::ok(out)
protocol_version,
to_notify: None,
})
}
}
// Connection upgrade for all the protocols contained in it.
#[derive(Clone)]
pub struct RegisteredProtocols<T>(pub Vec<RegisteredProtocol<T>>);
pub struct RegisteredProtocols<TUserData>(pub Vec<RegisteredProtocol<TUserData>>);
impl<T> RegisteredProtocols<T> {
impl<TUserData> RegisteredProtocols<TUserData> {
/// Returns the number of protocols.
#[inline]
pub fn len(&self) -> usize {
@@ -239,7 +298,7 @@ impl<T> RegisteredProtocols<T> {
/// Finds a protocol in the list by its id.
pub fn find_protocol(&self, protocol: ProtocolId)
-> Option<&RegisteredProtocol<T>> {
-> Option<&RegisteredProtocol<TUserData>> {
self.0.iter().find(|p| p.id == protocol)
}
@@ -249,35 +308,36 @@ impl<T> RegisteredProtocols<T> {
}
}
impl<T> Default for RegisteredProtocols<T> {
impl<TUserData> Default for RegisteredProtocols<TUserData> {
fn default() -> Self {
RegisteredProtocols(Vec::new())
}
}
impl<T, C> ConnectionUpgrade<C> for RegisteredProtocols<T>
where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/
impl<TSubstream, TUserData> ConnectionUpgrade<TSubstream> for RegisteredProtocols<TUserData>
where TSubstream: AsyncRead + AsyncWrite,
TUserData: Clone,
{
type NamesIter = VecIntoIter<(Bytes, Self::UpgradeIdentifier)>;
type UpgradeIdentifier = (usize,
<RegisteredProtocol<T> as ConnectionUpgrade<C>>::UpgradeIdentifier);
<RegisteredProtocol<TUserData> as ConnectionUpgrade<TSubstream>>::UpgradeIdentifier);
fn protocol_names(&self) -> Self::NamesIter {
// We concat the lists of `RegisteredProtocol::protocol_names` for
// each protocol.
self.0.iter().enumerate().flat_map(|(n, proto)|
ConnectionUpgrade::<C>::protocol_names(proto)
ConnectionUpgrade::<TSubstream>::protocol_names(proto)
.map(move |(name, id)| (name, (n, id)))
).collect::<Vec<_>>().into_iter()
}
type Output = <RegisteredProtocol<T> as ConnectionUpgrade<C>>::Output;
type Future = <RegisteredProtocol<T> as ConnectionUpgrade<C>>::Future;
type Output = <RegisteredProtocol<TUserData> as ConnectionUpgrade<TSubstream>>::Output;
type Future = <RegisteredProtocol<TUserData> as ConnectionUpgrade<TSubstream>>::Future;
#[inline]
fn upgrade(
self,
socket: C,
socket: TSubstream,
upgrade_identifier: Self::UpgradeIdentifier,
endpoint: Endpoint,
remote_addr: &Multiaddr
@@ -15,7 +15,7 @@
// along with Substrate. If not, see <http://www.gnu.org/licenses/>.
use bytes::Bytes;
use custom_proto::{RegisteredProtocols, RegisteredProtocolOutput};
use custom_proto::{Packet, RegisteredProtocols, RegisteredProtocolSubstream};
use futures::{prelude::*, task};
use libp2p::core::{ConnectionUpgrade, Endpoint, PeerId, PublicKey, upgrade};
use libp2p::core::nodes::handled_node::{NodeHandler, NodeHandlerEndpoint, NodeHandlerEvent};
@@ -52,7 +52,7 @@ pub struct SubstrateNodeHandler<TSubstream, TUserData> {
/// List of registered custom protocols.
registered_custom: Arc<RegisteredProtocols<TUserData>>,
/// Substreams open for "custom" protocols (eg. dot).
custom_protocols_substreams: Vec<RegisteredProtocolOutput<TUserData>>,
custom_protocols_substreams: Vec<RegisteredProtocolSubstream<TSubstream>>,
/// Address of the node.
address: Multiaddr,
@@ -81,10 +81,10 @@ pub struct SubstrateNodeHandler<TSubstream, TUserData> {
next_identify: Interval,
/// Substreams being upgraded on the listening side.
upgrades_in_progress_listen: Vec<Box<Future<Item = FinalUpgrade<TSubstream, TUserData>, Error = IoError> + Send>>,
upgrades_in_progress_listen: Vec<Box<Future<Item = FinalUpgrade<TSubstream>, Error = IoError> + Send>>,
/// Substreams being upgraded on the dialing side. Contrary to `upgrades_in_progress_listen`,
/// these have a known purpose.
upgrades_in_progress_dial: Vec<(UpgradePurpose, Box<Future<Item = FinalUpgrade<TSubstream, TUserData>, Error = IoError> + Send>)>,
upgrades_in_progress_dial: Vec<(UpgradePurpose, Box<Future<Item = FinalUpgrade<TSubstream>, Error = IoError> + Send>)>,
/// The substreams we want to open.
queued_dial_upgrades: Vec<UpgradePurpose>,
/// Number of outbound substreams the outside should open for us.
@@ -396,6 +396,11 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
fn shutdown(&mut self) {
// TODO: close gracefully
self.is_shutting_down = true;
for custom_proto in &mut self.custom_protocols_substreams {
custom_proto.shutdown();
}
if let Some(to_notify) = self.to_notify.take() {
to_notify.notify();
}
@@ -403,6 +408,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
fn poll(&mut self) -> Poll<Option<NodeHandlerEvent<Self::OutboundOpenInfo, Self::OutEvent>>, IoError> {
if self.is_shutting_down {
// TODO: finish only when everything is closed
return Ok(Async::Ready(None));
}
@@ -456,7 +462,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
) {
debug_assert!(self.registered_custom.has_protocol(protocol),
"invalid protocol id requested in the API of the libp2p networking");
let proto = match self.custom_protocols_substreams.iter().find(|p| p.protocol_id == protocol) {
let proto = match self.custom_protocols_substreams.iter_mut().find(|p| p.protocol_id() == protocol) {
Some(proto) => proto,
None => {
// We are processing a message event before we could report to the outside that
@@ -465,13 +471,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
},
};
let mut message = Bytes::with_capacity(1 + data.len());
message.extend_from_slice(&[packet_id]);
message.extend_from_slice(&data);
if let Err(_) = proto.outgoing.unbounded_send(message) {
error!(target: "sub-libp2p", "Error while sending custom message to channel");
}
proto.send_message(Packet { id: packet_id, data: data.into() });
}
/// The node will try to open a Kademlia substream and produce a `KadOpen` event containing the
@@ -521,7 +521,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
/// Optionally produces an event to dispatch.
fn inject_fully_negotiated(
&mut self,
upgrade: FinalUpgrade<TSubstream, TUserData>
upgrade: FinalUpgrade<TSubstream>
) -> Option<SubstrateOutEvent<TSubstream>> {
match upgrade {
FinalUpgrade::IdentifyListener(sender) =>
@@ -561,15 +561,15 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
}
},
FinalUpgrade::Custom(proto) => {
self.cancel_dial_upgrade(&UpgradePurpose::Custom(proto.protocol_id));
if self.custom_protocols_substreams.iter().any(|p| p.protocol_id == proto.protocol_id) {
self.cancel_dial_upgrade(&UpgradePurpose::Custom(proto.protocol_id()));
if self.custom_protocols_substreams.iter().any(|p| p.protocol_id() == proto.protocol_id()) {
// Skipping protocol that's already open.
return None;
}
let event = SubstrateOutEvent::CustomProtocolOpen {
protocol_id: proto.protocol_id,
version: proto.protocol_version,
protocol_id: proto.protocol_id(),
version: proto.protocol_version(),
};
self.custom_protocols_substreams.push(proto);
@@ -686,32 +686,32 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
// Poll for messages on the custom protocol stream.
for n in (0 .. self.custom_protocols_substreams.len()).rev() {
let mut custom_proto = self.custom_protocols_substreams.swap_remove(n);
match custom_proto.incoming.poll() {
match custom_proto.poll() {
Ok(Async::NotReady) => self.custom_protocols_substreams.push(custom_proto),
Ok(Async::Ready(Some((packet_id, data)))) => {
let protocol_id = custom_proto.protocol_id;
Ok(Async::Ready(Some(Packet { id, data }))) => {
let protocol_id = custom_proto.protocol_id();
self.custom_protocols_substreams.push(custom_proto);
return Ok(Async::Ready(Some(SubstrateOutEvent::CustomMessage {
protocol_id,
packet_id,
packet_id: id,
data,
})));
},
Ok(Async::Ready(None)) => {
// Trying to reopen the protocol.
self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id));
self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id()));
self.num_out_user_must_open += 1;
return Ok(Async::Ready(Some(SubstrateOutEvent::CustomProtocolClosed {
protocol_id: custom_proto.protocol_id,
protocol_id: custom_proto.protocol_id(),
result: Ok(()),
})))
},
Err(err) => {
// Trying to reopen the protocol.
self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id));
self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id()));
self.num_out_user_must_open += 1;
return Ok(Async::Ready(Some(SubstrateOutEvent::CustomProtocolClosed {
protocol_id: custom_proto.protocol_id,
protocol_id: custom_proto.protocol_id(),
result: Err(err),
})))
},
@@ -851,16 +851,16 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static,
}
/// Enum of all the possible protocols our service handles.
enum FinalUpgrade<TSubstream, TUserData> {
enum FinalUpgrade<TSubstream> {
Kad(KadConnecController, Box<Stream<Item = KadIncomingRequest, Error = IoError> + Send>),
IdentifyListener(identify::IdentifySender<TSubstream>),
IdentifyDialer(identify::IdentifyInfo, Multiaddr),
PingDialer(ping::PingDialer<TSubstream, Instant>),
PingListener(ping::PingListener<TSubstream>),
Custom(RegisteredProtocolOutput<TUserData>),
Custom(RegisteredProtocolSubstream<TSubstream>),
}
impl<TSubstream, TUserData> From<ping::PingOutput<TSubstream, Instant>> for FinalUpgrade<TSubstream, TUserData> {
impl<TSubstream> From<ping::PingOutput<TSubstream, Instant>> for FinalUpgrade<TSubstream> {
fn from(out: ping::PingOutput<TSubstream, Instant>) -> Self {
match out {
ping::PingOutput::Ponger(ponger) => FinalUpgrade::PingListener(ponger),
@@ -869,7 +869,7 @@ impl<TSubstream, TUserData> From<ping::PingOutput<TSubstream, Instant>> for Fina
}
}
impl<TSubstream, TUserData> From<identify::IdentifyOutput<TSubstream>> for FinalUpgrade<TSubstream, TUserData> {
impl<TSubstream> From<identify::IdentifyOutput<TSubstream>> for FinalUpgrade<TSubstream> {
fn from(out: identify::IdentifyOutput<TSubstream>) -> Self {
match out {
identify::IdentifyOutput::RemoteInfo { info, observed_addr } =>