Add rolling total and allow control over bytes per second allowed from node connections

This commit is contained in:
James Wilson
2021-07-28 16:08:02 +01:00
parent 9ec48adcaa
commit 83d31ef0b3
5 changed files with 382 additions and 2 deletions
+28 -1
View File
@@ -4,11 +4,13 @@ mod connection;
mod json_message;
mod real_ip;
use std::{collections::HashSet, net::IpAddr};
use std::{collections::HashSet, net::IpAddr, time::Duration};
use aggregator::{Aggregator, FromWebsocket};
use common::http_utils;
use common::node_message;
use common::byte_size::ByteSize;
use common::rolling_total::RollingTotalBuilder;
use futures::{channel::mpsc, SinkExt, StreamExt};
use http::Uri;
use hyper::{Method, Response};
@@ -47,6 +49,12 @@ struct Opts {
/// RAM by suggesting that it accounts for billions of nodes.
#[structopt(long, default_value = "20")]
max_nodes_per_connection: usize,
/// What is the maximum number of bytes per second, on average, that a connection from a
/// node is allowed to send to a shard before it gets booted. This is averaged over a
/// rolling window of 10 seconds, and so spikes beyond this limit are allowed as long as
/// the average traffic in the last 10 seconds falls below this value.
#[structopt(long, default_value = "512k")]
max_node_data_per_second: ByteSize
}
#[tokio::main]
@@ -70,6 +78,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
let aggregator = Aggregator::spawn(opts.core_url).await?;
let socket_addr = opts.socket;
let max_nodes_per_connection = opts.max_nodes_per_connection;
let bytes_per_second = opts.max_node_data_per_second;
let server = http_utils::start_server(socket_addr, move |addr, req| {
let aggregator = aggregator.clone();
@@ -91,6 +100,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
ws_recv,
tx_to_aggregator,
max_nodes_per_connection,
bytes_per_second
)
.await;
log::info!("Closing /submit connection from {:?}", addr);
@@ -120,10 +130,19 @@ async fn handle_node_websocket_connection<S>(
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
max_nodes_per_connection: usize,
bytes_per_second: ByteSize
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// Limit the number of bytes based on a rolling total and the incoming bytes per second
// that has been configured via the CLI opts.
let bytes_per_second = bytes_per_second.into_bytes();
let mut rolling_total_bytes = RollingTotalBuilder::new()
.granularity(Duration::from_secs(1))
.window_size_multiple(10)
.start();
// Track all of the message IDs that we've seen so far. If we exceed the
// max_nodes_per_connection limit we ignore subsequent message IDs.
let mut message_ids_seen = HashSet::new();
@@ -165,6 +184,14 @@ where
break;
}
// Keep track of total bytes and bail if average over last 10 secs exceeds preference.
rolling_total_bytes.push(bytes.len());
let this_bytes_per_second = rolling_total_bytes.total() / 10;
if this_bytes_per_second > bytes_per_second {
log::error!("Shutting down websocket connection: Too much traffic ({}bps)", this_bytes_per_second);
break;
}
// Deserialize from JSON, warning in debug mode if deserialization fails:
let node_message: json_message::NodeMessage = match serde_json::from_slice(&bytes) {
Ok(node_message) => node_message,