diff --git a/backend/common/src/either_sink.rs b/backend/common/src/either_sink.rs new file mode 100644 index 0000000..6e911af --- /dev/null +++ b/backend/common/src/either_sink.rs @@ -0,0 +1,57 @@ +use futures::sink::Sink; +use pin_project_lite::pin_project; + +pin_project! { + #[project = EitherSinkInner] + pub enum EitherSink { + A { #[pin] inner: A }, + B { #[pin] inner: B } + } +} + +/// A simple enum that delegates implementation to one of +/// the two possible sinks contained within. +impl EitherSink { + pub fn a(val: A) -> Self { + EitherSink::A { inner: val } + } + pub fn b(val: B) -> Self { + EitherSink::B { inner: val } + } +} + +impl Sink for EitherSink +where + A: Sink, + B: Sink +{ + type Error = Error; + + fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + match self.project() { + EitherSinkInner::A{ inner } => inner.poll_ready(cx), + EitherSinkInner::B{ inner } => inner.poll_ready(cx) + } + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + match self.project() { + EitherSinkInner::A{ inner } => inner.start_send(item), + EitherSinkInner::B{ inner } => inner.start_send(item) + } + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + match self.project() { + EitherSinkInner::A{ inner } => inner.poll_flush(cx), + EitherSinkInner::B{ inner } => inner.poll_flush(cx) + } + } + + fn poll_close(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + match self.project() { + EitherSinkInner::A{ inner } => inner.poll_close(cx), + EitherSinkInner::B{ inner } => inner.poll_close(cx) + } + } +} diff --git a/backend/common/src/lib.rs b/backend/common/src/lib.rs index f859036..cd7013b 100644 --- a/backend/common/src/lib.rs +++ b/backend/common/src/lib.rs @@ -30,6 +30,7 @@ mod dense_map; mod mean_list; mod most_seen; mod num_stats; +mod either_sink; // Export a bunch of common bits at the top level for ease of import: pub use assign_id::AssignId; @@ -37,3 +38,4 @@ pub use dense_map::DenseMap; pub use mean_list::MeanList; pub use most_seen::MostSeen; pub use num_stats::NumStats; +pub use either_sink::EitherSink; diff --git a/backend/telemetry_core/src/aggregator/aggregator_set.rs b/backend/telemetry_core/src/aggregator/aggregator_set.rs index 334cc5d..fb07fc3 100644 --- a/backend/telemetry_core/src/aggregator/aggregator_set.rs +++ b/backend/telemetry_core/src/aggregator/aggregator_set.rs @@ -4,6 +4,7 @@ use futures::{Sink, SinkExt, StreamExt}; use inner_loop::FromShardWebsocket; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use common::EitherSink; #[derive(Clone)] pub struct AggregatorSet(Arc); @@ -19,6 +20,8 @@ impl AggregatorSet { num_aggregators: usize, denylist: Vec, ) -> anyhow::Result { + assert_ne!(num_aggregators, 0, "You must have 1 or more aggregator"); + let aggregators = futures::future::try_join_all( (0..num_aggregators).map(|_| Aggregator::spawn(denylist.clone())), ) @@ -35,6 +38,13 @@ impl AggregatorSet { &self, ) -> impl Sink + Send + Sync + Unpin + 'static { + // Special case 1 aggregator to avoid the extra indurection and so on + // if we don't actually need it. + if self.0.aggregators.len() == 1 { + let sub = self.0.aggregators[0].subscribe_shard(); + return EitherSink::a(sub) + } + let mut conns: Vec<_> = self .0 .aggregators @@ -57,7 +67,7 @@ impl AggregatorSet { } }); - tx.sink_map_err(|e| anyhow::anyhow!("{}", e)) + EitherSink::b(tx.sink_map_err(|e| anyhow::anyhow!("{}", e))) } /// Return a sink that a feed can send messages into to be handled by a single aggregator.