diff --git a/polkadot/node/subsystem-util/src/validator_discovery.rs b/polkadot/node/subsystem-util/src/validator_discovery.rs index ac5bf1d470..fb0cf1539b 100644 --- a/polkadot/node/subsystem-util/src/validator_discovery.rs +++ b/polkadot/node/subsystem-util/src/validator_discovery.rs @@ -24,6 +24,7 @@ use futures::{ task::{Poll, self}, stream, }; +use streamunordered::{StreamUnordered, StreamYield}; use thiserror::Error; use polkadot_node_subsystem::{ @@ -104,6 +105,73 @@ async fn connect_to_authorities( Ok((connected_rx, revoke_tx)) } +/// A struct that assists performing multiple concurrent connection requests. +/// +/// This allows concurrent connections to validator sets at different `relay_parents` +/// and multiplexes their results into a single `Stream`. +#[derive(Default)] +pub struct ConnectionRequests { + // added connection requests relay_parent -> StreamUnordered token + id_map: HashMap, + + // Connection requests themselves. + requests: StreamUnordered, +} + +impl ConnectionRequests { + /// Insert a new connection request. + /// + /// If a `ConnectionRequest` under a given `relay_parent` already exists it will + /// be revoked and substituted with a new one. + pub fn put(&mut self, relay_parent: Hash, request: ConnectionRequest) { + self.remove(&relay_parent); + let token = self.requests.push(request); + + self.id_map.insert(relay_parent, token); + } + + /// Remove a connection request by a given `relay_parent`. + pub fn remove(&mut self, relay_parent: &Hash) { + if let Some(token) = self.id_map.remove(relay_parent) { + Pin::new(&mut self.requests).remove(token); + } + } +} + +impl stream::Stream for ConnectionRequests { + /// (relay_parent, validator_id, peer_id). + type Item = (Hash, ValidatorId, PeerId); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + // If there are currently no requests going on, pend instead of + // polling `StreamUnordered` which would lead to it terminating + // and returning `Poll::Ready(None)`. + if self.requests.is_empty() { + return Poll::Pending; + } + + match Pin::new(&mut self.requests).poll_next(cx) { + Poll::Ready(Some((yielded, token))) => { + match yielded { + StreamYield::Item(item) => { + if let Some((relay_parent, _)) = self.id_map.iter() + .find(|(_, &val)| val == token) + { + return Poll::Ready(Some((*relay_parent, item.0, item.1))); + } + } + StreamYield::Finished(_) => { + // `ConnectionRequest` is fullfilled, but not revoked + } + } + }, + _ => {}, + } + + Poll::Pending + } +} + /// A pending connection request to validators. /// This struct implements `Stream` to allow for asynchronous /// discovery of validator addresses. @@ -156,3 +224,188 @@ impl ConnectionRequest { } } } + +#[cfg(test)] +mod tests { + use super::*; + use polkadot_primitives::v1::ValidatorPair; + use sp_core::{Pair, Public}; + + use futures::{executor, poll, channel::{mpsc, oneshot}, StreamExt, SinkExt}; + + #[test] + fn adding_a_connection_request_works() { + let mut connection_requests = ConnectionRequests::default(); + + executor::block_on(async move { + assert_eq!(poll!(Pin::new(&mut connection_requests).next()), Poll::Pending); + + let validator_1 = ValidatorPair::generate().0.public(); + let validator_2 = ValidatorPair::generate().0.public(); + + let auth_1 = AuthorityDiscoveryId::from_slice(&[1; 32]); + let auth_2 = AuthorityDiscoveryId::from_slice(&[2; 32]); + + let mut validator_map = HashMap::new(); + validator_map.insert(auth_1.clone(), validator_1.clone()); + validator_map.insert(auth_2.clone(), validator_2.clone()); + + let (mut rq1_tx, rq1_rx) = mpsc::channel(8); + let (revoke_1_tx, _revoke_1_rx) = oneshot::channel(); + + let peer_id_1 = PeerId::random(); + let peer_id_2 = PeerId::random(); + + let connection_request_1 = ConnectionRequest { + validator_map, + connections: rq1_rx, + revoke: revoke_1_tx, + }; + + let relay_parent_1 = Hash::repeat_byte(1); + + connection_requests.put(relay_parent_1.clone(), connection_request_1); + + rq1_tx.send((auth_1, peer_id_1.clone())).await.unwrap(); + rq1_tx.send((auth_2, peer_id_2.clone())).await.unwrap(); + + let res = Pin::new(&mut connection_requests).next().await.unwrap(); + assert_eq!(res, (relay_parent_1, validator_1, peer_id_1)); + + let res = Pin::new(&mut connection_requests).next().await.unwrap(); + assert_eq!(res, (relay_parent_1, validator_2, peer_id_2)); + + assert_eq!( + poll!(Pin::new(&mut connection_requests).next()), + Poll::Pending, + ); + }); + } + + #[test] + fn adding_two_connection_requests_works() { + let mut connection_requests = ConnectionRequests::default(); + + executor::block_on(async move { + assert_eq!(poll!(Pin::new(&mut connection_requests).next()), Poll::Pending); + + let validator_1 = ValidatorPair::generate().0.public(); + let validator_2 = ValidatorPair::generate().0.public(); + + let auth_1 = AuthorityDiscoveryId::from_slice(&[1; 32]); + let auth_2 = AuthorityDiscoveryId::from_slice(&[2; 32]); + + let mut validator_map_1 = HashMap::new(); + let mut validator_map_2 = HashMap::new(); + + validator_map_1.insert(auth_1.clone(), validator_1.clone()); + validator_map_2.insert(auth_2.clone(), validator_2.clone()); + + let (mut rq1_tx, rq1_rx) = mpsc::channel(8); + let (revoke_1_tx, _revoke_1_rx) = oneshot::channel(); + + let (mut rq2_tx, rq2_rx) = mpsc::channel(8); + let (revoke_2_tx, _revoke_2_rx) = oneshot::channel(); + + let peer_id_1 = PeerId::random(); + let peer_id_2 = PeerId::random(); + + let connection_request_1 = ConnectionRequest { + validator_map: validator_map_1, + connections: rq1_rx, + revoke: revoke_1_tx, + }; + + let connection_request_2 = ConnectionRequest { + validator_map: validator_map_2, + connections: rq2_rx, + revoke: revoke_2_tx, + }; + + let relay_parent_1 = Hash::repeat_byte(1); + let relay_parent_2 = Hash::repeat_byte(2); + + connection_requests.put(relay_parent_1.clone(), connection_request_1); + connection_requests.put(relay_parent_2.clone(), connection_request_2); + + rq1_tx.send((auth_1, peer_id_1.clone())).await.unwrap(); + rq2_tx.send((auth_2, peer_id_2.clone())).await.unwrap(); + + let res = Pin::new(&mut connection_requests).next().await.unwrap(); + assert_eq!(res, (relay_parent_1, validator_1, peer_id_1)); + + let res = Pin::new(&mut connection_requests).next().await.unwrap(); + assert_eq!(res, (relay_parent_2, validator_2, peer_id_2)); + + assert_eq!( + poll!(Pin::new(&mut connection_requests).next()), + Poll::Pending, + ); + }); + } + + #[test] + fn replacing_a_connection_request_works() { + let mut connection_requests = ConnectionRequests::default(); + + executor::block_on(async move { + assert_eq!(poll!(Pin::new(&mut connection_requests).next()), Poll::Pending); + + let validator_1 = ValidatorPair::generate().0.public(); + let validator_2 = ValidatorPair::generate().0.public(); + + let auth_1 = AuthorityDiscoveryId::from_slice(&[1; 32]); + let auth_2 = AuthorityDiscoveryId::from_slice(&[2; 32]); + + let mut validator_map_1 = HashMap::new(); + let mut validator_map_2 = HashMap::new(); + + validator_map_1.insert(auth_1.clone(), validator_1.clone()); + validator_map_2.insert(auth_2.clone(), validator_2.clone()); + + let (mut rq1_tx, rq1_rx) = mpsc::channel(8); + let (revoke_1_tx, _revoke_1_rx) = oneshot::channel(); + + let (mut rq2_tx, rq2_rx) = mpsc::channel(8); + let (revoke_2_tx, _revoke_2_rx) = oneshot::channel(); + + let peer_id_1 = PeerId::random(); + let peer_id_2 = PeerId::random(); + + let connection_request_1 = ConnectionRequest { + validator_map: validator_map_1, + connections: rq1_rx, + revoke: revoke_1_tx, + }; + + let connection_request_2 = ConnectionRequest { + validator_map: validator_map_2, + connections: rq2_rx, + revoke: revoke_2_tx, + }; + + let relay_parent = Hash::repeat_byte(3); + + connection_requests.put(relay_parent.clone(), connection_request_1); + + rq1_tx.send((auth_1.clone(), peer_id_1.clone())).await.unwrap(); + + let res = Pin::new(&mut connection_requests).next().await.unwrap(); + assert_eq!(res, (relay_parent, validator_1, peer_id_1.clone())); + + connection_requests.put(relay_parent.clone(), connection_request_2); + + assert!(rq1_tx.send((auth_1, peer_id_1.clone())).await.is_err()); + + rq2_tx.send((auth_2, peer_id_2.clone())).await.unwrap(); + + let res = Pin::new(&mut connection_requests).next().await.unwrap(); + assert_eq!(res, (relay_parent, validator_2, peer_id_2)); + + assert_eq!( + poll!(Pin::new(&mut connection_requests).next()), + Poll::Pending, + ); + }); + } +}