diff --git a/polkadot/node/network/bridge/src/validator_discovery.rs b/polkadot/node/network/bridge/src/validator_discovery.rs index 762dc7d9ea..59626d54e0 100644 --- a/polkadot/node/network/bridge/src/validator_discovery.rs +++ b/polkadot/node/network/bridge/src/validator_discovery.rs @@ -23,13 +23,12 @@ use std::sync::Arc; use async_trait::async_trait; use futures::channel::mpsc; -use strum::IntoEnumIterator as _; use sc_network::multiaddr::{Multiaddr, Protocol}; use sc_authority_discovery::Service as AuthorityDiscoveryService; use polkadot_node_network_protocol::PeerId; use polkadot_primitives::v1::{AuthorityDiscoveryId, Block, Hash}; -use polkadot_node_network_protocol::peer_set::PeerSet; +use polkadot_node_network_protocol::peer_set::{PeerSet, PerPeerSet}; const LOG_TARGET: &str = "validator_discovery"; @@ -140,9 +139,9 @@ fn peer_id_from_multiaddr(addr: &Multiaddr) -> Option { }) } + pub(super) struct Service { - // indexed by PeerSet as usize - state: Vec, + state: PerPeerSet, // PhantomData used to make the struct generic instead of having generic methods _phantom: PhantomData<(N, AD)>, } @@ -162,15 +161,11 @@ struct StatePerPeerSet { impl Service { pub fn new() -> Self { Self { - state: PeerSet::iter().map(|_| Default::default()).collect(), + state: PerPeerSet::default(), _phantom: PhantomData, } } - fn state_mut(&mut self, peer_set: PeerSet) -> &mut StatePerPeerSet { - &mut self.state[peer_set as usize] - } - /// Find connected validators using the given `validator_ids`. /// /// Returns a [`HashMap`] that contains the found [`AuthorityDiscoveryId`]'s and their associated [`PeerId`]'s. @@ -182,7 +177,7 @@ impl Service { authority_discovery_service: &mut AD, ) -> HashMap { let mut result = HashMap::new(); - let state = self.state_mut(peer_set); + let state = &mut self.state[peer_set]; for id in validator_ids { // First check if we already cached the validator @@ -238,7 +233,7 @@ impl Service { &mut authority_discovery_service, ).await; - let state = self.state_mut(peer_set); + let state = &mut self.state[peer_set]; // Increment the counter of how many times the validators were requested. validator_ids.iter().for_each(|id| *state.requested_validators.entry(id.clone()).or_default() += 1); @@ -336,7 +331,7 @@ impl Service { peer_set: PeerSet, authority_discovery_service: &mut AD, ) { - let state = self.state_mut(peer_set); + let state = &mut self.state[peer_set]; // check if it's an authority we've been waiting for let maybe_authority = authority_discovery_service.get_authority_id_by_peer_id(peer_id.clone()).await; if let Some(authority) = maybe_authority { @@ -352,8 +347,7 @@ impl Service { /// Should be called when a peer disconnected. pub fn on_peer_disconnected(&mut self, peer_id: &PeerId, peer_set: PeerSet) { - let state = self.state_mut(peer_set); - state.connected_peers.remove(peer_id); + self.state[peer_set].connected_peers.remove(peer_id); } } @@ -569,7 +563,7 @@ mod tests { let reply = receiver.next().await.unwrap(); assert_eq!(reply.0, authority_ids[1]); assert_eq!(reply.1, peer_ids[1]); - let state = service.state_mut(PeerSet::Validation); + let state = &service.state[PeerSet::Validation]; assert_eq!(state.non_revoked_discovery_requests.len(), 1); }); } @@ -613,7 +607,7 @@ mod tests { ).await; let _ = receiver.next().await.unwrap(); - let state = service.state_mut(PeerSet::Validation); + let state = &service.state[PeerSet::Validation]; assert_eq!(state.non_revoked_discovery_requests.len(), 1); assert_eq!(ns.peers_set.len(), 2); @@ -631,7 +625,7 @@ mod tests { ).await; let _ = receiver.next().await.unwrap(); - let state = service.state_mut(PeerSet::Validation); + let state = &service.state[PeerSet::Validation]; assert_eq!(state.non_revoked_discovery_requests.len(), 1); assert_eq!(ns.peers_set.len(), 1); }); @@ -667,7 +661,7 @@ mod tests { ).await; assert_eq!((validator_id.clone(), validator_peer_id.clone()), receiver.next().await.unwrap()); - let state = service.state_mut(PeerSet::Validation); + let state = &service.state[PeerSet::Validation]; assert!( state.connected_peers .get(&validator_peer_id) diff --git a/polkadot/node/network/protocol/src/peer_set.rs b/polkadot/node/network/protocol/src/peer_set.rs index 01d14067b9..c8e3b68e36 100644 --- a/polkadot/node/network/protocol/src/peer_set.rs +++ b/polkadot/node/network/protocol/src/peer_set.rs @@ -17,18 +17,17 @@ //! All peersets and protocols used for parachains. use sc_network::config::{NonDefaultSetConfig, SetConfig}; -use std::borrow::Cow; +use std::{borrow::Cow, ops::{Index, IndexMut}}; use strum::{EnumIter, IntoEnumIterator}; /// The peer-sets and thus the protocols which are used for the network. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter)] -#[repr(usize)] pub enum PeerSet { /// The validation peer-set is responsible for all messages related to candidate validation and /// communication among validators. - Validation = 0, + Validation, /// The collation peer-set is used for validator<>collator communication. - Collation = 1, + Collation, } impl PeerSet { @@ -89,6 +88,32 @@ impl PeerSet { } } +/// A small and nifty collection that allows to store data pertaining to each peer set. +#[derive(Debug, Default)] +pub struct PerPeerSet { + validation: T, + collation: T, +} + +impl Index for PerPeerSet { + type Output = T; + fn index(&self, index: PeerSet) -> &T { + match index { + PeerSet::Validation => &self.validation, + PeerSet::Collation => &self.collation, + } + } +} + +impl IndexMut for PerPeerSet { + fn index_mut(&mut self, index: PeerSet) -> &mut T { + match index { + PeerSet::Validation => &mut self.validation, + PeerSet::Collation => &mut self.collation, + } + } +} + /// Get `NonDefaultSetConfig`s for all available peer sets. /// /// Should be used during network configuration (added to [`NetworkConfiguration::extra_sets`])