Give things unique ID types, not aliases, to prevent mixups

This commit is contained in:
James Wilson
2021-06-30 16:59:03 +01:00
parent 06bd660599
commit 4308359feb
14 changed files with 308 additions and 394 deletions
+11 -8
View File
@@ -5,12 +5,15 @@ use futures::{ future, Sink, SinkExt };
use super::inner_loop;
use crate::find_location::find_location;
use crate::state::NodeId;
use common::id_type;
use std::net::Ipv4Addr;
/// A unique Id is assigned per websocket connection (or more accurately,
/// per feed socket and per shard socket). This can be combined with the
/// [`LocalId`] of messages to give us a global ID.
type ConnId = u64;
id_type! {
/// A unique Id is assigned per websocket connection (or more accurately,
/// per feed socket and per shard socket). This can be combined with the
/// [`LocalId`] of messages to give us a global ID.
pub ConnId(u64)
}
#[derive(Clone)]
pub struct Aggregator(Arc<AggregatorInternal>);
@@ -64,13 +67,13 @@ impl Aggregator {
pub fn subscribe_shard(&self) -> impl Sink<inner_loop::FromShardWebsocket, Error = anyhow::Error> + Unpin {
// Assign a unique aggregator-local ID to each connection that subscribes, and pass
// that along with every message to the aggregator loop:
let shard_conn_id: ConnId = self.0.shard_conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let shard_conn_id = self.0.shard_conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tx_to_aggregator = self.0.tx_to_aggregator.clone();
// Calling `send` on this Sink requires Unpin. There may be a nicer way than this,
// but pinning by boxing is the easy solution for now:
Box::pin(tx_to_aggregator.with(move |msg| async move {
Ok(inner_loop::ToAggregator::FromShardWebsocket(shard_conn_id, msg))
Ok(inner_loop::ToAggregator::FromShardWebsocket(shard_conn_id.into(), msg))
}))
}
@@ -78,13 +81,13 @@ impl Aggregator {
pub fn subscribe_feed(&self) -> impl Sink<inner_loop::FromFeedWebsocket, Error = anyhow::Error> + Unpin {
// Assign a unique aggregator-local ID to each connection that subscribes, and pass
// that along with every message to the aggregator loop:
let feed_conn_id: ConnId = self.0.feed_conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let feed_conn_id = self.0.feed_conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tx_to_aggregator = self.0.tx_to_aggregator.clone();
// Calling `send` on this Sink requires Unpin. There may be a nicer way than this,
// but pinning by boxing is the easy solution for now:
Box::pin(tx_to_aggregator.with(move |msg| async move {
Ok(inner_loop::ToAggregator::FromFeedWebsocket(feed_conn_id, msg))
Ok(inner_loop::ToAggregator::FromFeedWebsocket(feed_conn_id.into(), msg))
}))
}
+36 -41
View File
@@ -1,7 +1,7 @@
use common::{
internal_messages::{
self,
LocalId,
ShardNodeId,
MuteReason
},
types::BlockHash,
@@ -16,11 +16,7 @@ use std::collections::{ HashMap, HashSet };
use crate::state::{ self, State, NodeId };
use crate::feed_message::{ self, FeedMessageSerializer };
use crate::find_location;
/// A unique Id is assigned per websocket connection (or more accurately,
/// per feed socket and per shard socket). This can be combined with the
/// [`LocalId`] of messages to give us a global ID.
type ConnId = u64;
use super::aggregator::ConnId;
/// Incoming messages come via subscriptions, and end up looking like this.
#[derive(Clone,Debug)]
@@ -40,19 +36,19 @@ pub enum FromShardWebsocket {
},
/// Tell the aggregator about a new node.
Add {
local_id: LocalId,
local_id: ShardNodeId,
ip: Option<std::net::IpAddr>,
node: common::types::NodeDetails,
genesis_hash: common::types::BlockHash
},
/// Update/pass through details about a node.
Update {
local_id: LocalId,
local_id: ShardNodeId,
payload: node::Payload
},
/// Tell the aggregator that a node has been removed when it disconnects.
Remove {
local_id: LocalId,
local_id: ShardNodeId,
},
/// The shard is disconnected.
Disconnected
@@ -63,7 +59,7 @@ pub enum FromShardWebsocket {
pub enum ToShardWebsocket {
/// Mute messages to the core by passing the shard-local ID of them.
Mute {
local_id: LocalId,
local_id: ShardNodeId,
reason: internal_messages::MuteReason
}
}
@@ -129,7 +125,7 @@ pub struct InnerLoop {
node_state: State,
/// We maintain a mapping between NodeId and ConnId+LocalId, so that we know
/// which messages are about which nodes.
node_ids: BiMap<NodeId, (ConnId, LocalId)>,
node_ids: BiMap<NodeId, (ConnId, ShardNodeId)>,
/// Keep track of how to send messages out to feeds.
feed_channels: HashMap<ConnId, mpsc::UnboundedSender<ToFeedWebsocket>>,
@@ -194,7 +190,7 @@ impl InnerLoop {
if let Some(loc) = location {
let mut feed_message_serializer = FeedMessageSerializer::new();
feed_message_serializer.push(feed_message::LocatedNode(
node_id,
node_id.get_chain_node_id().into(),
loc.latitude,
loc.longitude,
&loc.city
@@ -212,7 +208,8 @@ impl InnerLoop {
/// Handle messages coming from shards.
async fn handle_from_shard(&mut self, shard_conn_id: ConnId, msg: FromShardWebsocket) {
log::debug!("Message from shard ({}): {:?}", shard_conn_id, msg);
log::debug!("Message from shard ({:?}): {:?}", shard_conn_id, msg);
match msg {
FromShardWebsocket::Initialize { channel } => {
self.shard_channels.insert(shard_conn_id, channel);
@@ -249,7 +246,7 @@ impl InnerLoop {
// Tell chain subscribers about the node we've just added:
let mut feed_messages_for_chain = FeedMessageSerializer::new();
feed_messages_for_chain.push(feed_message::AddedNode(node_id, &details.node));
feed_messages_for_chain.push(feed_message::AddedNode(node_id.get_chain_node_id().into(), &details.node));
self.finalize_and_broadcast_to_chain_feeds(&genesis_hash, feed_messages_for_chain).await;
// Tell everybody about the new node count and potential rename:
@@ -272,7 +269,7 @@ impl InnerLoop {
let node_id = match self.node_ids.remove_by_right(&(shard_conn_id, local_id)) {
Some((node_id, _)) => node_id,
None => {
log::error!("Cannot find ID for node with shard/connectionId of {}/{}", shard_conn_id, local_id);
log::error!("Cannot find ID for node with shard/connectionId of {:?}/{:?}", shard_conn_id, local_id);
return
}
};
@@ -282,11 +279,22 @@ impl InnerLoop {
let node_id = match self.node_ids.get_by_right(&(shard_conn_id, local_id)) {
Some(id) => *id,
None => {
log::error!("Cannot find ID for node with shard/connectionId of {}/{}", shard_conn_id, local_id);
log::error!("Cannot find ID for node with shard/connectionId of {:?}/{:?}", shard_conn_id, local_id);
return
}
};
self.handle_from_shard_update(node_id, payload).await;
let mut feed_message_serializer = FeedMessageSerializer::new();
let broadcast_finality = self.node_state.update_node(node_id, payload, &mut feed_message_serializer);
if let Some(chain) = self.node_state.get_chain_by_node_id(node_id) {
let genesis_hash = *chain.genesis_hash();
if broadcast_finality {
self.finalize_and_broadcast_to_chain_finality_feeds(&genesis_hash, feed_message_serializer).await;
} else {
self.finalize_and_broadcast_to_chain_feeds(&genesis_hash, feed_message_serializer).await;
}
}
},
FromShardWebsocket::Disconnected => {
// Find all nodes associated with this shard connection ID:
@@ -302,24 +310,9 @@ impl InnerLoop {
}
}
async fn handle_from_shard_update(&mut self, node_id: NodeId, payload: node::Payload) {
let mut feed_message_serializer = FeedMessageSerializer::new();
let broadcast_finality = self.node_state.update_node(node_id, payload, &mut feed_message_serializer);
if let Some(chain) = self.node_state.get_chain_by_node_id(node_id) {
let genesis_hash = *chain.genesis_hash();
if broadcast_finality {
self.finalize_and_broadcast_to_chain_finality_feeds(&genesis_hash, feed_message_serializer).await;
} else {
self.finalize_and_broadcast_to_chain_feeds(&genesis_hash, feed_message_serializer).await;
}
}
}
/// Handle messages coming from feeds.
async fn handle_from_feed(&mut self, feed_conn_id: ConnId, msg: FromFeedWebsocket) {
log::debug!("Message from feed ({}): {:?}", feed_conn_id, msg);
log::debug!("Message from feed ({:?}): {:?}", feed_conn_id, msg);
match msg {
FromFeedWebsocket::Initialize { mut channel } => {
self.feed_channels.insert(feed_conn_id, channel.clone());
@@ -396,7 +389,9 @@ impl InnerLoop {
new_chain.finalized_block().height,
new_chain.finalized_block().hash
));
for (idx, (node_id, node)) in new_chain.iter_nodes().enumerate() {
for (idx, (chain_node_id, node)) in new_chain.iter_nodes().enumerate() {
let chain_node_id = chain_node_id.into();
// Send subscription confirmation and chain head before doing all the nodes,
// and continue sending batches of 32 nodes a time over the wire subsequently
if idx % 32 == 0 {
@@ -404,14 +399,14 @@ impl InnerLoop {
let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes)).await;
}
}
feed_serializer.push(feed_message::AddedNode(node_id, node));
feed_serializer.push(feed_message::AddedNode(chain_node_id, node));
feed_serializer.push(feed_message::FinalizedBlock(
node_id,
chain_node_id,
node.finalized().height,
node.finalized().hash,
));
if node.stale() {
feed_serializer.push(feed_message::StaleNode(node_id));
feed_serializer.push(feed_message::StaleNode(chain_node_id));
}
}
if let Some(bytes) = feed_serializer.into_finalized() {
@@ -480,9 +475,9 @@ impl InnerLoop {
self.node_ids.remove_by_left(&node_id);
let removed_details = match self.node_state.remove_node(node_id) {
Ok(remove_details) => remove_details,
Err(err) => {
log::error!("Error removing node {}: {}", node_id, err);
Some(remove_details) => remove_details,
None => {
log::error!("Could not find node {:?}", node_id);
return
}
};
@@ -504,7 +499,7 @@ impl InnerLoop {
// Assuming the chain hasn't gone away, tell chain subscribers about the node removal
if removed_details.chain_node_count != 0 {
feed_for_chain.push(
feed_message::RemovedNode(node_id)
feed_message::RemovedNode(node_id.get_chain_node_id().into())
);
}
}
+13 -10
View File
@@ -7,10 +7,13 @@ use std::mem;
use crate::state::Node;
use serde_json::to_writer;
use common::types::{
Address, BlockDetails, BlockHash, BlockNumber, NodeHardware, NodeIO, NodeId, NodeStats,
BlockDetails, BlockHash, BlockNumber, NodeHardware, NodeIO, NodeStats,
Timestamp
};
type Address = Box<str>;
type FeedNodeId = usize;
pub trait FeedMessage {
const ACTION: u8;
}
@@ -133,28 +136,28 @@ pub struct BestBlock(pub BlockNumber, pub Timestamp, pub Option<u64>);
#[derive(Serialize)]
pub struct BestFinalized(pub BlockNumber, pub BlockHash);
pub struct AddedNode<'a>(pub NodeId, pub &'a Node);
pub struct AddedNode<'a>(pub FeedNodeId, pub &'a Node);
#[derive(Serialize)]
pub struct RemovedNode(pub NodeId);
pub struct RemovedNode(pub FeedNodeId);
#[derive(Serialize)]
pub struct LocatedNode<'a>(pub NodeId, pub f32, pub f32, pub &'a str);
pub struct LocatedNode<'a>(pub FeedNodeId, pub f32, pub f32, pub &'a str);
#[derive(Serialize)]
pub struct ImportedBlock<'a>(pub NodeId, pub &'a BlockDetails);
pub struct ImportedBlock<'a>(pub FeedNodeId, pub &'a BlockDetails);
#[derive(Serialize)]
pub struct FinalizedBlock(pub NodeId, pub BlockNumber, pub BlockHash);
pub struct FinalizedBlock(pub FeedNodeId, pub BlockNumber, pub BlockHash);
#[derive(Serialize)]
pub struct NodeStatsUpdate<'a>(pub NodeId, pub &'a NodeStats);
pub struct NodeStatsUpdate<'a>(pub FeedNodeId, pub &'a NodeStats);
#[derive(Serialize)]
pub struct NodeIOUpdate<'a>(pub NodeId, pub &'a NodeIO);
pub struct NodeIOUpdate<'a>(pub FeedNodeId, pub &'a NodeIO);
#[derive(Serialize)]
pub struct Hardware<'a>(pub NodeId, pub &'a NodeHardware);
pub struct Hardware<'a>(pub FeedNodeId, pub &'a NodeHardware);
#[derive(Serialize)]
pub struct TimeSync(pub u64);
@@ -203,7 +206,7 @@ pub struct AfgAuthoritySet(
);
#[derive(Serialize)]
pub struct StaleNode(pub NodeId);
pub struct StaleNode(pub FeedNodeId);
impl FeedMessageWrite for AddedNode<'_> {
fn write_to_feed(&self, ser: &mut FeedMessageSerializer) {
+57 -37
View File
@@ -5,11 +5,17 @@ use common::util::{now, DenseMap, NumStats};
use common::most_seen::MostSeen;
use common::node::Payload;
use once_cell::sync::Lazy;
use common::id_type;
use crate::feed_message::{self, FeedMessageSerializer};
use crate::find_location;
use super::node::Node;
use super::NodeId;
id_type!{
/// A Node ID that is unique to the chain it's in.
pub ChainNodeId(usize)
}
pub type Label = Box<str>;
@@ -20,7 +26,7 @@ pub struct Chain {
/// the most commonly used label as nodes are added/removed.
labels: MostSeen<Label>,
/// Set of nodes that are in this chain
node_ids: HashSet<NodeId>,
nodes: DenseMap<ChainNodeId, Node>,
/// Best block
best: Block,
/// Finalized block
@@ -38,6 +44,7 @@ pub struct Chain {
pub enum AddNodeResult {
Overquota,
Added {
id: ChainNodeId,
chain_renamed: bool
}
}
@@ -65,7 +72,7 @@ impl Chain {
pub fn new(genesis_hash: BlockHash) -> Self {
Chain {
labels: MostSeen::default(),
node_ids: HashSet::new(),
nodes: DenseMap::new(),
best: Block::zero(),
finalized: Block::zero(),
block_times: NumStats::new(50),
@@ -79,29 +86,34 @@ impl Chain {
pub fn can_add_node(&self) -> bool {
// Dynamically determine the max nodes based on the most common
// label so far, in case it changes to something with a different limit.
self.node_ids.len() < max_nodes(self.labels.best())
self.nodes.len() < max_nodes(self.labels.best())
}
/// Assign a node to this chain. If the function returns false, it
/// means that the node could not be added as we're at quota.
pub fn add_node(&mut self, node_id: NodeId, chain_label: &Box<str>) -> AddNodeResult {
/// Assign a node to this chain.
pub fn add_node(&mut self, node: Node) -> AddNodeResult {
if !self.can_add_node() {
return AddNodeResult::Overquota
}
let label_result = self.labels.insert(chain_label);
self.node_ids.insert(node_id);
let node_chain_label = &node.details().chain;
let label_result = self.labels.insert(node_chain_label);
let node_id = self.nodes.add(node);
AddNodeResult::Added {
id: node_id,
chain_renamed: label_result.has_changed()
}
}
/// Remove a node from this chain. We expect the label it used for the chain so
/// that we can keep track of which label is most popular.
pub fn remove_node(&mut self, node_id: NodeId, chain_label: &Box<str>) -> RemoveNodeResult {
let label_result = self.labels.remove(&chain_label);
self.node_ids.remove(&node_id);
/// Remove a node from this chain.
pub fn remove_node(&mut self, node_id: ChainNodeId) -> RemoveNodeResult {
let node = match self.nodes.remove(node_id) {
Some(node) => node,
None => return RemoveNodeResult { chain_renamed: false }
};
let node_chain_label = &node.details().chain;
let label_result = self.labels.remove(node_chain_label);
RemoveNodeResult {
chain_renamed: label_result.has_changed()
@@ -110,25 +122,25 @@ impl Chain {
/// Attempt to update the best block seen in this chain.
/// Returns a boolean which denotes whether the output is for finalization feeds (true) or not (false).
pub fn update_node(&mut self, all_nodes: &mut DenseMap<Node>, nid: NodeId, payload: Payload, feed: &mut FeedMessageSerializer) -> bool {
pub fn update_node(&mut self, nid: ChainNodeId, payload: Payload, feed: &mut FeedMessageSerializer) -> bool {
if let Some(block) = payload.best_block() {
self.handle_block(all_nodes, block, nid, feed);
self.handle_block(block, nid, feed);
}
if let Some(node) = all_nodes.get_mut(nid) {
if let Some(node) = self.nodes.get_mut(nid) {
match payload {
Payload::SystemInterval(ref interval) => {
if node.update_hardware(interval) {
feed.push(feed_message::Hardware(nid, node.hardware()));
feed.push(feed_message::Hardware(nid.into(), node.hardware()));
}
if let Some(stats) = node.update_stats(interval) {
feed.push(feed_message::NodeStatsUpdate(nid, stats));
feed.push(feed_message::NodeStatsUpdate(nid.into(), stats));
}
if let Some(io) = node.update_io(interval) {
feed.push(feed_message::NodeIOUpdate(nid, io));
feed.push(feed_message::NodeIOUpdate(nid.into(), io));
}
}
Payload::AfgAuthoritySet(authority) => {
@@ -187,7 +199,7 @@ impl Chain {
if let Some(block) = payload.finalized_block() {
if let Some(finalized) = node.update_finalized(block) {
feed.push(feed_message::FinalizedBlock(
nid,
nid.into(),
finalized.height,
finalized.hash,
));
@@ -203,14 +215,14 @@ impl Chain {
false
}
fn handle_block(&mut self, all_nodes: &mut DenseMap<Node>, block: &Block, nid: NodeId, feed: &mut FeedMessageSerializer) {
fn handle_block(&mut self, block: &Block, nid: ChainNodeId, feed: &mut FeedMessageSerializer) {
let mut propagation_time = None;
let now = now();
let nodes_len = self.node_ids.len();
let nodes_len = self.nodes.len();
self.update_stale_nodes(all_nodes, now, feed);
self.update_stale_nodes(now, feed);
let node = match all_nodes.get_mut(nid) {
let node = match self.nodes.get_mut(nid) {
Some(node) => node,
None => return,
};
@@ -243,14 +255,14 @@ impl Chain {
}
if let Some(details) = node.update_details(now, propagation_time) {
feed.push(feed_message::ImportedBlock(nid, details));
feed.push(feed_message::ImportedBlock(nid.into(), details));
}
}
}
/// Check if the chain is stale (has not received a new best block in a while).
/// If so, find a new best block, ignoring any stale nodes and marking them as such.
fn update_stale_nodes(&mut self, all_nodes: &mut DenseMap<Node>, now: u64, feed: &mut FeedMessageSerializer) {
fn update_stale_nodes(&mut self, now: u64, feed: &mut FeedMessageSerializer) {
let threshold = now - STALE_TIMEOUT;
let timestamp = match self.timestamp {
@@ -267,11 +279,7 @@ impl Chain {
let mut finalized = Block::zero();
let mut timestamp = None;
for &nid in self.node_ids.iter() {
let node = match all_nodes.get_mut(nid) {
Some(node) => node,
None => continue
};
for (nid, node) in self.nodes.iter_mut() {
if !node.update_stale(threshold) {
if node.best().height > best.height {
best = *node.best();
@@ -282,7 +290,7 @@ impl Chain {
finalized = *node.finalized();
}
} else {
feed.push(feed_message::StaleNode(nid));
feed.push(feed_message::StaleNode(nid.into()));
}
}
@@ -301,14 +309,26 @@ impl Chain {
}
}
pub fn update_node_location(&mut self, node_id: ChainNodeId, location: find_location::Location) -> bool {
if let Some(node) = self.nodes.get_mut(node_id) {
node.update_location(location);
true
} else {
false
}
}
pub fn get_node(&self, id: ChainNodeId) -> Option<&Node> {
self.nodes.get(id)
}
pub fn iter_nodes(&self) -> impl Iterator<Item=(ChainNodeId, &Node)> {
self.nodes.iter()
}
pub fn label(&self) -> &str {
&self.labels.best()
}
pub fn node_ids(&self) -> impl Iterator<Item=NodeId> + '_ {
self.node_ids.iter().copied()
}
pub fn node_count(&self) -> usize {
self.node_ids.len()
self.nodes.len()
}
pub fn best_block(&self) -> &Block {
&self.best
+79 -88
View File
@@ -7,28 +7,39 @@ use common::node::Payload;
use std::iter::IntoIterator;
use crate::feed_message::FeedMessageSerializer;
use crate::find_location;
use common::id_type;
use super::chain::{ self, Chain };
use super::chain::{ self, Chain, ChainNodeId };
pub type NodeId = usize;
pub type ChainId = usize;
id_type!{
/// A globally unique Chain ID.
pub ChainId(usize)
}
/// A "global" Node ID is a composite of the ID of the chain it's
/// on, and it's chain local ID.
#[derive(Debug,Clone,Copy,Hash,PartialEq,Eq)]
pub struct NodeId(ChainId, ChainNodeId);
impl NodeId {
pub fn get_chain_node_id(&self) -> ChainNodeId {
self.1
}
}
/// Our state constains node and chain information
pub struct State {
// Store nodes and chains in a fairly compact format.
nodes: DenseMap<Node>,
chains: DenseMap<Chain>,
chains: DenseMap<ChainId, Chain>,
// Find the right chain given various details.
chains_by_genesis_hash: HashMap<BlockHash, ChainId>,
chains_by_label: HashMap<Box<str>, ChainId>,
chains_by_node: HashMap<NodeId, ChainId>,
/// Chain labels that we do not want to allow connecting.
denylist: HashSet<String>,
}
/// Adding a node to a chain leads to this result
/// Adding a node to a chain leads to this node_idult
pub enum AddNodeResult<'a> {
/// The chain is on the "deny list", so we can't add the node
ChainOnDenyList,
@@ -38,6 +49,16 @@ pub enum AddNodeResult<'a> {
NodeAddedToChain(NodeAddedToChain<'a>)
}
#[cfg(test)]
impl <'a> AddNodeResult<'a> {
pub fn unwrap_id(&self) -> NodeId {
match &self {
AddNodeResult::NodeAddedToChain(d) => d.id,
_ => panic!("Attempt to unwrap_id on AddNodeResult that did not succeed")
}
}
}
pub struct NodeAddedToChain<'a> {
/// The ID assigned to this node.
pub id: NodeId,
@@ -65,25 +86,12 @@ pub struct RemovedNode {
pub new_chain_label: Box<str>,
}
/// If removing a node goes wrong, we get this back
#[derive(Debug, thiserror::Error)]
pub enum RemoveNodeError {
/// The node that you tried to remove wasn't found
#[error("Node not found")]
NodeNotFound,
/// The chain associated to the node wasn't found
#[error("Node chain not found")]
NodeChainNotFound
}
impl State {
pub fn new<T: IntoIterator<Item=String>>(denylist: T) -> State {
State {
nodes: DenseMap::new(),
chains: DenseMap::new(),
chains_by_genesis_hash: HashMap::new(),
chains_by_label: HashMap::new(),
chains_by_node: HashMap::new(),
denylist: denylist.into_iter().collect(),
}
}
@@ -91,28 +99,27 @@ impl State {
pub fn iter_chains(&self) -> impl Iterator<Item=StateChain<'_>> {
self.chains
.iter()
.map(move |(_,chain)| StateChain { state: self, chain })
.map(move |(_,chain)| StateChain { chain })
}
pub fn get_chain_by_node_id(&self, node_id: NodeId) -> Option<StateChain<'_>> {
self.chains_by_node
.get(&node_id)
.and_then(|&chain_id| self.chains.get(chain_id))
.map(|chain| StateChain { state: self, chain })
self.chains
.get(node_id.0)
.map(|chain| StateChain { chain })
}
pub fn get_chain_by_genesis_hash(&self, genesis_hash: &BlockHash) -> Option<StateChain<'_>> {
self.chains_by_genesis_hash
.get(genesis_hash)
.and_then(|&chain_id| self.chains.get(chain_id))
.map(|chain| StateChain { state: self, chain })
.map(|chain| StateChain { chain })
}
pub fn get_chain_by_label(&self, label: &str) -> Option<StateChain<'_>> {
self.chains_by_label
.get(label)
.and_then(|&chain_id| self.chains.get(chain_id))
.map(|chain| StateChain { state: self, chain })
.map(|chain| StateChain { chain })
}
pub fn add_node(&mut self, genesis_hash: BlockHash, node_details: NodeDetails) -> AddNodeResult<'_> {
@@ -137,36 +144,26 @@ impl State {
let chain = self.chains.get_mut(chain_id)
.expect("should be known to exist after the above (unless chains_by_genesis_hash out of sync)");
// What ID will the node have when it's added? We don't actually want
// to add it until we know whether the chain will accept it, but we want
// an ID to give to the chain.
let node_id = self.nodes.next_id();
let node_chain_label = node_details.chain.clone();
let node = Node::new(node_details);
let old_chain_label = chain.label().into();
match chain.add_node(node_id, &node_chain_label) {
match chain.add_node(node) {
chain::AddNodeResult::Overquota => {
AddNodeResult::ChainOverQuota
},
chain::AddNodeResult::Added { chain_renamed } => {
chain::AddNodeResult::Added { id, chain_renamed } => {
let chain = &*chain;
// Actually add the node, and a reference to its chain,
// if the chain adds it successfully:
self.nodes.add(Node::new(node_details));
self.chains_by_node.insert(node_id, chain_id);
// Update the label we use to reference the chain if
// it changes (it'll always change first time a node's added):
if chain_renamed {
self.chains_by_label.remove(&old_chain_label);
self.chains_by_label.insert(chain.label().to_string().into_boxed_str(), chain_id);
self.chains_by_label.insert(chain.label().into(), chain_id);
}
let node = self.nodes.get(node_id).expect("node added above");
AddNodeResult::NodeAddedToChain(NodeAddedToChain {
id: node_id,
node: node,
id: NodeId(chain_id, id),
node: chain.get_node(id).expect("node added above"),
old_chain_label: old_chain_label,
new_chain_label: chain.label(),
chain_node_count: chain.node_count(),
@@ -177,24 +174,20 @@ impl State {
}
/// Remove a node
pub fn remove_node(&mut self, node_id: NodeId) -> Result<RemovedNode,RemoveNodeError> {
self.nodes.remove(node_id)
.ok_or(RemoveNodeError::NodeNotFound)?;
pub fn remove_node(&mut self, NodeId(chain_id, chain_node_id): NodeId) -> Option<RemovedNode> {
let chain = self.chains.get_mut(chain_id)?;
let old_chain_label = chain.label().into();
let chain_id = self.chains_by_node.remove(&node_id)
.ok_or(RemoveNodeError::NodeChainNotFound)?;
// Actually remove the node
let remove_result = chain.remove_node(chain_node_id);
let chain = self.chains.get_mut(chain_id)
.ok_or(RemoveNodeError::NodeChainNotFound)?;
let old_chain_label = chain.label().to_string().into_boxed_str();
let remove_result = chain.remove_node(node_id, &old_chain_label);
let new_chain_label = chain.label().to_string().into_boxed_str();
// Get updated chain details.
let new_chain_label: Box<str> = chain.label().into();
let chain_node_count = chain.node_count();
let genesis_hash = *chain.genesis_hash();
// Is the chain empty? Remove if so and clean up indexes to it
if chain_node_count == 0 {
let genesis_hash = *chain.genesis_hash();
self.chains_by_label.remove(&old_chain_label);
self.chains_by_genesis_hash.remove(&genesis_hash);
self.chains.remove(chain_id);
@@ -206,7 +199,7 @@ impl State {
self.chains_by_label.insert(new_chain_label.clone(), chain_id);
}
Ok(RemovedNode {
Some(RemovedNode {
old_chain_label,
new_chain_label,
chain_node_count: chain_node_count,
@@ -216,25 +209,19 @@ impl State {
/// Attempt to update the best block seen, given a node and block.
/// Returns a boolean which denotes whether the output is for finalization feeds (true) or not (false).
pub fn update_node(&mut self, node_id: NodeId, payload: Payload, feed: &mut FeedMessageSerializer) -> bool {
let chain_id = match self.chains_by_node.get(&node_id) {
Some(chain_id) => *chain_id,
None => { log::error!("Cannot find chain_id for node with ID {}", node_id); return false }
};
pub fn update_node(&mut self, NodeId(chain_id, chain_node_id): NodeId, payload: Payload, feed: &mut FeedMessageSerializer) -> bool {
let chain = match self.chains.get_mut(chain_id) {
Some(chain) => chain,
None => { log::error!("Cannot find chain for node with ID {}", node_id); return false }
None => { log::error!("Cannot find chain for node with ID {:?}", chain_id); return false }
};
chain.update_node(&mut self.nodes, node_id, payload, feed)
chain.update_node(chain_node_id, payload, feed)
}
/// Update the location for a node. Return `false` if the node was not found.
pub fn update_node_location(&mut self, node_id: NodeId, location: find_location::Location) -> bool {
if let Some(node) = self.nodes.get_mut(node_id) {
node.update_location(location);
true
pub fn update_node_location(&mut self, NodeId(chain_id, chain_node_id): NodeId, location: find_location::Location) -> bool {
if let Some(chain) = self.chains.get_mut(chain_id) {
chain.update_node_location(chain_node_id, location)
} else {
false
}
@@ -247,7 +234,6 @@ impl State {
/// aren't really intended for use outside of [`State`] methods. Any modification
/// of a chain needs to go through [`State`].
pub struct StateChain<'a> {
state: &'a State,
chain: &'a Chain
}
@@ -273,11 +259,8 @@ impl <'a> StateChain<'a> {
pub fn finalized_block(&self) -> &'a Block {
self.chain.finalized_block()
}
pub fn iter_nodes(&self) -> impl Iterator<Item=(NodeId, &'a Node)> + 'a {
let state = self.state;
self.chain.node_ids().filter_map(move |id| {
Some((id, state.nodes.get(id)?))
})
pub fn iter_nodes(&self) -> impl Iterator<Item=(ChainNodeId, &'a Node)> + 'a {
self.chain.iter_nodes()
}
}
@@ -314,7 +297,7 @@ mod test {
AddNodeResult::NodeAddedToChain(details) => details
};
assert_eq!(add_node_result.id, 0);
assert_eq!(add_node_result.id, NodeId(0.into(), 0.into()));
assert_eq!(&*add_node_result.old_chain_label, "");
assert_eq!(&*add_node_result.new_chain_label, "Chain One");
assert_eq!(add_node_result.chain_node_count, 1);
@@ -331,7 +314,7 @@ mod test {
AddNodeResult::NodeAddedToChain(details) => details
};
assert_eq!(add_node_result.id, 1);
assert_eq!(add_node_result.id, NodeId(0.into(), 1.into()));
assert_eq!(&*add_node_result.old_chain_label, "Chain One");
assert_eq!(&*add_node_result.new_chain_label, "Chain One");
assert_eq!(add_node_result.chain_node_count, 2);
@@ -343,32 +326,38 @@ mod test {
let mut state = State::new(None);
let chain1_genesis = BlockHash::from_low_u64_be(1);
state.add_node(chain1_genesis, node("A", "Chain One")); // 0
let node_id0 = state
.add_node(chain1_genesis, node("A", "Chain One")) // 0
.unwrap_id();
assert_eq!(state.get_chain_by_node_id(0).expect("Chain should exist").label(), "Chain One");
assert_eq!(state.get_chain_by_node_id(node_id0).expect("Chain should exist").label(), "Chain One");
assert!(state.get_chain_by_label("Chain One").is_some());
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
state.add_node(chain1_genesis, node("B", "Chain Two")); // 1
let node_id1 = state
.add_node(chain1_genesis, node("B", "Chain Two")) // 1
.unwrap_id();
// Chain name hasn't changed yet; "Chain One" as common as "Chain Two"..
assert_eq!(state.get_chain_by_node_id(0).expect("Chain should exist").label(), "Chain One");
assert_eq!(state.get_chain_by_node_id(node_id0).expect("Chain should exist").label(), "Chain One");
assert!(state.get_chain_by_label("Chain One").is_some());
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
state.add_node(chain1_genesis, node("B", "Chain Two")); // 2
let node_id2 = state
.add_node(chain1_genesis, node("B", "Chain Two"))
.unwrap_id(); // 2
// Chain name has changed; "Chain Two" the winner now..
assert_eq!(state.get_chain_by_node_id(0).expect("Chain should exist").label(), "Chain Two");
assert_eq!(state.get_chain_by_node_id(node_id0).expect("Chain should exist").label(), "Chain Two");
assert!(state.get_chain_by_label("Chain One").is_none());
assert!(state.get_chain_by_label("Chain Two").is_some());
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
state.remove_node(1).expect("Removal OK (id: 1)");
state.remove_node(2).expect("Removal OK (id: 2");
state.remove_node(node_id1).expect("Removal OK (id: 1)");
state.remove_node(node_id2).expect("Removal OK (id: 2)");
// Removed both "Chain Two" nodes; dominant name now "Chain One" again..
assert_eq!(state.get_chain_by_node_id(0).expect("Chain should exist").label(), "Chain One");
assert_eq!(state.get_chain_by_node_id(node_id0).expect("Chain should exist").label(), "Chain One");
assert!(state.get_chain_by_label("Chain One").is_some());
assert!(state.get_chain_by_label("Chain Two").is_none());
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
@@ -379,13 +368,15 @@ mod test {
let mut state = State::new(None);
let chain1_genesis = BlockHash::from_low_u64_be(1);
state.add_node(chain1_genesis, node("A", "Chain One")); // 0
let node_id = state
.add_node(chain1_genesis, node("A", "Chain One")) // 0
.unwrap_id();
assert!(state.get_chain_by_label("Chain One").is_some());
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
assert_eq!(state.iter_chains().count(), 1);
state.remove_node(0);
state.remove_node(node_id);
assert!(state.get_chain_by_label("Chain One").is_none());
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_none());