use flume throughout telemetry_core

This commit is contained in:
James Wilson
2021-08-10 10:26:13 +01:00
parent 8268cf2afe
commit 703a9ddc4e
9 changed files with 98 additions and 247 deletions
+54
View File
@@ -392,6 +392,19 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "flume"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e90cc80fad5bb391b38127896b0fa27d97e7fef74742797f4da518d67e1292f"
dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"pin-project",
"spinning_top",
]
[[package]]
name = "fnv"
version = "1.0.7"
@@ -540,8 +553,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
@@ -827,6 +842,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "nanorand"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "729eb334247daa1803e0a094d0a5c55711b85571179f5ec6e53eccfdf7008958"
dependencies = [
"getrandom",
]
[[package]]
name = "native-tls"
version = "0.2.7"
@@ -977,6 +1001,26 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e"
[[package]]
name = "pin-project"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "576bc800220cc65dac09e99e97b08b358cfab6e17078de8dc5fee223bd2d0c08"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e8fe8163d14ce7f0cdac2e040116f22eac817edabff0be91e8aff7e9accf389"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
version = "0.2.7"
@@ -1458,6 +1502,15 @@ dependencies = [
"sha-1",
]
[[package]]
name = "spinning_top"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75adad84ee84b521fb2cca2d4fd0f1dab1d8d026bda3c5bea4ca63b5f9f9293c"
dependencies = [
"lock_api",
]
[[package]]
name = "static_assertions"
version = "1.1.0"
@@ -1521,6 +1574,7 @@ dependencies = [
"bytes",
"common",
"criterion",
"flume",
"futures",
"hex",
"http",
@@ -1,202 +0,0 @@
use futures::channel::mpsc::{
unbounded, SendError, TrySendError, UnboundedReceiver, UnboundedSender,
};
use futures::{Sink, SinkExt, Stream, StreamExt};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
/// Create an unbounded channel where we record the current length of the message queue.
pub fn metered_unbounded<T>() -> (MeteredUnboundedSender<T>, MeteredUnboundedReceiver<T>) {
let (tx, rx) = unbounded();
let len = Arc::new(AtomicUsize::new(0));
let len2 = Arc::clone(&len);
let tx = MeteredUnboundedSender {
inner: tx,
len: len,
};
let rx = MeteredUnboundedReceiver {
inner: rx,
len: len2,
};
(tx, rx)
}
/// This is similar to an `UnboundedSender`, except that we keep track
/// of the length of the internal message buffer.
#[derive(Debug, Clone)]
pub struct MeteredUnboundedSender<T> {
inner: UnboundedSender<T>,
len: Arc<AtomicUsize>,
}
impl<T> MeteredUnboundedSender<T> {
/// The current number of messages in the queue.
pub fn len(&self) -> usize {
self.len.load(Ordering::Relaxed)
}
/// Send a message.
pub fn unbounded_send(&self, item: T) -> Result<(), TrySendError<T>> {
self.len.fetch_add(1, Ordering::Relaxed);
self.inner.unbounded_send(item)
}
}
impl<T> Sink<T> for MeteredUnboundedSender<T> {
type Error = SendError;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.unbounded_send(item).map_err(|e| e.into_send_error())
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.inner.poll_flush_unpin(cx)
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.inner.poll_close_unpin(cx)
}
}
impl<T> Stream for MeteredUnboundedReceiver<T> {
type Item = T;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let res = self.inner.poll_next_unpin(cx);
if matches!(res, Poll::Ready(Some(..))) {
self.len.fetch_sub(1, Ordering::Relaxed);
}
res
}
}
/// This is similar to an `UnboundedReceiver`, except that we keep track
/// of the length of the internal message buffer.
#[derive(Debug)]
pub struct MeteredUnboundedReceiver<T> {
inner: UnboundedReceiver<T>,
len: Arc<AtomicUsize>,
}
impl<T> MeteredUnboundedReceiver<T> {
/// The current number of messages in the queue.
pub fn len(&self) -> usize {
self.len.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn channel_len_consistent_with_msgs() {
let (tx, mut rx) = metered_unbounded();
assert_eq!(tx.len(), 0);
tx.unbounded_send(1).unwrap();
assert_eq!(tx.len(), 1);
tx.unbounded_send(2).unwrap();
assert_eq!(tx.len(), 2);
tx.unbounded_send(3).unwrap();
assert_eq!(tx.len(), 3);
rx.next().await.unwrap();
assert_eq!(tx.len(), 2);
rx.next().await.unwrap();
assert_eq!(tx.len(), 1);
rx.next().await.unwrap();
assert_eq!(tx.len(), 0);
}
#[tokio::test]
async fn channel_len_consistent_with_msgs_sink() {
let (mut tx, mut rx) = metered_unbounded::<usize>();
assert_eq!(tx.len(), 0);
tx.send(1).await.unwrap();
assert_eq!(tx.len(), 1);
tx.send(2).await.unwrap();
assert_eq!(tx.len(), 2);
tx.send(3).await.unwrap();
assert_eq!(tx.len(), 3);
rx.next().await.unwrap();
assert_eq!(tx.len(), 2);
rx.next().await.unwrap();
assert_eq!(tx.len(), 1);
rx.next().await.unwrap();
assert_eq!(tx.len(), 0);
}
#[tokio::test]
async fn channel_len_consistent_when_send_parallelised() {
let (tx, _rx) = metered_unbounded::<usize>();
// Send lots of messages on a bunch of real threads:
let mut join_handles = vec![];
for _ in 0..50 {
let tx = tx.clone();
let join_handle = std::thread::spawn(move || {
for i in 0..10000 {
tx.unbounded_send(i).unwrap();
}
});
join_handles.push(join_handle);
}
// When they are done, our len should be accurate:
for handle in join_handles {
handle.join().unwrap();
}
assert_eq!(tx.len(), 50 * 10_000);
}
#[tokio::test]
async fn channel_len_consistent_when_send_and_recv_parallelised() {
let (tx, mut rx) = metered_unbounded::<usize>();
// Send lots of messages on a bunch of real threads:
let mut join_handles = vec![];
for _ in 0..50 {
let tx = tx.clone();
let join_handle = std::thread::spawn(move || {
for i in 0..10000 {
tx.unbounded_send(i).unwrap();
}
});
join_handles.push(join_handle);
}
// While this is happenening, we are trying to receive that same number of msgs:
for _ in 0..500_000 {
rx.next().await.unwrap();
}
// When they are done, our len should be accurate:
for handle in join_handles {
handle.join().unwrap();
}
assert_eq!(tx.len(), 0);
}
}
-3
View File
@@ -1,3 +0,0 @@
mod metered_unbounded;
pub use metered_unbounded::*;
-1
View File
@@ -15,7 +15,6 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
pub mod byte_size;
pub mod channel;
pub mod http_utils;
pub mod id_type;
pub mod internal_messages;
+1
View File
@@ -15,6 +15,7 @@ bimap = "0.6.1"
bincode = "1.3.3"
bytes = "1.0.1"
common = { path = "../common" }
flume = "0.10.8"
futures = "0.3.15"
hex = "0.4.3"
http = "0.2.4"
@@ -18,7 +18,6 @@ use super::inner_loop;
use crate::find_location::find_location;
use crate::state::NodeId;
use common::id_type;
use futures::channel::mpsc;
use futures::{future, Sink, SinkExt};
use std::net::Ipv4Addr;
use std::sync::atomic::AtomicU64;
@@ -54,17 +53,17 @@ struct AggregatorInternal {
/// Send messages in to the aggregator from the outside via this. This is
/// stored here so that anybody holding an `Aggregator` handle can
/// make use of it.
tx_to_aggregator: mpsc::UnboundedSender<inner_loop::ToAggregator>,
tx_to_aggregator: flume::Sender<inner_loop::ToAggregator>,
}
impl Aggregator {
/// Spawn a new Aggregator. This connects to the telemetry backend
pub async fn spawn(opts: AggregatorOpts) -> anyhow::Result<Aggregator> {
let (tx_to_aggregator, rx_from_external) = mpsc::unbounded();
let (tx_to_aggregator, rx_from_external) = flume::unbounded();
// Kick off a locator task to locate nodes, which hands back a channel to make location requests
let tx_to_locator = find_location(tx_to_aggregator.clone().with(|(node_id, msg)| {
future::ok::<_, mpsc::SendError>(inner_loop::ToAggregator::FromFindLocation(
let tx_to_locator = find_location(tx_to_aggregator.clone().into_sink().with(|(node_id, msg)| {
future::ok::<_, flume::SendError<_>>(inner_loop::ToAggregator::FromFindLocation(
node_id, msg,
))
}));
@@ -89,8 +88,8 @@ impl Aggregator {
/// in to the aggregator. If nobody is tolding the tx side of the channel
/// any more, this task will gracefully end.
async fn handle_messages(
rx_from_external: mpsc::UnboundedReceiver<inner_loop::ToAggregator>,
tx_to_aggregator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>,
rx_from_external: flume::Receiver<inner_loop::ToAggregator>,
tx_to_aggregator: flume::Sender<(NodeId, Ipv4Addr)>,
max_queue_len: usize,
denylist: Vec<String>,
) {
@@ -114,7 +113,7 @@ impl Aggregator {
// Calling `send` on this Sink requires Unpin. There may be a nicer way than this,
// but pinning by boxing is the easy solution for now:
Box::pin(tx_to_aggregator.with(move |msg| async move {
Box::pin(tx_to_aggregator.into_sink().with(move |msg| async move {
Ok(inner_loop::ToAggregator::FromShardWebsocket(
shard_conn_id.into(),
msg,
@@ -141,7 +140,7 @@ impl Aggregator {
// but pinning by boxing is the easy solution for now:
(
feed_conn_id,
Box::pin(tx_to_aggregator.with(move |msg| async move {
Box::pin(tx_to_aggregator.into_sink().with(move |msg| async move {
Ok(inner_loop::ToAggregator::FromFeedWebsocket(
feed_conn_id.into(),
msg,
@@ -20,13 +20,11 @@ use crate::find_location;
use crate::state::{self, NodeId, State};
use bimap::BiMap;
use common::{
channel::metered_unbounded,
internal_messages::{self, MuteReason, ShardNodeId},
node_message,
node_types::BlockHash,
time,
};
use futures::channel::mpsc;
use futures::StreamExt;
use std::collections::{HashMap, HashSet};
use std::{
@@ -48,7 +46,7 @@ pub enum FromShardWebsocket {
/// When the socket is opened, it'll send this first
/// so that we have a way to communicate back to it.
Initialize {
channel: mpsc::UnboundedSender<ToShardWebsocket>,
channel: flume::Sender<ToShardWebsocket>,
},
/// Tell the aggregator about a new node.
Add {
@@ -86,7 +84,7 @@ pub enum FromFeedWebsocket {
/// Unbounded so that slow feeds don't block aggregato
/// progress.
Initialize {
channel: mpsc::UnboundedSender<ToFeedWebsocket>,
channel: flume::Sender<ToFeedWebsocket>,
},
/// The feed can subscribe to a chain to receive
/// messages relating to it.
@@ -135,9 +133,9 @@ pub struct InnerLoop {
node_ids: BiMap<NodeId, (ConnId, ShardNodeId)>,
/// Keep track of how to send messages out to feeds.
feed_channels: HashMap<ConnId, mpsc::UnboundedSender<ToFeedWebsocket>>,
feed_channels: HashMap<ConnId, flume::Sender<ToFeedWebsocket>>,
/// Keep track of how to send messages out to shards.
shard_channels: HashMap<ConnId, mpsc::UnboundedSender<ToShardWebsocket>>,
shard_channels: HashMap<ConnId, flume::Sender<ToShardWebsocket>>,
/// Which chain is a feed subscribed to?
/// Feed Connection ID -> Chain Genesis Hash
@@ -150,7 +148,7 @@ pub struct InnerLoop {
feed_conn_id_finality: HashSet<ConnId>,
/// Send messages here to make geographical location requests.
tx_to_locator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>,
tx_to_locator: flume::Sender<(NodeId, Ipv4Addr)>,
/// How big can the queue of messages coming in to the aggregator get before messages
/// are prioritised and dropped to try and get back on track.
@@ -160,7 +158,7 @@ pub struct InnerLoop {
impl InnerLoop {
/// Create a new inner loop handler with the various state it needs.
pub fn new(
tx_to_locator: mpsc::UnboundedSender<(NodeId, Ipv4Addr)>,
tx_to_locator: flume::Sender<(NodeId, Ipv4Addr)>,
denylist: Vec<String>,
max_queue_len: usize,
) -> Self {
@@ -177,14 +175,16 @@ impl InnerLoop {
}
}
/// Start handling and responding to incoming messages. Owing to unbounded channels, we actually
/// only have a single `.await` (in this function). This helps to make it clear that the aggregator loop
/// will be able to make progress quickly without any potential yield points.
pub async fn handle(mut self, mut rx_from_external: mpsc::UnboundedReceiver<ToAggregator>) {
/// Start handling and responding to incoming messages.
pub async fn handle(mut self, rx_from_external: flume::Receiver<ToAggregator>) {
let max_queue_len = self.max_queue_len;
let (metered_tx, mut metered_rx) = metered_unbounded();
let (metered_tx, metered_rx) = flume::unbounded();
// Actually handle all of our messages, but before we get here, we
// check the length of the queue below to decide whether or not to
// pass the message on to this.
tokio::spawn(async move {
let mut metered_rx = metered_rx.into_stream();
while let Some(msg) = metered_rx.next().await {
match msg {
ToAggregator::FromFeedWebsocket(feed_conn_id, msg) => {
@@ -215,9 +215,11 @@ impl InnerLoop {
});
});
let mut rx_from_external = rx_from_external.into_stream();
while let Some(msg) = rx_from_external.next().await {
// ignore node updates if we have too many messages to handle, in an attempt
// to reduce the queue length back to something reasonable.
// to reduce the queue length back to something reasonable, lest it get out of
// control and start consuming a load of memory.
if metered_tx.len() > max_queue_len {
if matches!(
msg,
@@ -227,7 +229,7 @@ impl InnerLoop {
}
}
if let Err(e) = metered_tx.unbounded_send(msg) {
if let Err(e) = metered_tx.send(msg) {
log::error!("Cannot send message into aggregator: {}", e);
break;
}
@@ -277,7 +279,7 @@ impl InnerLoop {
match self.node_state.add_node(genesis_hash, node) {
state::AddNodeResult::ChainOnDenyList => {
if let Some(shard_conn) = self.shard_channels.get_mut(&shard_conn_id) {
let _ = shard_conn.unbounded_send(ToShardWebsocket::Mute {
let _ = shard_conn.send(ToShardWebsocket::Mute {
local_id,
reason: MuteReason::ChainNotAllowed,
});
@@ -285,7 +287,7 @@ impl InnerLoop {
}
state::AddNodeResult::ChainOverQuota => {
if let Some(shard_conn) = self.shard_channels.get_mut(&shard_conn_id) {
let _ = shard_conn.unbounded_send(ToShardWebsocket::Mute {
let _ = shard_conn.send(ToShardWebsocket::Mute {
local_id,
reason: MuteReason::Overquota,
});
@@ -326,7 +328,7 @@ impl InnerLoop {
// Ask for the grographical location of the node.
// Currently we only geographically locate IPV4 addresses so ignore IPV6.
if let IpAddr::V4(ip_v4) = ip {
let _ = self.tx_to_locator.unbounded_send((node_id, ip_v4));
let _ = self.tx_to_locator.send((node_id, ip_v4));
}
}
}
@@ -409,7 +411,7 @@ impl InnerLoop {
// Send this to the channel that subscribed:
if let Some(bytes) = feed_serializer.into_finalized() {
let _ = channel.unbounded_send(ToFeedWebsocket::Bytes(bytes));
let _ = channel.send(ToFeedWebsocket::Bytes(bytes));
}
}
FromFeedWebsocket::Ping { value } => {
@@ -422,7 +424,7 @@ impl InnerLoop {
let mut feed_serializer = FeedMessageSerializer::new();
feed_serializer.push(feed_message::Pong(&value));
if let Some(bytes) = feed_serializer.into_finalized() {
let _ = feed_channel.unbounded_send(ToFeedWebsocket::Bytes(bytes));
let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes));
}
}
FromFeedWebsocket::Subscribe { chain } => {
@@ -470,7 +472,7 @@ impl InnerLoop {
new_chain.finalized_block().hash,
));
if let Some(bytes) = feed_serializer.into_finalized() {
let _ = feed_channel.unbounded_send(ToFeedWebsocket::Bytes(bytes));
let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes));
}
// If many (eg 10k) nodes are connected, serializing all of their info takes time.
@@ -505,7 +507,7 @@ impl InnerLoop {
})
.collect();
for bytes in all_feed_messages {
let _ = feed_channel.unbounded_send(ToFeedWebsocket::Bytes(bytes));
let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes));
}
// Actually make a note of the new chain subsciption:
@@ -620,7 +622,7 @@ impl InnerLoop {
if let Some(feeds) = self.chain_to_feed_conn_ids.get(genesis_hash) {
for &feed_id in feeds {
if let Some(chan) = self.feed_channels.get_mut(&feed_id) {
let _ = chan.unbounded_send(message.clone());
let _ = chan.send(message.clone());
}
}
}
@@ -636,7 +638,7 @@ impl InnerLoop {
/// Send a message to everybody.
fn broadcast_to_all_feeds(&mut self, message: ToFeedWebsocket) {
for chan in self.feed_channels.values_mut() {
let _ = chan.unbounded_send(message.clone());
let _ = chan.send(message.clone());
}
}
@@ -662,7 +664,7 @@ impl InnerLoop {
// are also subscribed to receive finality updates.
for &feed_id in feeds.union(&self.feed_conn_id_finality) {
if let Some(chan) = self.feed_channels.get_mut(&feed_id) {
let _ = chan.unbounded_send(message.clone());
let _ = chan.send(message.clone());
}
}
}
+3 -3
View File
@@ -17,7 +17,6 @@
use std::net::Ipv4Addr;
use std::sync::Arc;
use futures::channel::mpsc;
use futures::{Sink, SinkExt, StreamExt};
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
@@ -31,12 +30,13 @@ pub type Location = Option<Arc<NodeLocation>>;
/// This is responsible for taking an IP address and attempting
/// to find a geographical location from this
pub fn find_location<Id, R>(response_chan: R) -> mpsc::UnboundedSender<(Id, Ipv4Addr)>
pub fn find_location<Id, R>(response_chan: R) -> flume::Sender<(Id, Ipv4Addr)>
where
R: Sink<(Id, Option<Arc<NodeLocation>>)> + Unpin + Send + Clone + 'static,
Id: Clone + Send + 'static,
{
let (tx, mut rx) = mpsc::unbounded();
let (tx, rx) = flume::unbounded();
let mut rx = rx.into_stream();
// cache entries
let mut cache: FxHashMap<Ipv4Addr, Option<Arc<NodeLocation>>> = FxHashMap::default();
+5 -4
View File
@@ -29,7 +29,7 @@ use bincode::Options;
use common::http_utils;
use common::internal_messages;
use common::ready_chunks_all::ReadyChunksAll;
use futures::{channel::mpsc, SinkExt, StreamExt};
use futures::{SinkExt, StreamExt};
use hyper::{Method, Response};
use simple_logger::SimpleLogger;
use structopt::StructOpt;
@@ -201,7 +201,8 @@ async fn handle_shard_websocket_connection<S>(
where
S: futures::Sink<FromShardWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
let (tx_to_shard_conn, mut rx_from_aggregator) = mpsc::unbounded();
let (tx_to_shard_conn, rx_from_aggregator) = flume::unbounded();
let mut rx_from_aggregator = rx_from_aggregator.into_stream();
// Tell the aggregator about this new connection, and give it a way to send messages to us:
let init_msg = FromShardWebsocket::Initialize {
@@ -343,8 +344,8 @@ where
S: futures::Sink<FromFeedWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// unbounded channel so that slow feeds don't block aggregator progress:
let (tx_to_feed_conn, rx_from_aggregator) = mpsc::unbounded();
let mut rx_from_aggregator_chunks = ReadyChunksAll::new(rx_from_aggregator);
let (tx_to_feed_conn, rx_from_aggregator) = flume::unbounded();
let mut rx_from_aggregator_chunks = ReadyChunksAll::new(rx_from_aggregator.into_stream());
// Tell the aggregator about this new connection, and give it a way to send messages to us:
let init_msg = FromFeedWebsocket::Initialize {