Add rolling total and allow control over bytes per second allowed from node connections

This commit is contained in:
James Wilson
2021-07-28 16:08:02 +01:00
parent 9ec48adcaa
commit 83d31ef0b3
5 changed files with 382 additions and 2 deletions
+94
View File
@@ -0,0 +1,94 @@
use anyhow::{ anyhow, Error };
#[derive(Copy,Clone,Debug)]
pub struct ByteSize(usize);
impl ByteSize {
pub fn new(bytes: usize) -> ByteSize {
ByteSize(bytes)
}
pub fn into_bytes(self) -> usize {
self.0
}
}
impl From<ByteSize> for usize {
fn from(b: ByteSize) -> Self {
b.0
}
}
impl std::str::FromStr for ByteSize {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.trim();
match s.find(|c| !char::is_ascii_digit(&c)) {
// No non-numeric chars; assume bytes then
None => Ok(ByteSize(s.parse().expect("all ascii digits"))),
// First non-numeric char
Some(idx) => {
let n = s[..idx].parse().expect("all ascii digits");
let suffix = s[idx..].trim();
let n = match suffix {
"B" | "b" => n,
"kB" | "K" | "k" => n * 1000,
"MB" | "M" | "m" => n * 1000 * 1000,
"GB" | "G" | "g" => n * 1000 * 1000 * 1000,
"KiB" | "Ki" => n * 1024,
"MiB" | "Mi" => n * 1024 * 1024,
"GiB" | "Gi" => n * 1024 * 1024 * 1024,
_ => return Err(anyhow!("\
Cannot parse into bytes; suffix is '{}', but expecting one of \
B,b, kB,K,k, MB,M,m, GB,G,g, KiB,Ki, MiB,Mi, GiB,Gi", suffix))
};
Ok(ByteSize(n))
}
}
}
}
#[cfg(test)]
mod test {
use crate::byte_size::ByteSize;
#[test]
fn can_parse_valid_strings() {
let cases = vec![
("100", 100),
("100B", 100),
("100b", 100),
("20kB", 20 * 1000),
("20 kB", 20 * 1000),
("20K", 20 * 1000),
(" 20k", 20 * 1000),
("1MB", 1 * 1000 * 1000),
("1M", 1 * 1000 * 1000),
("1m", 1 * 1000 * 1000),
("1 m", 1 * 1000 * 1000),
("1GB", 1 * 1000 * 1000 * 1000),
("1G", 1 * 1000 * 1000 * 1000),
("1g", 1 * 1000 * 1000 * 1000),
("1KiB", 1 * 1024),
("1Ki", 1 * 1024),
("1MiB", 1 * 1024 * 1024),
("1Mi", 1 * 1024 * 1024),
("1GiB", 1 * 1024 * 1024 * 1024),
("1Gi", 1 * 1024 * 1024 * 1024),
(" 1 Gi ", 1 * 1024 * 1024 * 1024),
];
for (s, expected) in cases {
let b: ByteSize = s.parse().unwrap();
assert_eq!(b.into_bytes(), expected);
}
}
}
+2
View File
@@ -6,6 +6,8 @@ pub mod node_types;
pub mod ready_chunks_all;
pub mod time;
pub mod ws_client;
pub mod rolling_total;
pub mod byte_size;
mod assign_id;
mod dense_map;
+257
View File
@@ -0,0 +1,257 @@
use std::time::{ Duration, Instant };
use num_traits::{ Zero, SaturatingAdd, SaturatingSub };
use std::collections::VecDeque;
/// Build an object responsible for keeping track of a rolling total.
/// It does this in constant time and using memory proportional to the
/// granularity * window size multiple that we set.
pub struct RollingTotalBuilder<Time: TimeSource = SystemTimeSource> {
window_size_multiple: usize,
granularity: Duration,
time_source: Time
}
impl RollingTotalBuilder {
/// Build a [`RollingTotal`] struct. By default,
/// the window_size is 10s, the granularity is 1s,
/// and system time is used.
pub fn new() -> RollingTotalBuilder<SystemTimeSource> {
Self {
window_size_multiple: 10,
granularity: Duration::from_secs(1),
time_source: SystemTimeSource
}
}
/// Set the source of time we'll use. By default, we use system time.
pub fn time_source<Time: TimeSource>(self, val: Time) -> RollingTotalBuilder<Time> {
RollingTotalBuilder {
window_size_multiple: self.window_size_multiple,
granularity: self.granularity,
time_source: val
}
}
/// Set the size of the window of time that we'll look back on
/// to sum up values over to give us the current total. The size
/// is set as a multiple of the granularity; a granulatiry of 1s
/// and a size of 10 means the window size will be 10 seconds.
pub fn window_size_multiple(mut self, val: usize) -> Self {
self.window_size_multiple = val;
self
}
/// What is the granulatiry of our windows of time. For example, a
/// granularity of 5 seconds means that every 5 seconds the window
/// that we look at shifts forward to the next 5 seconds worth of data.
/// A larger granularity is more efficient but less accurate than a
/// smaller one.
pub fn granularity(mut self, val: Duration) -> Self {
self.granularity = val;
self
}
}
impl <Time: TimeSource> RollingTotalBuilder<Time> {
/// Create a [`RollingTotal`] with these setings, starting from the
/// instant provided.
pub fn start<T>(self) -> RollingTotal<T, Time>
where T: Zero + SaturatingAdd + SaturatingSub
{
let mut averages = VecDeque::new();
averages.push_back((self.time_source.now(), T::zero()));
RollingTotal {
window_size_multiple: self.window_size_multiple,
time_source: self.time_source,
granularity: self.granularity,
averages,
total: T::zero()
}
}
}
pub struct RollingTotal<Val, Time = SystemTimeSource> {
window_size_multiple: usize,
time_source: Time,
granularity: Duration,
averages: VecDeque<(Instant, Val)>,
total: Val
}
impl <Val, Time: TimeSource> RollingTotal<Val, Time>
where
Val: SaturatingAdd + SaturatingSub + Copy + std::fmt::Debug,
Time: TimeSource
{
/// Add a new value at some time.
pub fn push(&mut self, value: Val) {
let time = self.time_source.now();
let (last_time, last_val) = self.averages
.back_mut()
.expect("always 1 value");
let since_last_nanos = time.duration_since(*last_time).as_nanos();
let granularity_nanos = self.granularity.as_nanos();
if since_last_nanos >= granularity_nanos {
// New time doesn't fit into last bucket; create a new bucket with a time
// that is some number of granularity steps from the last, and add the
// value to that.
// This rounds down, eg 7 / 5 = 1. Find the number of granularity steps
// to jump from the last time such that the jump can fit this new value.
let steps = since_last_nanos / granularity_nanos;
// Create a new time this number of jumps forward, and push it.
let new_time = *last_time + Duration::from_nanos(granularity_nanos as u64) * steps as u32;
self.total = self.total.saturating_add(&value);
self.averages.push_back((new_time, value));
// Remove any old times/values no longer within our window size. If window_size_multiple
// is 1, then we only keep the just-pushed time, hence the "-1". Remember to keep our
// cached total up to date if we remove things.
let oldest_time_in_window = new_time - (self.granularity * (self.window_size_multiple - 1) as u32);
while self.averages.front().expect("always 1 value").0 < oldest_time_in_window {
let value = self.averages.pop_front().expect("always 1 value").1;
self.total = self.total.saturating_sub(&value);
}
} else {
// New time fits into our last bucket, so just add it on. We don't need to worry
// about bucket cleanup since number/times of buckets hasn't changed.
*last_val = last_val.saturating_add(&value);
self.total = self.total.saturating_add(&value);
}
}
/// Fetch the current rolling total that we've accumulated. Note that this
/// is based on the last seen times and values, and is not affected by the time
/// that it is called.
pub fn total(&self) -> Val {
self.total
}
/// Fetch the current time source, incase we need to modify it.
pub fn time_source(&mut self) -> &mut Time {
&mut self.time_source
}
#[cfg(test)]
pub fn averages(&self) -> &VecDeque<(Instant, Val)> {
&self.averages
}
}
/// A source of time that we can use in our rolling total.
/// This allows us to avoid explicitly mentioning time when pushing
/// new values, and makes it a little harder to accidentally pass
/// an older time and cause a panic.
pub trait TimeSource {
fn now(&self) -> Instant;
}
pub struct SystemTimeSource;
impl TimeSource for SystemTimeSource {
fn now(&self) -> Instant {
Instant::now()
}
}
pub struct UserTimeSource(Instant);
impl UserTimeSource {
pub fn new(time: Instant) -> Self {
UserTimeSource(time)
}
pub fn set_time(&mut self, time: Instant) {
self.0 = time;
}
pub fn increment_by(&mut self, duration: Duration) {
self.0 += duration;
}
}
impl TimeSource for UserTimeSource {
fn now(&self) -> Instant {
self.0
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn gets_correct_total_within_granularity() {
let start_time = Instant::now();
let mut rolling_total = RollingTotalBuilder::new()
.granularity(Duration::from_secs(1))
.window_size_multiple(10)
.time_source(UserTimeSource(start_time))
.start();
rolling_total.time_source().increment_by(Duration::from_millis(300));
rolling_total.push(1);
rolling_total.time_source().increment_by(Duration::from_millis(300));
rolling_total.push(10);
rolling_total.time_source().increment_by(Duration::from_millis(300));
rolling_total.push(-5);
assert_eq!(rolling_total.total(), 6);
assert_eq!(rolling_total.averages().len(), 1);
}
#[test]
fn gets_correct_total_within_window() {
let start_time = Instant::now();
let mut rolling_total = RollingTotalBuilder::new()
.granularity(Duration::from_secs(1))
.window_size_multiple(10)
.time_source(UserTimeSource(start_time))
.start();
rolling_total.push(4);
assert_eq!(rolling_total.averages().len(), 1);
assert_eq!(rolling_total.total(), 4);
rolling_total.time_source().increment_by(Duration::from_secs(3));
rolling_total.push(1);
assert_eq!(rolling_total.averages().len(), 2);
assert_eq!(rolling_total.total(), 5);
rolling_total.time_source().increment_by(Duration::from_secs(1));
rolling_total.push(10);
assert_eq!(rolling_total.averages().len(), 3);
assert_eq!(rolling_total.total(), 15);
// Jump precisely to the end of the window. Now, pushing a
// value will displace the first one (4). Note: if no value
// is pushed, this time change will have no effect.
rolling_total.time_source().increment_by(Duration::from_secs(8));
rolling_total.push(20);
assert_eq!(rolling_total.averages().len(), 3);
assert_eq!(rolling_total.total(), 15 + 20 - 4);
// Jump so that only the last value is still within the window:
rolling_total.time_source().increment_by(Duration::from_secs(9));
rolling_total.push(1);
assert_eq!(rolling_total.averages().len(), 2);
assert_eq!(rolling_total.total(), 21);
// Jump so that everything is out of scope (just about!):
rolling_total.time_source().increment_by(Duration::from_secs(10));
rolling_total.push(1);
assert_eq!(rolling_total.averages().len(), 1);
assert_eq!(rolling_total.total(), 1);
}
}
+1 -1
View File
@@ -26,7 +26,7 @@ const ABOUT: &str = "This is the Telemetry Backend Core that receives telemetry
#[derive(StructOpt, Debug)]
#[structopt(name = NAME, version = VERSION, author = AUTHORS, about = ABOUT)]
struct Opts {
/// This is the socket address that Telemetryis listening to. This is restricted to
/// This is the socket address that Telemetry is listening to. This is restricted to
/// localhost (127.0.0.1) by default and should be fine for most use cases. If
/// you are using Telemetry in a container, you likely want to set this to '0.0.0.0:8000'
#[structopt(short = "l", long = "listen", default_value = "127.0.0.1:8000")]
+28 -1
View File
@@ -4,11 +4,13 @@ mod connection;
mod json_message;
mod real_ip;
use std::{collections::HashSet, net::IpAddr};
use std::{collections::HashSet, net::IpAddr, time::Duration};
use aggregator::{Aggregator, FromWebsocket};
use common::http_utils;
use common::node_message;
use common::byte_size::ByteSize;
use common::rolling_total::RollingTotalBuilder;
use futures::{channel::mpsc, SinkExt, StreamExt};
use http::Uri;
use hyper::{Method, Response};
@@ -47,6 +49,12 @@ struct Opts {
/// RAM by suggesting that it accounts for billions of nodes.
#[structopt(long, default_value = "20")]
max_nodes_per_connection: usize,
/// What is the maximum number of bytes per second, on average, that a connection from a
/// node is allowed to send to a shard before it gets booted. This is averaged over a
/// rolling window of 10 seconds, and so spikes beyond this limit are allowed as long as
/// the average traffic in the last 10 seconds falls below this value.
#[structopt(long, default_value = "512k")]
max_node_data_per_second: ByteSize
}
#[tokio::main]
@@ -70,6 +78,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
let aggregator = Aggregator::spawn(opts.core_url).await?;
let socket_addr = opts.socket;
let max_nodes_per_connection = opts.max_nodes_per_connection;
let bytes_per_second = opts.max_node_data_per_second;
let server = http_utils::start_server(socket_addr, move |addr, req| {
let aggregator = aggregator.clone();
@@ -91,6 +100,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
ws_recv,
tx_to_aggregator,
max_nodes_per_connection,
bytes_per_second
)
.await;
log::info!("Closing /submit connection from {:?}", addr);
@@ -120,10 +130,19 @@ async fn handle_node_websocket_connection<S>(
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
max_nodes_per_connection: usize,
bytes_per_second: ByteSize
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// Limit the number of bytes based on a rolling total and the incoming bytes per second
// that has been configured via the CLI opts.
let bytes_per_second = bytes_per_second.into_bytes();
let mut rolling_total_bytes = RollingTotalBuilder::new()
.granularity(Duration::from_secs(1))
.window_size_multiple(10)
.start();
// Track all of the message IDs that we've seen so far. If we exceed the
// max_nodes_per_connection limit we ignore subsequent message IDs.
let mut message_ids_seen = HashSet::new();
@@ -165,6 +184,14 @@ where
break;
}
// Keep track of total bytes and bail if average over last 10 secs exceeds preference.
rolling_total_bytes.push(bytes.len());
let this_bytes_per_second = rolling_total_bytes.total() / 10;
if this_bytes_per_second > bytes_per_second {
log::error!("Shutting down websocket connection: Too much traffic ({}bps)", this_bytes_per_second);
break;
}
// Deserialize from JSON, warning in debug mode if deserialization fails:
let node_message: json_message::NodeMessage = match serde_json::from_slice(&bytes) {
Ok(node_message) => node_message,