diff --git a/substrate/core/network-libp2p/src/custom_proto.rs b/substrate/core/network-libp2p/src/custom_proto.rs
index 9aa3d03e9e..6d0a5d5b1e 100644
--- a/substrate/core/network-libp2p/src/custom_proto.rs
+++ b/substrate/core/network-libp2p/src/custom_proto.rs
@@ -15,22 +15,22 @@
// along with Substrate. If not, see .
use bytes::{Bytes, BytesMut};
-use ProtocolId;
use libp2p::core::{Multiaddr, ConnectionUpgrade, Endpoint};
-use PacketId;
+use libp2p::tokio_codec::Framed;
+use std::collections::VecDeque;
use std::io::Error as IoError;
use std::vec::IntoIter as VecIntoIter;
-use futures::{future, Future, stream, Stream, Sink};
-use futures::sync::mpsc;
+use futures::{prelude::*, future, stream, task};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
+use ProtocolId;
/// Connection upgrade for a single protocol.
///
/// Note that "a single protocol" here refers to `par` for example. However
/// each protocol can have multiple different versions for networking purposes.
#[derive(Clone)]
-pub struct RegisteredProtocol {
+pub struct RegisteredProtocol {
/// Id of the protocol for API purposes.
id: ProtocolId,
/// Base name of the protocol as advertised on the network.
@@ -41,67 +41,202 @@ pub struct RegisteredProtocol {
/// The packet count is used to filter out invalid messages.
supported_versions: Vec<(u8, u8)>,
/// Custom data.
- custom_data: T,
+ custom_data: TUserData,
}
-/// Output of a `RegisteredProtocol` upgrade.
-pub struct RegisteredProtocolOutput {
- /// Data passed to `RegisteredProtocol::new`.
- pub custom_data: T,
-
- /// Id of the protocol.
- pub protocol_id: ProtocolId,
-
- /// Endpoint of the connection.
- pub endpoint: Endpoint,
-
- /// Version of the protocol that was negotiated.
- pub protocol_version: u8,
-
- /// Channel to sender outgoing messages to.
- // TODO: consider assembling packet_id here
- pub outgoing: mpsc::UnboundedSender,
-
- /// Stream where incoming messages are received. The stream ends whenever
- /// either side is closed.
- pub incoming: Box + Send>,
-}
-
-impl RegisteredProtocol {
+impl RegisteredProtocol {
/// Creates a new `RegisteredProtocol`. The `custom_data` parameter will be
/// passed inside the `RegisteredProtocolOutput`.
- pub fn new(custom_data: T, protocol: ProtocolId, versions: &[(u8, u8)])
+ pub fn new(custom_data: TUserData, protocol: ProtocolId, versions: &[(u8, u8)])
-> Self {
- let mut proto_name = Bytes::from_static(b"/substrate/");
- proto_name.extend_from_slice(&protocol);
- proto_name.extend_from_slice(b"/");
+ let mut base_name = Bytes::from_static(b"/substrate/");
+ base_name.extend_from_slice(&protocol);
+ base_name.extend_from_slice(b"/");
RegisteredProtocol {
- base_name: proto_name,
+ base_name,
id: protocol,
supported_versions: {
let mut tmp: Vec<_> = versions.iter().rev().cloned().collect();
tmp.sort_unstable_by(|a, b| b.1.cmp(&a.1));
tmp
},
- custom_data: custom_data,
+ custom_data,
}
}
/// Returns the ID of the protocol.
+ #[inline]
pub fn id(&self) -> ProtocolId {
self.id
}
/// Returns the custom data that was passed to `new`.
- pub fn custom_data(&self) -> &T {
+ #[inline]
+ pub fn custom_data(&self) -> &TUserData {
&self.custom_data
}
}
-// `Maf` is short for `MultiaddressFuture`
-impl ConnectionUpgrade for RegisteredProtocol
-where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/
+/// Output of a `RegisteredProtocol` upgrade.
+pub struct RegisteredProtocolSubstream {
+ /// If true, we are in the process of closing the sink.
+ is_closing: bool,
+ /// Buffer of packets to send.
+ send_queue: VecDeque,
+ /// If true, we should call `poll_complete` on the inner sink.
+ requires_poll_complete: bool,
+ /// The underlying substream.
+ inner: stream::Fuse>>,
+ /// Maximum packet id.
+ packet_count: u8,
+ /// Id of the protocol.
+ protocol_id: ProtocolId,
+ /// Version of the protocol that was negotiated.
+ protocol_version: u8,
+ /// Task to notify when something is changed and we need to be polled.
+ to_notify: Option,
+}
+
+/// Packet of data that can be sent or received.
+#[derive(Debug, Clone)]
+pub struct Packet {
+ /// Identifier of the packet.
+ pub id: u8,
+ /// The raw data.
+ pub data: Bytes,
+}
+
+impl RegisteredProtocolSubstream {
+ /// Returns the protocol id.
+ #[inline]
+ pub fn protocol_id(&self) -> ProtocolId {
+ self.protocol_id
+ }
+
+ /// Returns the version of the protocol that was negotiated.
+ #[inline]
+ pub fn protocol_version(&self) -> u8 {
+ self.protocol_version
+ }
+
+ /// Starts a graceful shutdown process on this substream.
+ ///
+ /// Note that "graceful" means that we sent a closing message. We don't wait for any
+ /// confirmation from the remote.
+ ///
+ /// After calling this, the stream is guaranteed to finish soon-ish.
+ pub fn shutdown(&mut self) {
+ self.is_closing = true;
+ if let Some(task) = self.to_notify.take() {
+ task.notify();
+ }
+ }
+
+ /// Sends a message to the substream.
+ pub fn send_message(&mut self, Packet { id: packet_id, data }: Packet) {
+ if packet_id >= self.packet_count {
+ error!(target: "sub-libp2p", "Tried to send a packet with an invalid ID {}", packet_id);
+ return;
+ }
+
+ let mut message = Bytes::with_capacity(1 + data.len());
+ message.extend_from_slice(&[packet_id]);
+ message.extend_from_slice(&data);
+ self.send_queue.push_back(message);
+
+ // If the length of the queue goes over a certain arbitrary threshold, we print a warning.
+ // TODO: figure out a good threshold
+ if self.send_queue.len() >= 2048 {
+ warn!(target: "sub-libp2p", "Queue of packets to send over substream is pretty \
+ large: {}", self.send_queue.len());
+ }
+
+ if let Some(task) = self.to_notify.take() {
+ task.notify();
+ }
+ }
+
+ /// Turns raw data into a packet and checks whether it is valid.
+ fn data_to_packet(&self, mut data: BytesMut) -> Result {
+ // The `data` should be prefixed by the packet ID, therefore an empty packet is invalid.
+ if data.is_empty() {
+ debug!(target: "sub-libp2p", "ignoring incoming packet because it was empty");
+ return Err(());
+ }
+
+ let packet = {
+ let id = data[0];
+ let data = data.split_off(1);
+ Packet { id, data: data.freeze() }
+ };
+
+ if packet.id >= self.packet_count {
+ debug!(target: "sub-libp2p", "ignoring incoming packet because packet_id {} is \
+ too large", packet.id);
+ return Err(())
+ }
+
+ Ok(packet)
+ }
+}
+
+impl Stream for RegisteredProtocolSubstream
+where TSubstream: AsyncRead + AsyncWrite,
+{
+ type Item = Packet;
+ type Error = IoError;
+
+ fn poll(&mut self) -> Poll