Add/test max_nodes_per_connection feature

This commit is contained in:
James Wilson
2021-07-27 12:37:32 +01:00
parent ecf5fccaab
commit a44f39cfaf
4 changed files with 203 additions and 30 deletions
+117 -13
View File
@@ -6,7 +6,7 @@ use std::time::Duration;
use test_utils::{
assert_contains_matches,
feed_message_de::{FeedMessage, NodeDetails},
workspace::{ start_server, CoreOpts, start_server_debug }
workspace::{ start_server, CoreOpts, ShardOpts, start_server_debug }
};
/// The simplest test we can run; the main benefit of this test (since we check similar)
@@ -483,7 +483,10 @@ async fn slow_feeds_are_disconnected() {
// Start server in release mode with a 1s feed timeout (to make the test run faster):
let mut server = start_server(
true,
CoreOpts { feed_timeout: Some(1) }
// Timeout faster so the test can be quicker:
CoreOpts { feed_timeout: Some(1) },
// Allow us to send more messages in more easily:
ShardOpts { max_nodes_per_connection: Some(100_000) }
).await;
// Give us a shard to talk to:
@@ -491,7 +494,9 @@ async fn slow_feeds_are_disconnected() {
let (mut node_tx, _node_rx) = server.get_shard(shard_id).unwrap().connect_node().await.unwrap();
// Add a load of nodes from this shard so there's plenty of data to give to a feed.
// We want to exhaust any buffers between core and feed (eg BufWriters).
// We want to exhaust any buffers between core and feed (eg BufWriters). If the number
// is too low, data will happily be sent into a buffer and the connection won't need to
// be closed.
for n in 1..50_000 {
node_tx.send_json_text(json!({
"id":n,
@@ -521,18 +526,117 @@ async fn slow_feeds_are_disconnected() {
// be booted after ~a second.
tokio::time::sleep(Duration::from_secs(2)).await;
let mut v = Vec::new();
// Drain anything out and expect to hit a "closed" error, rather than get stuck
// waiting to receive mroe data (or see some other error).
loop {
let mut v = Vec::new();
let data = tokio::time::timeout(
Duration::from_secs(1),
raw_feed_rx.receive_data(&mut v)
).await;
// Drain anything out and expect to hit a "closed" error.
let res = loop {
if let Err(e) = raw_feed_rx.receive_data(&mut v).await {
break e
match data {
Ok(Ok(_)) => {
continue; // Drain data
}
Ok(Err(soketto::connection::Error::Closed)) => {
break; // End loop; success!
},
Ok(Err(e)) => {
panic!("recv should be closed but instead we saw this error: {}", e);
},
Err(_) => {
panic!("recv should be closed but seems to be happy waiting for more data");
},
}
};
assert!(
matches!(res, soketto::connection::Error::Closed),
"Should be Closed error, but is {:?}", res
);
}
// Tidy up:
server.shutdown().await;
}
/// If something connects to the `/submit` endpoint, there is a limit to the number
/// of different messags IDs it can send telemetry about, to prevent a malicious actor from
/// spamming a load of message IDs and exhausting our memory.
#[tokio::test]
async fn max_nodes_per_connection_is_enforced() {
let mut server = start_server(
false,
CoreOpts::default(),
// Limit max nodes per connection to 2; any other msgs should be ignored.
ShardOpts { max_nodes_per_connection: Some(2) }
).await;
// Connect to a shard
let shard_id = server.add_shard().await.unwrap();
let (mut node_tx, _node_rx) = server.get_shard(shard_id).unwrap().connect_node().await.unwrap();
// Connect a feed.
let (mut feed_tx, mut feed_rx) = server.get_core().connect_feed().await.unwrap();
// We'll send these messages from the node:
let json_msg = |n| json!({
"id":n,
"ts":"2021-07-12T10:37:47.714666+01:00",
"payload": {
"authority":true,
"chain":"Test Chain",
"config":"",
"genesis_hash": BlockHash::from_low_u64_ne(1),
"implementation":"Polkadot",
"msg":"system.connected",
"name": format!("Alice {}", n),
"network_id":"12D3KooWEyoppNCUx8Yx66oV9fJnriXwCcXwDDUA2kj6vnc6iDEp",
"startup_time":"1625565542717",
"version":"2.0.0-07a1af348-aarch64-macos"
}
});
// First message ID should lead to feed messages:
node_tx.send_json_text(json_msg(1)).unwrap();
assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
// Second message ID should lead to feed messages as well:
node_tx.send_json_text(json_msg(2)).unwrap();
assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
// Third message ID should be ignored:
node_tx.send_json_text(json_msg(3)).unwrap();
assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
// Forth message ID should be ignored as well:
node_tx.send_json_text(json_msg(4)).unwrap();
assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
// (now that the chain "Test Chain" is known about, subscribe to it for update messages.
// This wasn't needed to receive messages re the above since everybody hears about node
// count changes)
feed_tx.send_command("subscribe", "Test Chain").unwrap();
feed_rx.recv_feed_messages().await.unwrap();
// Update about non-ignored IDs should still lead to feed output:
node_tx.send_json_text(json!(
{"id":1, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" }
)).unwrap();
assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
node_tx.send_json_text(json!(
{"id":2, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" }
)).unwrap();
assert_ne!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
// Updates about ignored IDs are still ignored:
node_tx.send_json_text(json!(
{"id":3, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" }
)).unwrap();
assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
node_tx.send_json_text(json!(
{"id":4, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" }
)).unwrap();
assert_eq!(feed_rx.recv_feed_messages_timeout(Duration::from_secs(1)).await.unwrap().len(), 0);
// Tidy up:
server.shutdown().await;
+30 -4
View File
@@ -5,7 +5,7 @@ mod connection;
mod json_message;
mod real_ip;
use std::net::IpAddr;
use std::{collections::HashSet, net::IpAddr};
use aggregator::{Aggregator, FromWebsocket};
use common::node_message;
@@ -32,7 +32,7 @@ struct Opts {
socket: std::net::SocketAddr,
/// The desired log level; one of 'error', 'warn', 'info', 'debug' or 'trace', where
/// 'error' only logs errors and 'trace' logs everything.
#[structopt(required = false, long = "log", default_value = "info")]
#[structopt(long = "log", default_value = "info")]
log_level: log::LevelFilter,
/// Url to the Backend Core endpoint accepting shard connections
#[structopt(
@@ -41,6 +41,13 @@ struct Opts {
default_value = "ws://127.0.0.1:8000/shard_submit/"
)]
core_url: Uri,
/// How many different nodes is a given connection to the /submit endpoint allowed to
/// tell us about before we ignore the rest?
///
/// This is important because without a limit, a single connection could exhaust
/// RAM by suggesting that it accounts for billions of nodes.
#[structopt(long, default_value = "20")]
max_nodes_per_connection: usize
}
#[tokio::main]
@@ -62,7 +69,10 @@ async fn main() {
/// Declare our routes and start the server.
async fn start_server(opts: Opts) -> anyhow::Result<()> {
let aggregator = Aggregator::spawn(opts.core_url).await?;
let server = http_utils::start_server(opts.socket, move |addr, req| {
let socket_addr = opts.socket;
let max_nodes_per_connection = opts.max_nodes_per_connection;
let server = http_utils::start_server(socket_addr, move |addr, req| {
let aggregator = aggregator.clone();
async move {
match (req.method(), req.uri().path().trim_end_matches('/')) {
@@ -76,7 +86,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
Ok(http_utils::upgrade_to_websocket(req, move |ws_send, ws_recv| async move {
let tx_to_aggregator = aggregator.subscribe_node();
let (mut tx_to_aggregator, mut ws_send)
= handle_node_websocket_connection(real_addr, ws_send, ws_recv, tx_to_aggregator).await;
= handle_node_websocket_connection(real_addr, ws_send, ws_recv, tx_to_aggregator, max_nodes_per_connection).await;
log::info!("Closing /submit connection from {:?}", addr);
// Tell the aggregator that this connection has closed, so it can tidy up.
let _ = tx_to_aggregator.send(FromWebsocket::Disconnected).await;
@@ -104,10 +114,15 @@ async fn handle_node_websocket_connection<S>(
ws_send: http_utils::WsSender,
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
max_nodes_per_connection: usize
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// Track all of the message IDs that we've seen so far. If we exceed the
// max_nodes_per_connection limit we ignore subsequent message IDs.
let mut message_ids_seen = HashSet::new();
// This could be a oneshot channel, but it's useful to be able to clone
// messages, and we can't clone oneshot channel senders.
let (close_connection_tx, mut close_connection_rx) = mpsc::channel(0);
@@ -166,6 +181,17 @@ where
let message_id = node_message.id();
let payload = node_message.into_payload();
// Ignore messages from IDs that exceed our limit:
if message_ids_seen.contains(&message_id) {
// continue on; we're happy
} else if message_ids_seen.len() >= max_nodes_per_connection {
// ignore this message; it's not a "seen" ID and we've hit our limit.
continue;
} else {
// not seen ID, not hit limit; make note of new ID
message_ids_seen.insert(message_id);
}
// Until the aggregator receives an `Add` message, which we can create once
// we see one of these SystemConnected ones, it will ignore messages with
// the corresponding message_id.
+31 -8
View File
@@ -208,12 +208,13 @@ impl FeedReceiver {
/// Prefer [`FeedReceiver::recv_feed_messages`]; tests should generally be
/// robust in assuming that messages may not all be delivered at once (unless we are
/// specifically testing which messages are buffered together).
pub async fn recv_feed_messages_once(&mut self) -> Result<Vec<FeedMessage>, anyhow::Error> {
let msg = self
.0
.next()
.await
.ok_or_else(|| anyhow::anyhow!("Stream closed: no more messages"))??;
pub async fn recv_feed_messages_once_timeout(&mut self, timeout: Duration) -> Result<Vec<FeedMessage>, anyhow::Error> {
let msg = match tokio::time::timeout(timeout, self.0.next()).await {
// Timeout elapsed; no messages back:
Err(_) => return Ok(Vec::new()),
// Something back; Complain if error no stream closed:
Ok(res) => res.ok_or_else(|| anyhow::anyhow!("Stream closed: no more messages"))??
};
match msg {
ws_client::RecvMessage::Binary(data) => {
@@ -227,11 +228,26 @@ impl FeedReceiver {
}
}
/// Wait for the next set of feed messages to arrive.
/// See `recv_feed_messages_once_timeout`.
pub async fn recv_feed_messages_once(&mut self) -> Result<Vec<FeedMessage>, anyhow::Error> {
// Default to a timeout of 30 seconds, meaning that the test will eventually end,
self.recv_feed_messages_once_timeout(Duration::from_secs(30)).await
}
/// Wait for feed messages to be sent back, building up a list of output messages until
/// the channel goes quiet for a short while.
pub async fn recv_feed_messages(&mut self) -> Result<Vec<FeedMessage>, anyhow::Error> {
///
/// If no new messages are received within the timeout given, bail with whatever we have so far.
/// This differs from `recv_feed_messages` and `recv_feed_messages_once`, which will block indefinitely
/// waiting for something to arrive
pub async fn recv_feed_messages_timeout(&mut self, timeout: Duration) -> Result<Vec<FeedMessage>, anyhow::Error> {
// Block as long as needed for messages to start coming in:
let mut feed_messages = self.recv_feed_messages_once().await?;
let mut feed_messages = match tokio::time::timeout(timeout, self.recv_feed_messages_once()).await {
Ok(msgs) => msgs?,
Err(_) => return Ok(Vec::new()),
};
// Then, loop a little to make sure we catch any additional messages that are sent soon after:
loop {
match tokio::time::timeout(Duration::from_millis(250), self.recv_feed_messages_once())
@@ -250,4 +266,11 @@ impl FeedReceiver {
}
}
}
/// Wait for feed messages until nothing else arrives in a timely fashion.
/// See `recv_feed_messages_timeout`.
pub async fn recv_feed_messages(&mut self) -> Result<Vec<FeedMessage>, anyhow::Error> {
// Default to a timeout of 30 seconds, meaning that the test will eventually end,
self.recv_feed_messages_timeout(Duration::from_secs(30)).await
}
}
@@ -1,7 +1,7 @@
use super::commands;
use crate::server::{self, Server, Command};
/// Additional options to pass to the feed command.
/// Additional options to pass to the core command.
pub struct CoreOpts {
pub feed_timeout: Option<u64>
}
@@ -14,6 +14,19 @@ impl Default for CoreOpts {
}
}
/// Additional options to pass to the shard command.
pub struct ShardOpts {
pub max_nodes_per_connection: Option<usize>
}
impl Default for ShardOpts {
fn default() -> Self {
Self {
max_nodes_per_connection: None
}
}
}
/// Start a telemetry server. We'll use `cargo run` by default, but you can also provide
/// env vars to configure the binary that runs for the shard and core process. Either:
///
@@ -31,7 +44,7 @@ impl Default for CoreOpts {
/// - `TELEMETRY_SUBMIT_HOSTS` - hosts (comma separated) to connect to for telemetry `/submit`s.
/// - `TELEMETRY_FEED_HOST` - host to connect to for feeds (eg 127.0.0.1:3000)
///
pub async fn start_server(release_mode: bool, core_opts: CoreOpts) -> Server {
pub async fn start_server(release_mode: bool, core_opts: CoreOpts, shard_opts: ShardOpts) -> Server {
// Start to a single process:
if let Ok(bin) = std::env::var("TELEMETRY_BIN") {
return Server::start(server::StartOpts::SingleProcess {
@@ -52,10 +65,17 @@ pub async fn start_server(release_mode: bool, core_opts: CoreOpts) -> Server {
}
// Build the shard command
let shard_command = std::env::var("TELEMETRY_SHARD_BIN")
let mut shard_command = std::env::var("TELEMETRY_SHARD_BIN")
.map(|val| Command::new(val))
.unwrap_or_else(|_| commands::cargo_run_telemetry_shard(release_mode).expect("must be in rust workspace to run shard command"));
// Append additional opts to the shard command
if let Some(max_nodes_per_connection) = shard_opts.max_nodes_per_connection {
shard_command = shard_command
.arg("--max-nodes-per-connection")
.arg(max_nodes_per_connection.to_string());
}
// Build the core command
let mut core_command = std::env::var("TELEMETRY_CORE_BIN")
.map(|val| Command::new(val))
@@ -77,10 +97,10 @@ pub async fn start_server(release_mode: bool, core_opts: CoreOpts) -> Server {
/// Start a telemetry core server in debug mode. see [`start_server`] for details.
pub async fn start_server_debug() -> Server {
start_server(false, CoreOpts::default()).await
start_server(false, CoreOpts::default(), ShardOpts::default()).await
}
/// Start a telemetry core server in release mode. see [`start_server`] for details.
pub async fn start_server_release() -> Server {
start_server(true, CoreOpts::default()).await
start_server(true, CoreOpts::default(), ShardOpts::default()).await
}