Add/test max_nodes_per_connection feature

This commit is contained in:
James Wilson
2021-07-27 12:37:32 +01:00
parent ecf5fccaab
commit a44f39cfaf
4 changed files with 203 additions and 30 deletions
+30 -4
View File
@@ -5,7 +5,7 @@ mod connection;
mod json_message;
mod real_ip;
use std::net::IpAddr;
use std::{collections::HashSet, net::IpAddr};
use aggregator::{Aggregator, FromWebsocket};
use common::node_message;
@@ -32,7 +32,7 @@ struct Opts {
socket: std::net::SocketAddr,
/// The desired log level; one of 'error', 'warn', 'info', 'debug' or 'trace', where
/// 'error' only logs errors and 'trace' logs everything.
#[structopt(required = false, long = "log", default_value = "info")]
#[structopt(long = "log", default_value = "info")]
log_level: log::LevelFilter,
/// Url to the Backend Core endpoint accepting shard connections
#[structopt(
@@ -41,6 +41,13 @@ struct Opts {
default_value = "ws://127.0.0.1:8000/shard_submit/"
)]
core_url: Uri,
/// How many different nodes is a given connection to the /submit endpoint allowed to
/// tell us about before we ignore the rest?
///
/// This is important because without a limit, a single connection could exhaust
/// RAM by suggesting that it accounts for billions of nodes.
#[structopt(long, default_value = "20")]
max_nodes_per_connection: usize
}
#[tokio::main]
@@ -62,7 +69,10 @@ async fn main() {
/// Declare our routes and start the server.
async fn start_server(opts: Opts) -> anyhow::Result<()> {
let aggregator = Aggregator::spawn(opts.core_url).await?;
let server = http_utils::start_server(opts.socket, move |addr, req| {
let socket_addr = opts.socket;
let max_nodes_per_connection = opts.max_nodes_per_connection;
let server = http_utils::start_server(socket_addr, move |addr, req| {
let aggregator = aggregator.clone();
async move {
match (req.method(), req.uri().path().trim_end_matches('/')) {
@@ -76,7 +86,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
Ok(http_utils::upgrade_to_websocket(req, move |ws_send, ws_recv| async move {
let tx_to_aggregator = aggregator.subscribe_node();
let (mut tx_to_aggregator, mut ws_send)
= handle_node_websocket_connection(real_addr, ws_send, ws_recv, tx_to_aggregator).await;
= handle_node_websocket_connection(real_addr, ws_send, ws_recv, tx_to_aggregator, max_nodes_per_connection).await;
log::info!("Closing /submit connection from {:?}", addr);
// Tell the aggregator that this connection has closed, so it can tidy up.
let _ = tx_to_aggregator.send(FromWebsocket::Disconnected).await;
@@ -104,10 +114,15 @@ async fn handle_node_websocket_connection<S>(
ws_send: http_utils::WsSender,
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
max_nodes_per_connection: usize
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// Track all of the message IDs that we've seen so far. If we exceed the
// max_nodes_per_connection limit we ignore subsequent message IDs.
let mut message_ids_seen = HashSet::new();
// This could be a oneshot channel, but it's useful to be able to clone
// messages, and we can't clone oneshot channel senders.
let (close_connection_tx, mut close_connection_rx) = mpsc::channel(0);
@@ -166,6 +181,17 @@ where
let message_id = node_message.id();
let payload = node_message.into_payload();
// Ignore messages from IDs that exceed our limit:
if message_ids_seen.contains(&message_id) {
// continue on; we're happy
} else if message_ids_seen.len() >= max_nodes_per_connection {
// ignore this message; it's not a "seen" ID and we've hit our limit.
continue;
} else {
// not seen ID, not hit limit; make note of new ID
message_ids_seen.insert(message_id);
}
// Until the aggregator receives an `Add` message, which we can create once
// we see one of these SystemConnected ones, it will ignore messages with
// the corresponding message_id.