From aa63e2960c5f2e206c2381b962db6ec1171a7728 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Wed, 28 Jul 2021 16:43:51 +0100 Subject: [PATCH] block addresses from reconnecting if they send too much data --- backend/telemetry_shard/src/blocked_addrs.rs | 50 ++++++++++++++++++++ backend/telemetry_shard/src/main.rs | 19 ++++++++ 2 files changed, 69 insertions(+) create mode 100644 backend/telemetry_shard/src/blocked_addrs.rs diff --git a/backend/telemetry_shard/src/blocked_addrs.rs b/backend/telemetry_shard/src/blocked_addrs.rs new file mode 100644 index 0000000..93e92b3 --- /dev/null +++ b/backend/telemetry_shard/src/blocked_addrs.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; +use std::time::{Duration, Instant}; +use std::net::IpAddr; +use std::sync::{ Mutex, Arc }; + +/// Keep track of nodes that have been blocked. +#[derive(Debug, Clone)] +pub struct BlockedAddrs(Arc); + +#[derive(Debug)] +struct BlockAddrsInner { + block_duration: Duration, + inner: Mutex> +} + +impl BlockedAddrs { + /// Create a new block list. Nodes are blocked for the duration + /// provided here. + pub fn new(block_duration: Duration) -> BlockedAddrs { + BlockedAddrs(Arc::new(BlockAddrsInner { + block_duration, + inner: Mutex::new(HashMap::new()) + })) + } + + /// Block a new address + pub fn block_addr(&self, addr: IpAddr, reason: &'static str) { + let now = Instant::now(); + self.0.inner.lock().unwrap().insert(addr, (reason, now)); + } + + /// Find out whether an address has been blocked. If it has, a reason + /// will be returned. Else, we'll get None back. This function may also + /// perform cleanup if the item was blocked and the block has expired. + pub fn blocked_reason(&self, addr: &IpAddr) -> Option<&'static str> { + let mut map = self.0.inner.lock().unwrap(); + + let (reason, time) = match map.get(addr) { + Some(&(reason,time)) => (reason, time), + None => return None + }; + + if time + self.0.block_duration < Instant::now() { + map.remove(addr); + None + } else { + Some(reason) + } + } +} \ No newline at end of file diff --git a/backend/telemetry_shard/src/main.rs b/backend/telemetry_shard/src/main.rs index 825c3e4..7c403c2 100644 --- a/backend/telemetry_shard/src/main.rs +++ b/backend/telemetry_shard/src/main.rs @@ -3,6 +3,7 @@ mod aggregator; mod connection; mod json_message; mod real_ip; +mod blocked_addrs; use std::{collections::HashSet, net::IpAddr, time::Duration}; @@ -16,6 +17,7 @@ use http::Uri; use hyper::{Method, Response}; use simple_logger::SimpleLogger; use structopt::StructOpt; +use blocked_addrs::BlockedAddrs; const VERSION: &str = env!("CARGO_PKG_VERSION"); const AUTHORS: &str = env!("CARGO_PKG_AUTHORS"); @@ -55,6 +57,10 @@ struct Opts { /// the average traffic in the last 10 seconds falls below this value. #[structopt(long, default_value = "512k")] max_node_data_per_second: ByteSize, + /// How many seconds is a "/feed" connection that violates the '--max-node-data-per-second' + /// value prevented from reconnecting to this shard for, in seconds. + #[structopt(long, default_value = "600")] + node_block_seconds: u64 } #[tokio::main] @@ -75,6 +81,7 @@ async fn main() { /// Declare our routes and start the server. async fn start_server(opts: Opts) -> anyhow::Result<()> { + let block_list = BlockedAddrs::new(Duration::from_secs(opts.node_block_seconds)); let aggregator = Aggregator::spawn(opts.core_url).await?; let socket_addr = opts.socket; let max_nodes_per_connection = opts.max_nodes_per_connection; @@ -82,6 +89,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> { let server = http_utils::start_server(socket_addr, move |addr, req| { let aggregator = aggregator.clone(); + let block_list = block_list.clone(); async move { match (req.method(), req.uri().path().trim_end_matches('/')) { // Check that the server is up and running: @@ -89,6 +97,14 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> { // Nodes send messages here: (&Method::GET, "/submit") => { let real_addr = real_ip::real_ip(addr, req.headers()); + + if let Some(reason) = block_list.blocked_reason(&real_addr) { + return Ok(Response::builder() + .status(403) + .body(reason.into()) + .unwrap()) + } + Ok(http_utils::upgrade_to_websocket( req, move |ws_send, ws_recv| async move { @@ -101,6 +117,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> { tx_to_aggregator, max_nodes_per_connection, bytes_per_second, + block_list ) .await; log::info!("Closing /submit connection from {:?}", addr); @@ -131,6 +148,7 @@ async fn handle_node_websocket_connection( mut tx_to_aggregator: S, max_nodes_per_connection: usize, bytes_per_second: ByteSize, + block_list: BlockedAddrs ) -> (S, http_utils::WsSender) where S: futures::Sink + Unpin + Send + 'static, @@ -188,6 +206,7 @@ where rolling_total_bytes.push(bytes.len()); let this_bytes_per_second = rolling_total_bytes.total() / 10; if this_bytes_per_second > bytes_per_second { + block_list.block_addr(real_addr, "Too much traffic"); log::error!("Shutting down websocket connection: Too much traffic ({}bps)", this_bytes_per_second); break; }