diff --git a/backend/common/src/channel/metered_unbounded.rs b/backend/common/src/channel/metered_unbounded.rs new file mode 100644 index 0000000..f76e440 --- /dev/null +++ b/backend/common/src/channel/metered_unbounded.rs @@ -0,0 +1,191 @@ +use futures::channel::mpsc::{ SendError, TrySendError, UnboundedSender, UnboundedReceiver, unbounded }; +use futures::{ Sink, Stream, SinkExt, StreamExt }; +use std::sync::atomic::{ AtomicUsize, Ordering }; +use std::sync::Arc; +use std::task::Poll; + +/// Create an unbounded channel where we record the current length of the message queue. +pub fn metered_unbounded() -> (MeteredUnboundedSender, MeteredUnboundedReceiver) { + let (tx, rx) = unbounded(); + let len = Arc::new(AtomicUsize::new(0)); + let len2 = Arc::clone(&len); + + let tx = MeteredUnboundedSender { + inner: tx, + len: len + }; + let rx = MeteredUnboundedReceiver { + inner: rx, + len: len2 + }; + + (tx, rx) +} + +/// This is similar to an `UnboundedSender`, except that we keep track +/// of the length of the internal message buffer. +#[derive(Debug, Clone)] +pub struct MeteredUnboundedSender { + inner: UnboundedSender, + len: Arc, +} + +impl MeteredUnboundedSender { + /// The current number of messages in the queue. + pub fn len(&self) -> usize { + self.len.load(Ordering::Relaxed) + } + + /// Send a message. + pub fn unbounded_send(&self, item: T) -> Result<(), TrySendError> { + self.len.fetch_add(1, Ordering::Relaxed); + self.inner.unbounded_send(item) + } +} + +impl Sink for MeteredUnboundedSender { + type Error = SendError; + + fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + self.unbounded_send(item).map_err(|e| e.into_send_error()) + } + + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.inner.poll_flush_unpin(cx) + } + + fn poll_close(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.inner.poll_close_unpin(cx) + } +} + +impl Stream for MeteredUnboundedReceiver { + type Item = T; + + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + let res = self.inner.poll_next_unpin(cx); + if matches!(res, Poll::Ready(Some(..))) { + self.len.fetch_sub(1, Ordering::Relaxed); + } + res + } +} + +/// This is similar to an `UnboundedReceiver`, except that we keep track +/// of the length of the internal message buffer. +#[derive(Debug)] +pub struct MeteredUnboundedReceiver { + inner: UnboundedReceiver, + len: Arc, +} + +impl MeteredUnboundedReceiver { + /// The current number of messages in the queue. + pub fn len(&self) -> usize { + self.len.load(Ordering::Relaxed) + } +} + +#[cfg(test)] +mod test { + + use super::*; + + #[tokio::test] + async fn channel_len_consistent_with_msgs() { + let (tx, mut rx) = metered_unbounded(); + + assert_eq!(tx.len(), 0); + tx.unbounded_send(1).unwrap(); + assert_eq!(tx.len(), 1); + tx.unbounded_send(2).unwrap(); + assert_eq!(tx.len(), 2); + tx.unbounded_send(3).unwrap(); + assert_eq!(tx.len(), 3); + + rx.next().await.unwrap(); + assert_eq!(tx.len(), 2); + rx.next().await.unwrap(); + assert_eq!(tx.len(), 1); + rx.next().await.unwrap(); + assert_eq!(tx.len(), 0); + } + + #[tokio::test] + async fn channel_len_consistent_with_msgs_sink() { + let (mut tx, mut rx) = metered_unbounded::(); + + assert_eq!(tx.len(), 0); + tx.send(1).await.unwrap(); + assert_eq!(tx.len(), 1); + tx.send(2).await.unwrap(); + assert_eq!(tx.len(), 2); + tx.send(3).await.unwrap(); + assert_eq!(tx.len(), 3); + + rx.next().await.unwrap(); + assert_eq!(tx.len(), 2); + rx.next().await.unwrap(); + assert_eq!(tx.len(), 1); + rx.next().await.unwrap(); + assert_eq!(tx.len(), 0); + } + + #[tokio::test] + async fn channel_len_consistent_when_send_parallelised() { + let (mut tx, mut rx) = metered_unbounded::(); + + // Send lots of messages on a bunch of real threads: + let mut join_handles = vec![]; + for _ in 0..50 { + let tx = tx.clone(); + let join_handle = std::thread::spawn(move || { + for i in 0..10000 { + tx.unbounded_send(i).unwrap(); + } + }); + join_handles.push(join_handle); + } + + // When they are done, our len should be accurate: + for handle in join_handles { + handle.join().unwrap(); + } + assert_eq!(tx.len(), 50 * 10_000); + + } + + #[tokio::test] + async fn channel_len_consistent_when_send_and_recv_parallelised() { + let (mut tx, mut rx) = metered_unbounded::(); + + // Send lots of messages on a bunch of real threads: + let mut join_handles = vec![]; + for _ in 0..50 { + let tx = tx.clone(); + let join_handle = std::thread::spawn(move || { + for i in 0..10000 { + tx.unbounded_send(i).unwrap(); + } + }); + join_handles.push(join_handle); + } + + // While this is happenening, we are trying to receive that same number of msgs: + for _ in 0..500_000 { + rx.next().await.unwrap(); + } + + // When they are done, our len should be accurate: + for handle in join_handles { + handle.join().unwrap(); + } + assert_eq!(tx.len(), 0); + + } + +} \ No newline at end of file diff --git a/backend/common/src/channel/mod.rs b/backend/common/src/channel/mod.rs new file mode 100644 index 0000000..dd4b412 --- /dev/null +++ b/backend/common/src/channel/mod.rs @@ -0,0 +1,3 @@ +mod metered_unbounded; + +pub use metered_unbounded::*; \ No newline at end of file diff --git a/backend/common/src/lib.rs b/backend/common/src/lib.rs index 85e14fa..b863e45 100644 --- a/backend/common/src/lib.rs +++ b/backend/common/src/lib.rs @@ -24,6 +24,7 @@ pub mod ready_chunks_all; pub mod rolling_total; pub mod time; pub mod ws_client; +pub mod channel; mod assign_id; mod dense_map; diff --git a/backend/telemetry_core/src/aggregator/aggregator.rs b/backend/telemetry_core/src/aggregator/aggregator.rs index 7c975ee..172a405 100644 --- a/backend/telemetry_core/src/aggregator/aggregator.rs +++ b/backend/telemetry_core/src/aggregator/aggregator.rs @@ -74,16 +74,16 @@ impl Aggregator { }))) } - // This is spawned into a separate task and handles any messages coming - // in to the aggregator. If nobody is tolding the tx side of the channel - // any more, this task will gracefully end. + /// This is spawned into a separate task and handles any messages coming + /// in to the aggregator. If nobody is tolding the tx side of the channel + /// any more, this task will gracefully end. async fn handle_messages( rx_from_external: mpsc::UnboundedReceiver, tx_to_aggregator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>, denylist: Vec, ) { - inner_loop::InnerLoop::new(rx_from_external, tx_to_aggregator, denylist) - .handle() + inner_loop::InnerLoop::new(tx_to_aggregator, denylist) + .handle(rx_from_external) .await; } diff --git a/backend/telemetry_core/src/aggregator/inner_loop.rs b/backend/telemetry_core/src/aggregator/inner_loop.rs index 6cfec73..1b6e0fe 100644 --- a/backend/telemetry_core/src/aggregator/inner_loop.rs +++ b/backend/telemetry_core/src/aggregator/inner_loop.rs @@ -23,6 +23,7 @@ use common::{ internal_messages::{self, MuteReason, ShardNodeId}, node_message, node_types::BlockHash, + channel::metered_unbounded, time, }; use futures::channel::mpsc; @@ -127,9 +128,6 @@ pub enum ToFeedWebsocket { /// Instances of this are responsible for handling incoming and /// outgoing messages in the main aggregator loop. pub struct InnerLoop { - /// Messages from the outside world come into this: - rx_from_external: mpsc::UnboundedReceiver, - /// The state of our chains and nodes lives here: node_state: State, /// We maintain a mapping between NodeId and ConnId+LocalId, so that we know @@ -158,12 +156,10 @@ pub struct InnerLoop { impl InnerLoop { /// Create a new inner loop handler with the various state it needs. pub fn new( - rx_from_external: mpsc::UnboundedReceiver, tx_to_locator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>, denylist: Vec, ) -> Self { InnerLoop { - rx_from_external, node_state: State::new(denylist), node_ids: BiMap::new(), feed_channels: HashMap::new(), @@ -178,19 +174,40 @@ impl InnerLoop { /// Start handling and responding to incoming messages. Owing to unbounded channels, we actually /// only have a single `.await` (in this function). This helps to make it clear that the aggregator loop /// will be able to make progress quickly without any potential yield points. - pub async fn handle(mut self) { - while let Some(msg) = self.rx_from_external.next().await { - match msg { - ToAggregator::FromFeedWebsocket(feed_conn_id, msg) => { - self.handle_from_feed(feed_conn_id, msg) - } - ToAggregator::FromShardWebsocket(shard_conn_id, msg) => { - self.handle_from_shard(shard_conn_id, msg) - } - ToAggregator::FromFindLocation(node_id, location) => { - self.handle_from_find_location(node_id, location) + pub async fn handle(mut self, mut rx_from_external: mpsc::UnboundedReceiver) { + + let (metered_tx, mut metered_rx) = metered_unbounded(); + + tokio::spawn(async move { + while let Some(msg) = metered_rx.next().await { + match msg { + ToAggregator::FromFeedWebsocket(feed_conn_id, msg) => { + self.handle_from_feed(feed_conn_id, msg) + } + ToAggregator::FromShardWebsocket(shard_conn_id, msg) => { + self.handle_from_shard(shard_conn_id, msg) + } + ToAggregator::FromFindLocation(node_id, location) => { + self.handle_from_find_location(node_id, location) + } } } + }); + + // TEMP: let's monitor message queue len out of interest + let tx_len = metered_tx.clone(); + tokio::spawn(async move { + loop { + println!("Queue len: {}", tx_len.len()); + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await + } + }); + + while let Some(msg) = rx_from_external.next().await { + if let Err(e) = metered_tx.unbounded_send(msg) { + log::error!("Cannot send message into aggregator: {}", e); + break; + } } }