diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 2536d50..c9235fe 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -392,6 +392,19 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "flume" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e90cc80fad5bb391b38127896b0fa27d97e7fef74742797f4da518d67e1292f" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spinning_top", +] + [[package]] name = "fnv" version = "1.0.7" @@ -540,8 +553,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -827,6 +842,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "nanorand" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "729eb334247daa1803e0a094d0a5c55711b85571179f5ec6e53eccfdf7008958" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.7" @@ -977,6 +1001,26 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +[[package]] +name = "pin-project" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "576bc800220cc65dac09e99e97b08b358cfab6e17078de8dc5fee223bd2d0c08" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e8fe8163d14ce7f0cdac2e040116f22eac817edabff0be91e8aff7e9accf389" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.7" @@ -1458,6 +1502,15 @@ dependencies = [ "sha-1", ] +[[package]] +name = "spinning_top" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75adad84ee84b521fb2cca2d4fd0f1dab1d8d026bda3c5bea4ca63b5f9f9293c" +dependencies = [ + "lock_api", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -1521,6 +1574,7 @@ dependencies = [ "bytes", "common", "criterion", + "flume", "futures", "hex", "http", diff --git a/backend/common/src/channel/metered_unbounded.rs b/backend/common/src/channel/metered_unbounded.rs deleted file mode 100644 index 81c8f00..0000000 --- a/backend/common/src/channel/metered_unbounded.rs +++ /dev/null @@ -1,202 +0,0 @@ -use futures::channel::mpsc::{ - unbounded, SendError, TrySendError, UnboundedReceiver, UnboundedSender, -}; -use futures::{Sink, SinkExt, Stream, StreamExt}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::task::Poll; - -/// Create an unbounded channel where we record the current length of the message queue. -pub fn metered_unbounded() -> (MeteredUnboundedSender, MeteredUnboundedReceiver) { - let (tx, rx) = unbounded(); - let len = Arc::new(AtomicUsize::new(0)); - let len2 = Arc::clone(&len); - - let tx = MeteredUnboundedSender { - inner: tx, - len: len, - }; - let rx = MeteredUnboundedReceiver { - inner: rx, - len: len2, - }; - - (tx, rx) -} - -/// This is similar to an `UnboundedSender`, except that we keep track -/// of the length of the internal message buffer. -#[derive(Debug, Clone)] -pub struct MeteredUnboundedSender { - inner: UnboundedSender, - len: Arc, -} - -impl MeteredUnboundedSender { - /// The current number of messages in the queue. - pub fn len(&self) -> usize { - self.len.load(Ordering::Relaxed) - } - - /// Send a message. - pub fn unbounded_send(&self, item: T) -> Result<(), TrySendError> { - self.len.fetch_add(1, Ordering::Relaxed); - self.inner.unbounded_send(item) - } -} - -impl Sink for MeteredUnboundedSender { - type Error = SendError; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.inner.poll_ready(cx) - } - - fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - self.unbounded_send(item).map_err(|e| e.into_send_error()) - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.inner.poll_flush_unpin(cx) - } - - fn poll_close( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.inner.poll_close_unpin(cx) - } -} - -impl Stream for MeteredUnboundedReceiver { - type Item = T; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let res = self.inner.poll_next_unpin(cx); - if matches!(res, Poll::Ready(Some(..))) { - self.len.fetch_sub(1, Ordering::Relaxed); - } - res - } -} - -/// This is similar to an `UnboundedReceiver`, except that we keep track -/// of the length of the internal message buffer. -#[derive(Debug)] -pub struct MeteredUnboundedReceiver { - inner: UnboundedReceiver, - len: Arc, -} - -impl MeteredUnboundedReceiver { - /// The current number of messages in the queue. - pub fn len(&self) -> usize { - self.len.load(Ordering::Relaxed) - } -} - -#[cfg(test)] -mod test { - - use super::*; - - #[tokio::test] - async fn channel_len_consistent_with_msgs() { - let (tx, mut rx) = metered_unbounded(); - - assert_eq!(tx.len(), 0); - tx.unbounded_send(1).unwrap(); - assert_eq!(tx.len(), 1); - tx.unbounded_send(2).unwrap(); - assert_eq!(tx.len(), 2); - tx.unbounded_send(3).unwrap(); - assert_eq!(tx.len(), 3); - - rx.next().await.unwrap(); - assert_eq!(tx.len(), 2); - rx.next().await.unwrap(); - assert_eq!(tx.len(), 1); - rx.next().await.unwrap(); - assert_eq!(tx.len(), 0); - } - - #[tokio::test] - async fn channel_len_consistent_with_msgs_sink() { - let (mut tx, mut rx) = metered_unbounded::(); - - assert_eq!(tx.len(), 0); - tx.send(1).await.unwrap(); - assert_eq!(tx.len(), 1); - tx.send(2).await.unwrap(); - assert_eq!(tx.len(), 2); - tx.send(3).await.unwrap(); - assert_eq!(tx.len(), 3); - - rx.next().await.unwrap(); - assert_eq!(tx.len(), 2); - rx.next().await.unwrap(); - assert_eq!(tx.len(), 1); - rx.next().await.unwrap(); - assert_eq!(tx.len(), 0); - } - - #[tokio::test] - async fn channel_len_consistent_when_send_parallelised() { - let (tx, _rx) = metered_unbounded::(); - - // Send lots of messages on a bunch of real threads: - let mut join_handles = vec![]; - for _ in 0..50 { - let tx = tx.clone(); - let join_handle = std::thread::spawn(move || { - for i in 0..10000 { - tx.unbounded_send(i).unwrap(); - } - }); - join_handles.push(join_handle); - } - - // When they are done, our len should be accurate: - for handle in join_handles { - handle.join().unwrap(); - } - assert_eq!(tx.len(), 50 * 10_000); - } - - #[tokio::test] - async fn channel_len_consistent_when_send_and_recv_parallelised() { - let (tx, mut rx) = metered_unbounded::(); - - // Send lots of messages on a bunch of real threads: - let mut join_handles = vec![]; - for _ in 0..50 { - let tx = tx.clone(); - let join_handle = std::thread::spawn(move || { - for i in 0..10000 { - tx.unbounded_send(i).unwrap(); - } - }); - join_handles.push(join_handle); - } - - // While this is happenening, we are trying to receive that same number of msgs: - for _ in 0..500_000 { - rx.next().await.unwrap(); - } - - // When they are done, our len should be accurate: - for handle in join_handles { - handle.join().unwrap(); - } - assert_eq!(tx.len(), 0); - } -} diff --git a/backend/common/src/channel/mod.rs b/backend/common/src/channel/mod.rs deleted file mode 100644 index 3b1b209..0000000 --- a/backend/common/src/channel/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod metered_unbounded; - -pub use metered_unbounded::*; diff --git a/backend/common/src/lib.rs b/backend/common/src/lib.rs index 690f118..85e14fa 100644 --- a/backend/common/src/lib.rs +++ b/backend/common/src/lib.rs @@ -15,7 +15,6 @@ // along with this program. If not, see . pub mod byte_size; -pub mod channel; pub mod http_utils; pub mod id_type; pub mod internal_messages; diff --git a/backend/telemetry_core/Cargo.toml b/backend/telemetry_core/Cargo.toml index 35ca51d..2f1c46e 100644 --- a/backend/telemetry_core/Cargo.toml +++ b/backend/telemetry_core/Cargo.toml @@ -15,6 +15,7 @@ bimap = "0.6.1" bincode = "1.3.3" bytes = "1.0.1" common = { path = "../common" } +flume = "0.10.8" futures = "0.3.15" hex = "0.4.3" http = "0.2.4" diff --git a/backend/telemetry_core/src/aggregator/aggregator.rs b/backend/telemetry_core/src/aggregator/aggregator.rs index 7f644c2..98f4935 100644 --- a/backend/telemetry_core/src/aggregator/aggregator.rs +++ b/backend/telemetry_core/src/aggregator/aggregator.rs @@ -18,7 +18,6 @@ use super::inner_loop; use crate::find_location::find_location; use crate::state::NodeId; use common::id_type; -use futures::channel::mpsc; use futures::{future, Sink, SinkExt}; use std::net::Ipv4Addr; use std::sync::atomic::AtomicU64; @@ -54,17 +53,17 @@ struct AggregatorInternal { /// Send messages in to the aggregator from the outside via this. This is /// stored here so that anybody holding an `Aggregator` handle can /// make use of it. - tx_to_aggregator: mpsc::UnboundedSender, + tx_to_aggregator: flume::Sender, } impl Aggregator { /// Spawn a new Aggregator. This connects to the telemetry backend pub async fn spawn(opts: AggregatorOpts) -> anyhow::Result { - let (tx_to_aggregator, rx_from_external) = mpsc::unbounded(); + let (tx_to_aggregator, rx_from_external) = flume::unbounded(); // Kick off a locator task to locate nodes, which hands back a channel to make location requests - let tx_to_locator = find_location(tx_to_aggregator.clone().with(|(node_id, msg)| { - future::ok::<_, mpsc::SendError>(inner_loop::ToAggregator::FromFindLocation( + let tx_to_locator = find_location(tx_to_aggregator.clone().into_sink().with(|(node_id, msg)| { + future::ok::<_, flume::SendError<_>>(inner_loop::ToAggregator::FromFindLocation( node_id, msg, )) })); @@ -89,8 +88,8 @@ impl Aggregator { /// in to the aggregator. If nobody is tolding the tx side of the channel /// any more, this task will gracefully end. async fn handle_messages( - rx_from_external: mpsc::UnboundedReceiver, - tx_to_aggregator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>, + rx_from_external: flume::Receiver, + tx_to_aggregator: flume::Sender<(NodeId, Ipv4Addr)>, max_queue_len: usize, denylist: Vec, ) { @@ -114,7 +113,7 @@ impl Aggregator { // Calling `send` on this Sink requires Unpin. There may be a nicer way than this, // but pinning by boxing is the easy solution for now: - Box::pin(tx_to_aggregator.with(move |msg| async move { + Box::pin(tx_to_aggregator.into_sink().with(move |msg| async move { Ok(inner_loop::ToAggregator::FromShardWebsocket( shard_conn_id.into(), msg, @@ -141,7 +140,7 @@ impl Aggregator { // but pinning by boxing is the easy solution for now: ( feed_conn_id, - Box::pin(tx_to_aggregator.with(move |msg| async move { + Box::pin(tx_to_aggregator.into_sink().with(move |msg| async move { Ok(inner_loop::ToAggregator::FromFeedWebsocket( feed_conn_id.into(), msg, diff --git a/backend/telemetry_core/src/aggregator/inner_loop.rs b/backend/telemetry_core/src/aggregator/inner_loop.rs index abfd1fc..88cb402 100644 --- a/backend/telemetry_core/src/aggregator/inner_loop.rs +++ b/backend/telemetry_core/src/aggregator/inner_loop.rs @@ -20,13 +20,11 @@ use crate::find_location; use crate::state::{self, NodeId, State}; use bimap::BiMap; use common::{ - channel::metered_unbounded, internal_messages::{self, MuteReason, ShardNodeId}, node_message, node_types::BlockHash, time, }; -use futures::channel::mpsc; use futures::StreamExt; use std::collections::{HashMap, HashSet}; use std::{ @@ -48,7 +46,7 @@ pub enum FromShardWebsocket { /// When the socket is opened, it'll send this first /// so that we have a way to communicate back to it. Initialize { - channel: mpsc::UnboundedSender, + channel: flume::Sender, }, /// Tell the aggregator about a new node. Add { @@ -86,7 +84,7 @@ pub enum FromFeedWebsocket { /// Unbounded so that slow feeds don't block aggregato /// progress. Initialize { - channel: mpsc::UnboundedSender, + channel: flume::Sender, }, /// The feed can subscribe to a chain to receive /// messages relating to it. @@ -135,9 +133,9 @@ pub struct InnerLoop { node_ids: BiMap, /// Keep track of how to send messages out to feeds. - feed_channels: HashMap>, + feed_channels: HashMap>, /// Keep track of how to send messages out to shards. - shard_channels: HashMap>, + shard_channels: HashMap>, /// Which chain is a feed subscribed to? /// Feed Connection ID -> Chain Genesis Hash @@ -150,7 +148,7 @@ pub struct InnerLoop { feed_conn_id_finality: HashSet, /// Send messages here to make geographical location requests. - tx_to_locator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>, + tx_to_locator: flume::Sender<(NodeId, Ipv4Addr)>, /// How big can the queue of messages coming in to the aggregator get before messages /// are prioritised and dropped to try and get back on track. @@ -160,7 +158,7 @@ pub struct InnerLoop { impl InnerLoop { /// Create a new inner loop handler with the various state it needs. pub fn new( - tx_to_locator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>, + tx_to_locator: flume::Sender<(NodeId, Ipv4Addr)>, denylist: Vec, max_queue_len: usize, ) -> Self { @@ -177,14 +175,16 @@ impl InnerLoop { } } - /// Start handling and responding to incoming messages. Owing to unbounded channels, we actually - /// only have a single `.await` (in this function). This helps to make it clear that the aggregator loop - /// will be able to make progress quickly without any potential yield points. - pub async fn handle(mut self, mut rx_from_external: mpsc::UnboundedReceiver) { + /// Start handling and responding to incoming messages. + pub async fn handle(mut self, rx_from_external: flume::Receiver) { let max_queue_len = self.max_queue_len; - let (metered_tx, mut metered_rx) = metered_unbounded(); + let (metered_tx, metered_rx) = flume::unbounded(); + // Actually handle all of our messages, but before we get here, we + // check the length of the queue below to decide whether or not to + // pass the message on to this. tokio::spawn(async move { + let mut metered_rx = metered_rx.into_stream(); while let Some(msg) = metered_rx.next().await { match msg { ToAggregator::FromFeedWebsocket(feed_conn_id, msg) => { @@ -215,9 +215,11 @@ impl InnerLoop { }); }); + let mut rx_from_external = rx_from_external.into_stream(); while let Some(msg) = rx_from_external.next().await { // ignore node updates if we have too many messages to handle, in an attempt - // to reduce the queue length back to something reasonable. + // to reduce the queue length back to something reasonable, lest it get out of + // control and start consuming a load of memory. if metered_tx.len() > max_queue_len { if matches!( msg, @@ -227,7 +229,7 @@ impl InnerLoop { } } - if let Err(e) = metered_tx.unbounded_send(msg) { + if let Err(e) = metered_tx.send(msg) { log::error!("Cannot send message into aggregator: {}", e); break; } @@ -277,7 +279,7 @@ impl InnerLoop { match self.node_state.add_node(genesis_hash, node) { state::AddNodeResult::ChainOnDenyList => { if let Some(shard_conn) = self.shard_channels.get_mut(&shard_conn_id) { - let _ = shard_conn.unbounded_send(ToShardWebsocket::Mute { + let _ = shard_conn.send(ToShardWebsocket::Mute { local_id, reason: MuteReason::ChainNotAllowed, }); @@ -285,7 +287,7 @@ impl InnerLoop { } state::AddNodeResult::ChainOverQuota => { if let Some(shard_conn) = self.shard_channels.get_mut(&shard_conn_id) { - let _ = shard_conn.unbounded_send(ToShardWebsocket::Mute { + let _ = shard_conn.send(ToShardWebsocket::Mute { local_id, reason: MuteReason::Overquota, }); @@ -326,7 +328,7 @@ impl InnerLoop { // Ask for the grographical location of the node. // Currently we only geographically locate IPV4 addresses so ignore IPV6. if let IpAddr::V4(ip_v4) = ip { - let _ = self.tx_to_locator.unbounded_send((node_id, ip_v4)); + let _ = self.tx_to_locator.send((node_id, ip_v4)); } } } @@ -409,7 +411,7 @@ impl InnerLoop { // Send this to the channel that subscribed: if let Some(bytes) = feed_serializer.into_finalized() { - let _ = channel.unbounded_send(ToFeedWebsocket::Bytes(bytes)); + let _ = channel.send(ToFeedWebsocket::Bytes(bytes)); } } FromFeedWebsocket::Ping { value } => { @@ -422,7 +424,7 @@ impl InnerLoop { let mut feed_serializer = FeedMessageSerializer::new(); feed_serializer.push(feed_message::Pong(&value)); if let Some(bytes) = feed_serializer.into_finalized() { - let _ = feed_channel.unbounded_send(ToFeedWebsocket::Bytes(bytes)); + let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes)); } } FromFeedWebsocket::Subscribe { chain } => { @@ -470,7 +472,7 @@ impl InnerLoop { new_chain.finalized_block().hash, )); if let Some(bytes) = feed_serializer.into_finalized() { - let _ = feed_channel.unbounded_send(ToFeedWebsocket::Bytes(bytes)); + let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes)); } // If many (eg 10k) nodes are connected, serializing all of their info takes time. @@ -505,7 +507,7 @@ impl InnerLoop { }) .collect(); for bytes in all_feed_messages { - let _ = feed_channel.unbounded_send(ToFeedWebsocket::Bytes(bytes)); + let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes)); } // Actually make a note of the new chain subsciption: @@ -620,7 +622,7 @@ impl InnerLoop { if let Some(feeds) = self.chain_to_feed_conn_ids.get(genesis_hash) { for &feed_id in feeds { if let Some(chan) = self.feed_channels.get_mut(&feed_id) { - let _ = chan.unbounded_send(message.clone()); + let _ = chan.send(message.clone()); } } } @@ -636,7 +638,7 @@ impl InnerLoop { /// Send a message to everybody. fn broadcast_to_all_feeds(&mut self, message: ToFeedWebsocket) { for chan in self.feed_channels.values_mut() { - let _ = chan.unbounded_send(message.clone()); + let _ = chan.send(message.clone()); } } @@ -662,7 +664,7 @@ impl InnerLoop { // are also subscribed to receive finality updates. for &feed_id in feeds.union(&self.feed_conn_id_finality) { if let Some(chan) = self.feed_channels.get_mut(&feed_id) { - let _ = chan.unbounded_send(message.clone()); + let _ = chan.send(message.clone()); } } } diff --git a/backend/telemetry_core/src/find_location.rs b/backend/telemetry_core/src/find_location.rs index 680a123..326bff1 100644 --- a/backend/telemetry_core/src/find_location.rs +++ b/backend/telemetry_core/src/find_location.rs @@ -17,7 +17,6 @@ use std::net::Ipv4Addr; use std::sync::Arc; -use futures::channel::mpsc; use futures::{Sink, SinkExt, StreamExt}; use parking_lot::RwLock; use rustc_hash::FxHashMap; @@ -31,12 +30,13 @@ pub type Location = Option>; /// This is responsible for taking an IP address and attempting /// to find a geographical location from this -pub fn find_location(response_chan: R) -> mpsc::UnboundedSender<(Id, Ipv4Addr)> +pub fn find_location(response_chan: R) -> flume::Sender<(Id, Ipv4Addr)> where R: Sink<(Id, Option>)> + Unpin + Send + Clone + 'static, Id: Clone + Send + 'static, { - let (tx, mut rx) = mpsc::unbounded(); + let (tx, rx) = flume::unbounded(); + let mut rx = rx.into_stream(); // cache entries let mut cache: FxHashMap>> = FxHashMap::default(); diff --git a/backend/telemetry_core/src/main.rs b/backend/telemetry_core/src/main.rs index d96ab9d..c55ade2 100644 --- a/backend/telemetry_core/src/main.rs +++ b/backend/telemetry_core/src/main.rs @@ -29,7 +29,7 @@ use bincode::Options; use common::http_utils; use common::internal_messages; use common::ready_chunks_all::ReadyChunksAll; -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use hyper::{Method, Response}; use simple_logger::SimpleLogger; use structopt::StructOpt; @@ -201,7 +201,8 @@ async fn handle_shard_websocket_connection( where S: futures::Sink + Unpin + Send + 'static, { - let (tx_to_shard_conn, mut rx_from_aggregator) = mpsc::unbounded(); + let (tx_to_shard_conn, rx_from_aggregator) = flume::unbounded(); + let mut rx_from_aggregator = rx_from_aggregator.into_stream(); // Tell the aggregator about this new connection, and give it a way to send messages to us: let init_msg = FromShardWebsocket::Initialize { @@ -343,8 +344,8 @@ where S: futures::Sink + Unpin + Send + 'static, { // unbounded channel so that slow feeds don't block aggregator progress: - let (tx_to_feed_conn, rx_from_aggregator) = mpsc::unbounded(); - let mut rx_from_aggregator_chunks = ReadyChunksAll::new(rx_from_aggregator); + let (tx_to_feed_conn, rx_from_aggregator) = flume::unbounded(); + let mut rx_from_aggregator_chunks = ReadyChunksAll::new(rx_from_aggregator.into_stream()); // Tell the aggregator about this new connection, and give it a way to send messages to us: let init_msg = FromFeedWebsocket::Initialize {