diff --git a/backend/telemetry_core/tests/e2e_tests.rs b/backend/telemetry_core/tests/e2e_tests.rs index ddc1238..dec48c2 100644 --- a/backend/telemetry_core/tests/e2e_tests.rs +++ b/backend/telemetry_core/tests/e2e_tests.rs @@ -6,7 +6,7 @@ use std::time::Duration; use test_utils::{ assert_contains_matches, feed_message_de::{FeedMessage, NodeDetails}, - workspace::{ start_server, CoreOpts, start_server_debug } + workspace::{ start_server, CoreOpts, ShardOpts, start_server_debug } }; /// The simplest test we can run; the main benefit of this test (since we check similar) @@ -483,7 +483,10 @@ async fn slow_feeds_are_disconnected() { // Start server in release mode with a 1s feed timeout (to make the test run faster): let mut server = start_server( true, - CoreOpts { feed_timeout: Some(1) } + // Timeout faster so the test can be quicker: + CoreOpts { feed_timeout: Some(1) }, + // Allow us to send more messages in more easily: + ShardOpts { max_nodes_per_connection: Some(100_000) } ).await; // Give us a shard to talk to: @@ -491,7 +494,9 @@ async fn slow_feeds_are_disconnected() { let (mut node_tx, _node_rx) = server.get_shard(shard_id).unwrap().connect_node().await.unwrap(); // Add a load of nodes from this shard so there's plenty of data to give to a feed. - // We want to exhaust any buffers between core and feed (eg BufWriters). + // We want to exhaust any buffers between core and feed (eg BufWriters). If the number + // is too low, data will happily be sent into a buffer and the connection won't need to + // be closed. for n in 1..50_000 { node_tx.send_json_text(json!({ "id":n, @@ -521,18 +526,117 @@ async fn slow_feeds_are_disconnected() { // be booted after ~a second. tokio::time::sleep(Duration::from_secs(2)).await; - let mut v = Vec::new(); + // Drain anything out and expect to hit a "closed" error, rather than get stuck + // waiting to receive mroe data (or see some other error). + loop { + let mut v = Vec::new(); + let data = tokio::time::timeout( + Duration::from_secs(1), + raw_feed_rx.receive_data(&mut v) + ).await; - // Drain anything out and expect to hit a "closed" error. - let res = loop { - if let Err(e) = raw_feed_rx.receive_data(&mut v).await { - break e + match data { + Ok(Ok(_)) => { + continue; // Drain data + } + Ok(Err(soketto::connection::Error::Closed)) => { + break; // End loop; success! + }, + Ok(Err(e)) => { + panic!("recv should be closed but instead we saw this error: {}", e); + }, + Err(_) => { + panic!("recv should be closed but seems to be happy waiting for more data"); + }, } - }; - assert!( - matches!(res, soketto::connection::Error::Closed), - "Should be Closed error, but is {:?}", res - ); + } + + // Tidy up: + server.shutdown().await; +} + +/// If something connects to the `/submit` endpoint, there is a limit to the number +/// of different messags IDs it can send telemetry about, to prevent a malicious actor from +/// spamming a load of message IDs and exhausting our memory. +#[tokio::test] +async fn max_nodes_per_connection_is_enforced() { + let mut server = start_server( + false, + CoreOpts::default(), + // Limit max nodes per connection to 2; any other msgs should be ignored. + ShardOpts { max_nodes_per_connection: Some(2) } + ).await; + + // Connect to a shard + let shard_id = server.add_shard().await.unwrap(); + let (mut node_tx, _node_rx) = server.get_shard(shard_id).unwrap().connect_node().await.unwrap(); + + // Connect a feed. + let (mut feed_tx, mut feed_rx) = server.get_core().connect_feed().await.unwrap(); + + // We'll send these messages from the node: + let json_msg = |n| json!({ + "id":n, + "ts":"2021-07-12T10:37:47.714666+01:00", + "payload": { + "authority":true, + "chain":"Test Chain", + "config":"", + "genesis_hash": BlockHash::from_low_u64_ne(1), + "implementation":"Polkadot", + "msg":"system.connected", + "name": format!("Alice {}", n), + "network_id":"12D3KooWEyoppNCUx8Yx66oV9fJnriXwCcXwDDUA2kj6vnc6iDEp", + "startup_time":"1625565542717", + "version":"2.0.0-07a1af348-aarch64-macos" + } + }); + + // First message ID should lead to feed messages: + node_tx.send_json_text(json_msg(1)).unwrap(); + assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); + + // Second message ID should lead to feed messages as well: + node_tx.send_json_text(json_msg(2)).unwrap(); + assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); + + // Third message ID should be ignored: + node_tx.send_json_text(json_msg(3)).unwrap(); + assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); + + // Forth message ID should be ignored as well: + node_tx.send_json_text(json_msg(4)).unwrap(); + assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); + + // (now that the chain "Test Chain" is known about, subscribe to it for update messages. + // This wasn't needed to receive messages re the above since everybody hears about node + // count changes) + feed_tx.send_command("subscribe", "Test Chain").unwrap(); + feed_rx.recv_feed_messages().await.unwrap(); + + // Update about non-ignored IDs should still lead to feed output: + + node_tx.send_json_text(json!( + {"id":1, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" } + )).unwrap(); + assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); + + node_tx.send_json_text(json!( + {"id":2, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" } + )).unwrap(); + assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); + + // Updates about ignored IDs are still ignored: + + node_tx.send_json_text(json!( + {"id":3, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" } + )).unwrap(); + assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); + + node_tx.send_json_text(json!( + {"id":4, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" } + )).unwrap(); + assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0); // Tidy up: server.shutdown().await; diff --git a/backend/telemetry_shard/src/main.rs b/backend/telemetry_shard/src/main.rs index 711a9f2..1a57d79 100644 --- a/backend/telemetry_shard/src/main.rs +++ b/backend/telemetry_shard/src/main.rs @@ -5,7 +5,7 @@ mod connection; mod json_message; mod real_ip; -use std::net::IpAddr; +use std::{collections::HashSet, net::IpAddr}; use aggregator::{Aggregator, FromWebsocket}; use common::node_message; @@ -32,7 +32,7 @@ struct Opts { socket: std::net::SocketAddr, /// The desired log level; one of 'error', 'warn', 'info', 'debug' or 'trace', where /// 'error' only logs errors and 'trace' logs everything. - #[structopt(required = false, long = "log", default_value = "info")] + #[structopt(long = "log", default_value = "info")] log_level: log::LevelFilter, /// Url to the Backend Core endpoint accepting shard connections #[structopt( @@ -41,6 +41,13 @@ struct Opts { default_value = "ws://127.0.0.1:8000/shard_submit/" )] core_url: Uri, + /// How many different nodes is a given connection to the /submit endpoint allowed to + /// tell us about before we ignore the rest? + /// + /// This is important because without a limit, a single connection could exhaust + /// RAM by suggesting that it accounts for billions of nodes. + #[structopt(long, default_value = "20")] + max_nodes_per_connection: usize } #[tokio::main] @@ -62,7 +69,10 @@ async fn main() { /// Declare our routes and start the server. async fn start_server(opts: Opts) -> anyhow::Result<()> { let aggregator = Aggregator::spawn(opts.core_url).await?; - let server = http_utils::start_server(opts.socket, move |addr, req| { + let socket_addr = opts.socket; + let max_nodes_per_connection = opts.max_nodes_per_connection; + + let server = http_utils::start_server(socket_addr, move |addr, req| { let aggregator = aggregator.clone(); async move { match (req.method(), req.uri().path().trim_end_matches('/')) { @@ -76,7 +86,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> { Ok(http_utils::upgrade_to_websocket(req, move |ws_send, ws_recv| async move { let tx_to_aggregator = aggregator.subscribe_node(); let (mut tx_to_aggregator, mut ws_send) - = handle_node_websocket_connection(real_addr, ws_send, ws_recv, tx_to_aggregator).await; + = handle_node_websocket_connection(real_addr, ws_send, ws_recv, tx_to_aggregator, max_nodes_per_connection).await; log::info!("Closing /submit connection from {:?}", addr); // Tell the aggregator that this connection has closed, so it can tidy up. let _ = tx_to_aggregator.send(FromWebsocket::Disconnected).await; @@ -104,10 +114,15 @@ async fn handle_node_websocket_connection( ws_send: http_utils::WsSender, mut ws_recv: http_utils::WsReceiver, mut tx_to_aggregator: S, + max_nodes_per_connection: usize ) -> (S, http_utils::WsSender) where S: futures::Sink + Unpin + Send + 'static, { + // Track all of the message IDs that we've seen so far. If we exceed the + // max_nodes_per_connection limit we ignore subsequent message IDs. + let mut message_ids_seen = HashSet::new(); + // This could be a oneshot channel, but it's useful to be able to clone // messages, and we can't clone oneshot channel senders. let (close_connection_tx, mut close_connection_rx) = mpsc::channel(0); @@ -166,6 +181,17 @@ where let message_id = node_message.id(); let payload = node_message.into_payload(); + // Ignore messages from IDs that exceed our limit: + if message_ids_seen.contains(&message_id) { + // continue on; we're happy + } else if message_ids_seen.len() >= max_nodes_per_connection { + // ignore this message; it's not a "seen" ID and we've hit our limit. + continue; + } else { + // not seen ID, not hit limit; make note of new ID + message_ids_seen.insert(message_id); + } + // Until the aggregator receives an `Add` message, which we can create once // we see one of these SystemConnected ones, it will ignore messages with // the corresponding message_id. diff --git a/backend/test_utils/src/server/channels.rs b/backend/test_utils/src/server/channels.rs index 2480987..7643a14 100644 --- a/backend/test_utils/src/server/channels.rs +++ b/backend/test_utils/src/server/channels.rs @@ -208,12 +208,13 @@ impl FeedReceiver { /// Prefer [`FeedReceiver::recv_feed_messages`]; tests should generally be /// robust in assuming that messages may not all be delivered at once (unless we are /// specifically testing which messages are buffered together). - pub async fn recv_feed_messages_once(&mut self) -> Result, anyhow::Error> { - let msg = self - .0 - .next() - .await - .ok_or_else(|| anyhow::anyhow!("Stream closed: no more messages"))??; + pub async fn recv_feed_messages_once_timeout(&mut self, timeout: Duration) -> Result, anyhow::Error> { + let msg = match tokio::time::timeout(timeout, self.0.next()).await { + // Timeout elapsed; no messages back: + Err(_) => return Ok(Vec::new()), + // Something back; Complain if error no stream closed: + Ok(res) => res.ok_or_else(|| anyhow::anyhow!("Stream closed: no more messages"))?? + }; match msg { ws_client::RecvMessage::Binary(data) => { @@ -227,11 +228,26 @@ impl FeedReceiver { } } + /// Wait for the next set of feed messages to arrive. + /// See `recv_feed_messages_once_timeout`. + pub async fn recv_feed_messages_once(&mut self) -> Result, anyhow::Error> { + // Default to a timeout of 30 seconds, meaning that the test will eventually end, + self.recv_feed_messages_once_timeout(Duration::from_secs(30)).await + } + /// Wait for feed messages to be sent back, building up a list of output messages until /// the channel goes quiet for a short while. - pub async fn recv_feed_messages(&mut self) -> Result, anyhow::Error> { + /// + /// If no new messages are received within the timeout given, bail with whatever we have so far. + /// This differs from `recv_feed_messages` and `recv_feed_messages_once`, which will block indefinitely + /// waiting for something to arrive + pub async fn recv_feed_messages_timeout(&mut self, timeout: Duration) -> Result, anyhow::Error> { // Block as long as needed for messages to start coming in: - let mut feed_messages = self.recv_feed_messages_once().await?; + let mut feed_messages = match tokio::time::timeout(timeout, self.recv_feed_messages_once()).await { + Ok(msgs) => msgs?, + Err(_) => return Ok(Vec::new()), + }; + // Then, loop a little to make sure we catch any additional messages that are sent soon after: loop { match tokio::time::timeout(Duration::from_millis(250), self.recv_feed_messages_once()) @@ -250,4 +266,11 @@ impl FeedReceiver { } } } + + /// Wait for feed messages until nothing else arrives in a timely fashion. + /// See `recv_feed_messages_timeout`. + pub async fn recv_feed_messages(&mut self) -> Result, anyhow::Error> { + // Default to a timeout of 30 seconds, meaning that the test will eventually end, + self.recv_feed_messages_timeout(Duration::from_secs(30)).await + } } diff --git a/backend/test_utils/src/workspace/start_server.rs b/backend/test_utils/src/workspace/start_server.rs index 1c0f4e4..c6a5b5e 100644 --- a/backend/test_utils/src/workspace/start_server.rs +++ b/backend/test_utils/src/workspace/start_server.rs @@ -1,7 +1,7 @@ use super::commands; use crate::server::{self, Server, Command}; -/// Additional options to pass to the feed command. +/// Additional options to pass to the core command. pub struct CoreOpts { pub feed_timeout: Option } @@ -14,6 +14,19 @@ impl Default for CoreOpts { } } +/// Additional options to pass to the shard command. +pub struct ShardOpts { + pub max_nodes_per_connection: Option +} + +impl Default for ShardOpts { + fn default() -> Self { + Self { + max_nodes_per_connection: None + } + } +} + /// Start a telemetry server. We'll use `cargo run` by default, but you can also provide /// env vars to configure the binary that runs for the shard and core process. Either: /// @@ -31,7 +44,7 @@ impl Default for CoreOpts { /// - `TELEMETRY_SUBMIT_HOSTS` - hosts (comma separated) to connect to for telemetry `/submit`s. /// - `TELEMETRY_FEED_HOST` - host to connect to for feeds (eg 127.0.0.1:3000) /// -pub async fn start_server(release_mode: bool, core_opts: CoreOpts) -> Server { +pub async fn start_server(release_mode: bool, core_opts: CoreOpts, shard_opts: ShardOpts) -> Server { // Start to a single process: if let Ok(bin) = std::env::var("TELEMETRY_BIN") { return Server::start(server::StartOpts::SingleProcess { @@ -52,10 +65,17 @@ pub async fn start_server(release_mode: bool, core_opts: CoreOpts) -> Server { } // Build the shard command - let shard_command = std::env::var("TELEMETRY_SHARD_BIN") + let mut shard_command = std::env::var("TELEMETRY_SHARD_BIN") .map(|val| Command::new(val)) .unwrap_or_else(|_| commands::cargo_run_telemetry_shard(release_mode).expect("must be in rust workspace to run shard command")); + // Append additional opts to the shard command + if let Some(max_nodes_per_connection) = shard_opts.max_nodes_per_connection { + shard_command = shard_command + .arg("--max-nodes-per-connection") + .arg(max_nodes_per_connection.to_string()); + } + // Build the core command let mut core_command = std::env::var("TELEMETRY_CORE_BIN") .map(|val| Command::new(val)) @@ -77,10 +97,10 @@ pub async fn start_server(release_mode: bool, core_opts: CoreOpts) -> Server { /// Start a telemetry core server in debug mode. see [`start_server`] for details. pub async fn start_server_debug() -> Server { - start_server(false, CoreOpts::default()).await + start_server(false, CoreOpts::default(), ShardOpts::default()).await } /// Start a telemetry core server in release mode. see [`start_server`] for details. pub async fn start_server_release() -> Server { - start_server(true, CoreOpts::default()).await + start_server(true, CoreOpts::default(), ShardOpts::default()).await } \ No newline at end of file