From cfcde3d3b0cceda92e96c1e83ede25db7a283dd2 Mon Sep 17 00:00:00 2001 From: Robert Klotzner Date: Thu, 4 Feb 2021 12:31:36 +0100 Subject: [PATCH] End multiplexer stream once one of its inputs end. (#2380) * End multiplexer stream once one of its inputs end. Also add test, that we don't panic once a stream is exhausted. * Don't fuse already fused stream. --- polkadot/node/network/bridge/src/lib.rs | 2 +- .../node/network/bridge/src/multiplexer.rs | 54 ++++++++++++++++--- 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/polkadot/node/network/bridge/src/lib.rs b/polkadot/node/network/bridge/src/lib.rs index 030a3c5b8d..4e1c179e00 100644 --- a/polkadot/node/network/bridge/src/lib.rs +++ b/polkadot/node/network/bridge/src/lib.rs @@ -169,7 +169,7 @@ where let action = { let subsystem_next = ctx.recv().fuse(); let mut net_event_next = event_stream.next().fuse(); - let mut req_res_event_next = bridge.request_multiplexer.next().fuse(); + let mut req_res_event_next = bridge.request_multiplexer.next(); futures::pin_mut!(subsystem_next); futures::select! { diff --git a/polkadot/node/network/bridge/src/multiplexer.rs b/polkadot/node/network/bridge/src/multiplexer.rs index e9a312834a..d9475d0ea4 100644 --- a/polkadot/node/network/bridge/src/multiplexer.rs +++ b/polkadot/node/network/bridge/src/multiplexer.rs @@ -17,7 +17,7 @@ use std::pin::Pin; use futures::channel::mpsc; -use futures::stream::Stream; +use futures::stream::{FusedStream, Stream}; use futures::task::{Context, Poll}; use strum::IntoEnumIterator; @@ -35,12 +35,15 @@ use polkadot_subsystem::messages::AllMessages; /// /// This multiplexer consumes all request streams and makes them a `Stream` of a single message /// type, useful for the network bridge to send them via the `Overseer` to other subsystems. +/// +/// The resulting stream will end once any of its input ends. pub struct RequestMultiplexer { receivers: Vec<(Protocol, mpsc::Receiver)>, next_poll: usize, } /// Multiplexing can fail in case of invalid messages. +#[derive(Debug, PartialEq, Eq)] pub struct RequestMultiplexError { /// The peer that sent the invalid message. pub peer: PeerId, @@ -85,15 +88,17 @@ impl Stream for RequestMultiplexer { // % safe, because count initialized to len, loop would not be entered if 0, also // length of receivers is fixed. let (p, rx): &mut (_, _) = &mut self.receivers[i % len]; + // Avoid panic: + if rx.is_terminated() { + // Early return, we don't want to update next_poll. + return Poll::Ready(None); + } i += 1; count -= 1; match Pin::new(rx).poll_next(cx) { - // If at least one stream is pending, then we are not done yet (No - // Ready(None)). Poll::Pending => result = Poll::Pending, - // Receiver is a fused stream, which allows for this simple handling of - // exhausted ones. - Poll::Ready(None) => {} + // We are done, once a single receiver is done. + Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(Some(v)) => { result = Poll::Ready(Some(multiplex_single(*p, v))); break; @@ -105,6 +110,17 @@ impl Stream for RequestMultiplexer { } } +impl FusedStream for RequestMultiplexer { + fn is_terminated(&self) -> bool { + let len = self.receivers.len(); + if len == 0 { + return true; + } + let (_, rx) = &self.receivers[self.next_poll % len]; + rx.is_terminated() + } +} + /// Convert a single raw incoming request into a `MultiplexMessage`. fn multiplex_single( p: Protocol, @@ -130,3 +146,29 @@ fn decode_with_peer( ) -> Result { Req::decode(&mut payload.as_ref()).map_err(|error| RequestMultiplexError { peer, error }) } + +#[cfg(test)] +mod tests { + use futures::prelude::*; + use futures::stream::FusedStream; + + use super::RequestMultiplexer; + #[test] + fn check_exhaustion_safety() { + // Create and end streams: + fn drop_configs() -> RequestMultiplexer { + let (multiplexer, _) = RequestMultiplexer::new(); + multiplexer + } + let multiplexer = drop_configs(); + futures::executor::block_on(async move { + let mut f = multiplexer; + assert!(f.next().await.is_none()); + assert!(f.is_terminated()); + assert!(f.next().await.is_none()); + assert!(f.is_terminated()); + assert!(f.next().await.is_none()); + assert!(f.is_terminated()); + }); + } +}