mirror of
https://github.com/pezkuwichain/pezkuwi-telemetry.git
synced 2026-06-09 20:21:01 +00:00
Add rolling total and allow control over bytes per second allowed from node connections
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
use anyhow::{ anyhow, Error };
|
||||
|
||||
#[derive(Copy,Clone,Debug)]
|
||||
pub struct ByteSize(usize);
|
||||
|
||||
impl ByteSize {
|
||||
pub fn new(bytes: usize) -> ByteSize {
|
||||
ByteSize(bytes)
|
||||
}
|
||||
pub fn into_bytes(self) -> usize {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ByteSize> for usize {
|
||||
fn from(b: ByteSize) -> Self {
|
||||
b.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for ByteSize {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let s = s.trim();
|
||||
match s.find(|c| !char::is_ascii_digit(&c)) {
|
||||
// No non-numeric chars; assume bytes then
|
||||
None => Ok(ByteSize(s.parse().expect("all ascii digits"))),
|
||||
// First non-numeric char
|
||||
Some(idx) => {
|
||||
let n = s[..idx].parse().expect("all ascii digits");
|
||||
let suffix = s[idx..].trim();
|
||||
let n = match suffix {
|
||||
"B" | "b" => n,
|
||||
"kB" | "K" | "k" => n * 1000,
|
||||
"MB" | "M" | "m" => n * 1000 * 1000,
|
||||
"GB" | "G" | "g" => n * 1000 * 1000 * 1000,
|
||||
"KiB" | "Ki" => n * 1024,
|
||||
"MiB" | "Mi" => n * 1024 * 1024,
|
||||
"GiB" | "Gi" => n * 1024 * 1024 * 1024,
|
||||
_ => return Err(anyhow!("\
|
||||
Cannot parse into bytes; suffix is '{}', but expecting one of \
|
||||
B,b, kB,K,k, MB,M,m, GB,G,g, KiB,Ki, MiB,Mi, GiB,Gi", suffix))
|
||||
};
|
||||
Ok(ByteSize(n))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::byte_size::ByteSize;
|
||||
|
||||
|
||||
#[test]
|
||||
fn can_parse_valid_strings() {
|
||||
let cases = vec![
|
||||
("100", 100),
|
||||
("100B", 100),
|
||||
("100b", 100),
|
||||
|
||||
("20kB", 20 * 1000),
|
||||
("20 kB", 20 * 1000),
|
||||
("20K", 20 * 1000),
|
||||
(" 20k", 20 * 1000),
|
||||
|
||||
("1MB", 1 * 1000 * 1000),
|
||||
("1M", 1 * 1000 * 1000),
|
||||
("1m", 1 * 1000 * 1000),
|
||||
("1 m", 1 * 1000 * 1000),
|
||||
|
||||
("1GB", 1 * 1000 * 1000 * 1000),
|
||||
("1G", 1 * 1000 * 1000 * 1000),
|
||||
("1g", 1 * 1000 * 1000 * 1000),
|
||||
|
||||
("1KiB", 1 * 1024),
|
||||
("1Ki", 1 * 1024),
|
||||
|
||||
("1MiB", 1 * 1024 * 1024),
|
||||
("1Mi", 1 * 1024 * 1024),
|
||||
|
||||
("1GiB", 1 * 1024 * 1024 * 1024),
|
||||
("1Gi", 1 * 1024 * 1024 * 1024),
|
||||
(" 1 Gi ", 1 * 1024 * 1024 * 1024),
|
||||
];
|
||||
|
||||
for (s, expected) in cases {
|
||||
let b: ByteSize = s.parse().unwrap();
|
||||
assert_eq!(b.into_bytes(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -6,6 +6,8 @@ pub mod node_types;
|
||||
pub mod ready_chunks_all;
|
||||
pub mod time;
|
||||
pub mod ws_client;
|
||||
pub mod rolling_total;
|
||||
pub mod byte_size;
|
||||
|
||||
mod assign_id;
|
||||
mod dense_map;
|
||||
|
||||
@@ -0,0 +1,257 @@
|
||||
use std::time::{ Duration, Instant };
|
||||
use num_traits::{ Zero, SaturatingAdd, SaturatingSub };
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Build an object responsible for keeping track of a rolling total.
|
||||
/// It does this in constant time and using memory proportional to the
|
||||
/// granularity * window size multiple that we set.
|
||||
pub struct RollingTotalBuilder<Time: TimeSource = SystemTimeSource> {
|
||||
window_size_multiple: usize,
|
||||
granularity: Duration,
|
||||
time_source: Time
|
||||
}
|
||||
|
||||
impl RollingTotalBuilder {
|
||||
/// Build a [`RollingTotal`] struct. By default,
|
||||
/// the window_size is 10s, the granularity is 1s,
|
||||
/// and system time is used.
|
||||
pub fn new() -> RollingTotalBuilder<SystemTimeSource> {
|
||||
Self {
|
||||
window_size_multiple: 10,
|
||||
granularity: Duration::from_secs(1),
|
||||
time_source: SystemTimeSource
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the source of time we'll use. By default, we use system time.
|
||||
pub fn time_source<Time: TimeSource>(self, val: Time) -> RollingTotalBuilder<Time> {
|
||||
RollingTotalBuilder {
|
||||
window_size_multiple: self.window_size_multiple,
|
||||
granularity: self.granularity,
|
||||
time_source: val
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the size of the window of time that we'll look back on
|
||||
/// to sum up values over to give us the current total. The size
|
||||
/// is set as a multiple of the granularity; a granulatiry of 1s
|
||||
/// and a size of 10 means the window size will be 10 seconds.
|
||||
pub fn window_size_multiple(mut self, val: usize) -> Self {
|
||||
self.window_size_multiple = val;
|
||||
self
|
||||
}
|
||||
|
||||
/// What is the granulatiry of our windows of time. For example, a
|
||||
/// granularity of 5 seconds means that every 5 seconds the window
|
||||
/// that we look at shifts forward to the next 5 seconds worth of data.
|
||||
/// A larger granularity is more efficient but less accurate than a
|
||||
/// smaller one.
|
||||
pub fn granularity(mut self, val: Duration) -> Self {
|
||||
self.granularity = val;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl <Time: TimeSource> RollingTotalBuilder<Time> {
|
||||
/// Create a [`RollingTotal`] with these setings, starting from the
|
||||
/// instant provided.
|
||||
pub fn start<T>(self) -> RollingTotal<T, Time>
|
||||
where T: Zero + SaturatingAdd + SaturatingSub
|
||||
{
|
||||
let mut averages = VecDeque::new();
|
||||
averages.push_back((self.time_source.now(), T::zero()));
|
||||
|
||||
RollingTotal {
|
||||
window_size_multiple: self.window_size_multiple,
|
||||
time_source: self.time_source,
|
||||
granularity: self.granularity,
|
||||
averages,
|
||||
total: T::zero()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RollingTotal<Val, Time = SystemTimeSource> {
|
||||
window_size_multiple: usize,
|
||||
time_source: Time,
|
||||
granularity: Duration,
|
||||
averages: VecDeque<(Instant, Val)>,
|
||||
total: Val
|
||||
}
|
||||
|
||||
impl <Val, Time: TimeSource> RollingTotal<Val, Time>
|
||||
where
|
||||
Val: SaturatingAdd + SaturatingSub + Copy + std::fmt::Debug,
|
||||
Time: TimeSource
|
||||
{
|
||||
|
||||
/// Add a new value at some time.
|
||||
pub fn push(&mut self, value: Val) {
|
||||
let time = self.time_source.now();
|
||||
let (last_time, last_val) = self.averages
|
||||
.back_mut()
|
||||
.expect("always 1 value");
|
||||
|
||||
let since_last_nanos = time.duration_since(*last_time).as_nanos();
|
||||
let granularity_nanos = self.granularity.as_nanos();
|
||||
|
||||
if since_last_nanos >= granularity_nanos {
|
||||
// New time doesn't fit into last bucket; create a new bucket with a time
|
||||
// that is some number of granularity steps from the last, and add the
|
||||
// value to that.
|
||||
|
||||
// This rounds down, eg 7 / 5 = 1. Find the number of granularity steps
|
||||
// to jump from the last time such that the jump can fit this new value.
|
||||
let steps = since_last_nanos / granularity_nanos;
|
||||
|
||||
// Create a new time this number of jumps forward, and push it.
|
||||
let new_time = *last_time + Duration::from_nanos(granularity_nanos as u64) * steps as u32;
|
||||
self.total = self.total.saturating_add(&value);
|
||||
self.averages.push_back((new_time, value));
|
||||
|
||||
// Remove any old times/values no longer within our window size. If window_size_multiple
|
||||
// is 1, then we only keep the just-pushed time, hence the "-1". Remember to keep our
|
||||
// cached total up to date if we remove things.
|
||||
let oldest_time_in_window = new_time - (self.granularity * (self.window_size_multiple - 1) as u32);
|
||||
while self.averages.front().expect("always 1 value").0 < oldest_time_in_window {
|
||||
let value = self.averages.pop_front().expect("always 1 value").1;
|
||||
self.total = self.total.saturating_sub(&value);
|
||||
}
|
||||
} else {
|
||||
// New time fits into our last bucket, so just add it on. We don't need to worry
|
||||
// about bucket cleanup since number/times of buckets hasn't changed.
|
||||
*last_val = last_val.saturating_add(&value);
|
||||
self.total = self.total.saturating_add(&value);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Fetch the current rolling total that we've accumulated. Note that this
|
||||
/// is based on the last seen times and values, and is not affected by the time
|
||||
/// that it is called.
|
||||
pub fn total(&self) -> Val {
|
||||
self.total
|
||||
}
|
||||
|
||||
/// Fetch the current time source, incase we need to modify it.
|
||||
pub fn time_source(&mut self) -> &mut Time {
|
||||
&mut self.time_source
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn averages(&self) -> &VecDeque<(Instant, Val)> {
|
||||
&self.averages
|
||||
}
|
||||
}
|
||||
|
||||
/// A source of time that we can use in our rolling total.
|
||||
/// This allows us to avoid explicitly mentioning time when pushing
|
||||
/// new values, and makes it a little harder to accidentally pass
|
||||
/// an older time and cause a panic.
|
||||
pub trait TimeSource {
|
||||
fn now(&self) -> Instant;
|
||||
}
|
||||
|
||||
pub struct SystemTimeSource;
|
||||
impl TimeSource for SystemTimeSource {
|
||||
fn now(&self) -> Instant {
|
||||
Instant::now()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UserTimeSource(Instant);
|
||||
impl UserTimeSource {
|
||||
pub fn new(time: Instant) -> Self {
|
||||
UserTimeSource(time)
|
||||
}
|
||||
pub fn set_time(&mut self, time: Instant) {
|
||||
self.0 = time;
|
||||
}
|
||||
pub fn increment_by(&mut self, duration: Duration) {
|
||||
self.0 += duration;
|
||||
}
|
||||
}
|
||||
impl TimeSource for UserTimeSource {
|
||||
fn now(&self) -> Instant {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn gets_correct_total_within_granularity() {
|
||||
let start_time = Instant::now();
|
||||
let mut rolling_total = RollingTotalBuilder::new()
|
||||
.granularity(Duration::from_secs(1))
|
||||
.window_size_multiple(10)
|
||||
.time_source(UserTimeSource(start_time))
|
||||
.start();
|
||||
|
||||
rolling_total.time_source().increment_by(Duration::from_millis(300));
|
||||
rolling_total.push(1);
|
||||
|
||||
rolling_total.time_source().increment_by(Duration::from_millis(300));
|
||||
rolling_total.push(10);
|
||||
|
||||
rolling_total.time_source().increment_by(Duration::from_millis(300));
|
||||
rolling_total.push(-5);
|
||||
|
||||
assert_eq!(rolling_total.total(), 6);
|
||||
assert_eq!(rolling_total.averages().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gets_correct_total_within_window() {
|
||||
let start_time = Instant::now();
|
||||
let mut rolling_total = RollingTotalBuilder::new()
|
||||
.granularity(Duration::from_secs(1))
|
||||
.window_size_multiple(10)
|
||||
.time_source(UserTimeSource(start_time))
|
||||
.start();
|
||||
|
||||
rolling_total.push(4);
|
||||
|
||||
assert_eq!(rolling_total.averages().len(), 1);
|
||||
assert_eq!(rolling_total.total(), 4);
|
||||
|
||||
rolling_total.time_source().increment_by(Duration::from_secs(3));
|
||||
rolling_total.push(1);
|
||||
|
||||
assert_eq!(rolling_total.averages().len(), 2);
|
||||
assert_eq!(rolling_total.total(), 5);
|
||||
|
||||
rolling_total.time_source().increment_by(Duration::from_secs(1));
|
||||
rolling_total.push(10);
|
||||
|
||||
assert_eq!(rolling_total.averages().len(), 3);
|
||||
assert_eq!(rolling_total.total(), 15);
|
||||
|
||||
// Jump precisely to the end of the window. Now, pushing a
|
||||
// value will displace the first one (4). Note: if no value
|
||||
// is pushed, this time change will have no effect.
|
||||
rolling_total.time_source().increment_by(Duration::from_secs(8));
|
||||
rolling_total.push(20);
|
||||
|
||||
assert_eq!(rolling_total.averages().len(), 3);
|
||||
assert_eq!(rolling_total.total(), 15 + 20 - 4);
|
||||
|
||||
// Jump so that only the last value is still within the window:
|
||||
rolling_total.time_source().increment_by(Duration::from_secs(9));
|
||||
rolling_total.push(1);
|
||||
|
||||
assert_eq!(rolling_total.averages().len(), 2);
|
||||
assert_eq!(rolling_total.total(), 21);
|
||||
|
||||
// Jump so that everything is out of scope (just about!):
|
||||
rolling_total.time_source().increment_by(Duration::from_secs(10));
|
||||
rolling_total.push(1);
|
||||
|
||||
assert_eq!(rolling_total.averages().len(), 1);
|
||||
assert_eq!(rolling_total.total(), 1);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -26,7 +26,7 @@ const ABOUT: &str = "This is the Telemetry Backend Core that receives telemetry
|
||||
#[derive(StructOpt, Debug)]
|
||||
#[structopt(name = NAME, version = VERSION, author = AUTHORS, about = ABOUT)]
|
||||
struct Opts {
|
||||
/// This is the socket address that Telemetryis listening to. This is restricted to
|
||||
/// This is the socket address that Telemetry is listening to. This is restricted to
|
||||
/// localhost (127.0.0.1) by default and should be fine for most use cases. If
|
||||
/// you are using Telemetry in a container, you likely want to set this to '0.0.0.0:8000'
|
||||
#[structopt(short = "l", long = "listen", default_value = "127.0.0.1:8000")]
|
||||
|
||||
@@ -4,11 +4,13 @@ mod connection;
|
||||
mod json_message;
|
||||
mod real_ip;
|
||||
|
||||
use std::{collections::HashSet, net::IpAddr};
|
||||
use std::{collections::HashSet, net::IpAddr, time::Duration};
|
||||
|
||||
use aggregator::{Aggregator, FromWebsocket};
|
||||
use common::http_utils;
|
||||
use common::node_message;
|
||||
use common::byte_size::ByteSize;
|
||||
use common::rolling_total::RollingTotalBuilder;
|
||||
use futures::{channel::mpsc, SinkExt, StreamExt};
|
||||
use http::Uri;
|
||||
use hyper::{Method, Response};
|
||||
@@ -47,6 +49,12 @@ struct Opts {
|
||||
/// RAM by suggesting that it accounts for billions of nodes.
|
||||
#[structopt(long, default_value = "20")]
|
||||
max_nodes_per_connection: usize,
|
||||
/// What is the maximum number of bytes per second, on average, that a connection from a
|
||||
/// node is allowed to send to a shard before it gets booted. This is averaged over a
|
||||
/// rolling window of 10 seconds, and so spikes beyond this limit are allowed as long as
|
||||
/// the average traffic in the last 10 seconds falls below this value.
|
||||
#[structopt(long, default_value = "512k")]
|
||||
max_node_data_per_second: ByteSize
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -70,6 +78,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
|
||||
let aggregator = Aggregator::spawn(opts.core_url).await?;
|
||||
let socket_addr = opts.socket;
|
||||
let max_nodes_per_connection = opts.max_nodes_per_connection;
|
||||
let bytes_per_second = opts.max_node_data_per_second;
|
||||
|
||||
let server = http_utils::start_server(socket_addr, move |addr, req| {
|
||||
let aggregator = aggregator.clone();
|
||||
@@ -91,6 +100,7 @@ async fn start_server(opts: Opts) -> anyhow::Result<()> {
|
||||
ws_recv,
|
||||
tx_to_aggregator,
|
||||
max_nodes_per_connection,
|
||||
bytes_per_second
|
||||
)
|
||||
.await;
|
||||
log::info!("Closing /submit connection from {:?}", addr);
|
||||
@@ -120,10 +130,19 @@ async fn handle_node_websocket_connection<S>(
|
||||
mut ws_recv: http_utils::WsReceiver,
|
||||
mut tx_to_aggregator: S,
|
||||
max_nodes_per_connection: usize,
|
||||
bytes_per_second: ByteSize
|
||||
) -> (S, http_utils::WsSender)
|
||||
where
|
||||
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
|
||||
{
|
||||
// Limit the number of bytes based on a rolling total and the incoming bytes per second
|
||||
// that has been configured via the CLI opts.
|
||||
let bytes_per_second = bytes_per_second.into_bytes();
|
||||
let mut rolling_total_bytes = RollingTotalBuilder::new()
|
||||
.granularity(Duration::from_secs(1))
|
||||
.window_size_multiple(10)
|
||||
.start();
|
||||
|
||||
// Track all of the message IDs that we've seen so far. If we exceed the
|
||||
// max_nodes_per_connection limit we ignore subsequent message IDs.
|
||||
let mut message_ids_seen = HashSet::new();
|
||||
@@ -165,6 +184,14 @@ where
|
||||
break;
|
||||
}
|
||||
|
||||
// Keep track of total bytes and bail if average over last 10 secs exceeds preference.
|
||||
rolling_total_bytes.push(bytes.len());
|
||||
let this_bytes_per_second = rolling_total_bytes.total() / 10;
|
||||
if this_bytes_per_second > bytes_per_second {
|
||||
log::error!("Shutting down websocket connection: Too much traffic ({}bps)", this_bytes_per_second);
|
||||
break;
|
||||
}
|
||||
|
||||
// Deserialize from JSON, warning in debug mode if deserialization fails:
|
||||
let node_message: json_message::NodeMessage = match serde_json::from_slice(&bytes) {
|
||||
Ok(node_message) => node_message,
|
||||
|
||||
Reference in New Issue
Block a user