Send Status message on all newly-opened legacy substreams (#6593)

* Send Status message on all newly-opened legacy substreams

* Fix tests
This commit is contained in:
Pierre Krieger
2020-07-08 13:44:51 +02:00
committed by GitHub
parent 6eb2eb81c5
commit 8ca05879e8
5 changed files with 119 additions and 114 deletions
@@ -21,7 +21,8 @@ use bytes::BytesMut;
use futures::prelude::*;
use futures_codec::Framed;
use libp2p::core::{Endpoint, UpgradeInfo, InboundUpgrade, OutboundUpgrade, upgrade::ProtocolName};
use std::{collections::VecDeque, io, pin::Pin, vec::IntoIter as VecIntoIter};
use parking_lot::RwLock;
use std::{collections::VecDeque, io, pin::Pin, sync::Arc, vec::IntoIter as VecIntoIter};
use std::task::{Context, Poll};
use unsigned_varint::codec::UviBytes;
@@ -38,12 +39,13 @@ pub struct RegisteredProtocol {
/// List of protocol versions that we support.
/// Ordered in descending order so that the best comes first.
supported_versions: Vec<u8>,
/// Handshake to send after the substream is open.
handshake_message: Arc<RwLock<Vec<u8>>>,
}
impl RegisteredProtocol {
/// Creates a new `RegisteredProtocol`. The `custom_data` parameter will be
/// passed inside the `RegisteredProtocolOutput`.
pub fn new(protocol: impl Into<ProtocolId>, versions: &[u8])
/// Creates a new `RegisteredProtocol`.
pub fn new(protocol: impl Into<ProtocolId>, versions: &[u8], handshake_message: Arc<RwLock<Vec<u8>>>)
-> Self {
let protocol = protocol.into();
let mut base_name = b"/substrate/".to_vec();
@@ -58,8 +60,14 @@ impl RegisteredProtocol {
tmp.sort_unstable_by(|a, b| b.cmp(&a));
tmp
},
handshake_message,
}
}
/// Returns the `Arc` to the handshake message that was passed at initialization.
pub fn handshake_message(&self) -> &Arc<RwLock<Vec<u8>>> {
&self.handshake_message
}
}
impl Clone for RegisteredProtocol {
@@ -68,6 +76,7 @@ impl Clone for RegisteredProtocol {
id: self.id.clone(),
base_name: self.base_name.clone(),
supported_versions: self.supported_versions.clone(),
handshake_message: self.handshake_message.clone(),
}
}
}
@@ -244,10 +253,10 @@ impl ProtocolName for RegisteredProtocolName {
}
impl<TSubstream> InboundUpgrade<TSubstream> for RegisteredProtocol
where TSubstream: AsyncRead + AsyncWrite + Unpin,
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = RegisteredProtocolSubstream<TSubstream>;
type Future = future::Ready<Result<Self::Output, io::Error>>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, io::Error>> + Send>>;
type Error = io::Error;
fn upgrade_inbound(
@@ -255,26 +264,31 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin,
socket: TSubstream,
info: Self::Info,
) -> Self::Future {
let framed = {
let mut codec = UviBytes::default();
codec.set_max_len(16 * 1024 * 1024); // 16 MiB hard limit for packets.
Framed::new(socket, codec)
};
Box::pin(async move {
let mut framed = {
let mut codec = UviBytes::default();
codec.set_max_len(16 * 1024 * 1024); // 16 MiB hard limit for packets.
Framed::new(socket, codec)
};
future::ok(RegisteredProtocolSubstream {
is_closing: false,
endpoint: Endpoint::Listener,
send_queue: VecDeque::new(),
requires_poll_flush: false,
inner: framed.fuse(),
protocol_version: info.version,
clogged_fuse: false,
let handshake = BytesMut::from(&self.handshake_message.read()[..]);
framed.send(handshake).await?;
Ok(RegisteredProtocolSubstream {
is_closing: false,
endpoint: Endpoint::Listener,
send_queue: VecDeque::new(),
requires_poll_flush: false,
inner: framed.fuse(),
protocol_version: info.version,
clogged_fuse: false,
})
})
}
}
impl<TSubstream> OutboundUpgrade<TSubstream> for RegisteredProtocol
where TSubstream: AsyncRead + AsyncWrite + Unpin,
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = <Self as InboundUpgrade<TSubstream>>::Output;
type Future = <Self as InboundUpgrade<TSubstream>>::Future;
@@ -285,16 +299,25 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin,
socket: TSubstream,
info: Self::Info,
) -> Self::Future {
let framed = Framed::new(socket, UviBytes::default());
Box::pin(async move {
let mut framed = {
let mut codec = UviBytes::default();
codec.set_max_len(16 * 1024 * 1024); // 16 MiB hard limit for packets.
Framed::new(socket, codec)
};
future::ok(RegisteredProtocolSubstream {
is_closing: false,
endpoint: Endpoint::Dialer,
send_queue: VecDeque::new(),
requires_poll_flush: false,
inner: framed.fuse(),
protocol_version: info.version,
clogged_fuse: false,
let handshake = BytesMut::from(&self.handshake_message.read()[..]);
framed.send(handshake).await?;
Ok(RegisteredProtocolSubstream {
is_closing: false,
endpoint: Endpoint::Dialer,
send_queue: VecDeque::new(),
requires_poll_flush: false,
inner: framed.fuse(),
protocol_version: info.version,
clogged_fuse: false,
})
})
}
}