From 0eff32d10ef2bc067d16bb689f597bf84ef8c835 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Fri, 23 Jul 2021 15:57:50 +0100 Subject: [PATCH] Migrate telemetry_core to Hyper+Soketto --- backend/Cargo.lock | 13 +- backend/common/Cargo.toml | 4 + backend/common/src/http_utils.rs | 139 ++++++ backend/common/src/lib.rs | 1 + backend/telemetry_core/Cargo.toml | 1 + .../src/aggregator/aggregator.rs | 4 +- backend/telemetry_core/src/main.rs | 451 +++++++++--------- 7 files changed, 389 insertions(+), 224 deletions(-) create mode 100644 backend/common/src/http_utils.rs diff --git a/backend/Cargo.lock b/backend/Cargo.lock index d8e0ca9..bcc187d 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -20,9 +20,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.41" +version = "1.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15af2628f6890fe2609a3b91bef4c83450512802e59489f9c1cb1fa5df064a61" +checksum = "595d3cfa7a60d4555cb5067b99f07142a08ea778de5cf993f7b75c7d8fabc486" [[package]] name = "arrayvec" @@ -205,6 +205,8 @@ dependencies = [ name = "common" version = "0.1.0" dependencies = [ + "anyhow", + "base64", "bimap", "bincode", "bytes", @@ -212,6 +214,7 @@ dependencies = [ "futures", "hex", "http", + "hyper", "log", "num-traits", "pin-project-lite", @@ -219,6 +222,7 @@ dependencies = [ "rustc-hash", "serde", "serde_json", + "sha-1", "soketto", "thiserror", "tokio", @@ -675,9 +679,9 @@ checksum = "6456b8a6c8f33fee7d958fcd1b60d55b11940a79e63ae87013e6d22e26034440" [[package]] name = "hyper" -version = "0.14.9" +version = "0.14.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07d6baa1b441335f3ce5098ac421fb6547c46dda735ca1bc6d0153c838f9dd83" +checksum = "0b61cf2d1aebcf6e6352c97b81dc2244ca29194be1b276f5d8ad5c6330fffb11" dependencies = [ "bytes", "futures-channel", @@ -1689,6 +1693,7 @@ dependencies = [ "futures", "hex", "http", + "hyper", "log", "once_cell", "parking_lot", diff --git a/backend/common/Cargo.toml b/backend/common/Cargo.toml index 1b38b09..a6fcd58 100644 --- a/backend/common/Cargo.toml +++ b/backend/common/Cargo.toml @@ -6,12 +6,15 @@ edition = "2018" license = "GPL-3.0" [dependencies] +anyhow = "1.0.42" +base64 = { default-features = false, features = ["alloc"], version = "0.13" } bimap = "0.6.1" bytes = "1.0.1" fnv = "1.0.7" futures = "0.3.15" hex = "0.4.3" http = "0.2.4" +hyper = { version = "0.14.11", features = ["full"] } log = "0.4" num-traits = "0.2" pin-project-lite = "0.2.7" @@ -19,6 +22,7 @@ primitive-types = { version = "0.9.0", features = ["serde"] } rustc-hash = "1.1.0" serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0", features = ["raw_value"] } +sha-1 = { default-features = false, version = "0.9" } soketto = "0.6.0" thiserror = "1.0.24" tokio = { version = "1.8.2", features = ["full"] } diff --git a/backend/common/src/http_utils.rs b/backend/common/src/http_utils.rs new file mode 100644 index 0000000..bf2d145 --- /dev/null +++ b/backend/common/src/http_utils.rs @@ -0,0 +1,139 @@ +use std::net::SocketAddr; +use hyper::{ Server, Request, Response, Body }; +use std::future::Future; +use tokio_util::compat::{Compat,TokioAsyncReadCompatExt}; +use futures::io::{BufReader, BufWriter}; +use hyper::server::conn::AddrStream; + +/// A convenience function to start up a Hyper server and handle requests. +pub async fn start_server(addr: SocketAddr, handler: H) -> Result<(), anyhow::Error> +where + H: Clone + Send + Sync + 'static + FnMut(SocketAddr, Request) -> F, + F: Send + 'static + Future, anyhow::Error>> +{ + let service = + hyper::service::make_service_fn(move |addr: &AddrStream| { + let mut handler = handler.clone(); + let addr = addr.remote_addr(); + async move { Ok::<_, hyper::Error>(hyper::service::service_fn(move |r| handler(addr, r))) } + }); + let server = Server::bind(&addr).serve(service); + + log::info!("listening on http://{}", server.local_addr()); + server.await?; + + Ok(()) + +} + +type WsStream = BufReader>>; +pub type WsSender = soketto::connection::Sender; +pub type WsReceiver = soketto::connection::Receiver; + +/// A convenience function to upgrade a Hyper request into a Soketto Websocket. +pub fn upgrade_to_websocket(req: Request, on_upgrade: H) -> hyper::Response +where + H: 'static + Send + FnOnce(WsSender, WsReceiver) -> F, + F: Send + Future +{ + if !is_upgrade_request(&req) { + return basic_response(400, "Expecting WebSocket upgrade headers"); + } + + let key = match req.headers().get("Sec-WebSocket-Key") { + Some(key) => key, + None => return basic_response(400, "Upgrade to websocket connection failed; Sec-WebSocket-Key header not provided") + }; + + if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") { + return basic_response(400, "Sec-WebSocket-Version header should have a value of 13"); + } + + // Just a little ceremony we need to go to to return the correct response key: + let mut accept_key_buf = [0; 32]; + let accept_key = generate_websocket_accept_key(key.as_bytes(), &mut accept_key_buf); + + // Tell the client that we accept the upgrade-to-WS request: + let response = Response::builder() + .status(hyper::StatusCode::SWITCHING_PROTOCOLS) + .header(hyper::header::CONNECTION, "upgrade") + .header(hyper::header::UPGRADE, "websocket") + .header("Sec-WebSocket-Accept", accept_key) + .body(Body::empty()) + .expect("bug: failed to build response"); + + // Spawn our handler to work with the WS connection: + tokio::spawn(async move { + // Get our underlying TCP stream: + let stream = match hyper::upgrade::on(req).await { + Ok(stream) => stream, + Err(e) => { + log::error!("Error upgrading connection to websocket: {}", e); + return; + } + }; + + // Start a Soketto server with it: + let server = soketto::handshake::Server::new( + BufReader::new(BufWriter::new(stream.compat())) + ); + + // Get hold of a way to send and receive messages: + let (sender, receiver) + = server.into_builder().finish(); + + // Pass these to our when-upgraded handler: + on_upgrade(sender, receiver).await; + }); + + response +} + +/// A helper to return a basic HTTP response with a code and text body. +fn basic_response(code: u16, msg: impl AsRef) -> Response { + Response::builder() + .status(code) + .body(Body::from(msg.as_ref().to_owned())) + .expect("bug: failed to build response body") +} + +/// Defined in RFC 6455. this is how we convert the Sec-WebSocket-Key in a request into a +/// Sec-WebSocket-Accept that we return in the response. +fn generate_websocket_accept_key<'a>(key: &[u8], buf: &'a mut [u8; 32]) -> &'a [u8] { + // Defined in RFC 6455, we append this to the key to generate the response: + const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + use sha1::{Digest, Sha1}; + let mut digest = Sha1::new(); + digest.update(key); + digest.update(KEY); + let d = digest.finalize(); + + let n = base64::encode_config_slice(&d, base64::STANDARD, buf); + &buf[..n] +} + +/// Check if a request is a websocket upgrade request. +fn is_upgrade_request(request: &hyper::Request) -> bool { + header_contains_value(request.headers(), hyper::header::CONNECTION, b"upgrade") + && header_contains_value(request.headers(), hyper::header::UPGRADE, b"websocket") +} + +/// Check if there is a header of the given name containing the wanted value. +fn header_contains_value(headers: &hyper::HeaderMap, header: hyper::header::HeaderName, value: &[u8]) -> bool { + pub fn trim(x: &[u8]) -> &[u8] { + let from = match x.iter().position(|x| !x.is_ascii_whitespace()) { + Some(i) => i, + None => return &x[0..0], + }; + let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap(); + &x[from..=to] + } + + for header in headers.get_all(header) { + if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) { + return true; + } + } + false +} \ No newline at end of file diff --git a/backend/common/src/lib.rs b/backend/common/src/lib.rs index 4eaab74..19ff1cb 100644 --- a/backend/common/src/lib.rs +++ b/backend/common/src/lib.rs @@ -5,6 +5,7 @@ pub mod node_types; pub mod time; pub mod ws_client; pub mod ready_chunks_all; +pub mod http_utils; mod assign_id; mod dense_map; diff --git a/backend/telemetry_core/Cargo.toml b/backend/telemetry_core/Cargo.toml index c751052..5aca032 100644 --- a/backend/telemetry_core/Cargo.toml +++ b/backend/telemetry_core/Cargo.toml @@ -14,6 +14,7 @@ common = { path = "../common" } futures = "0.3.15" hex = "0.4.3" http = "0.2.4" +hyper = "0.14.11" log = "0.4.14" once_cell = "1.8.0" parking_lot = "0.11.1" diff --git a/backend/telemetry_core/src/aggregator/aggregator.rs b/backend/telemetry_core/src/aggregator/aggregator.rs index 7e9a4e1..2a049c1 100644 --- a/backend/telemetry_core/src/aggregator/aggregator.rs +++ b/backend/telemetry_core/src/aggregator/aggregator.rs @@ -74,7 +74,7 @@ impl Aggregator { /// Return a sink that a shard can send messages into to be handled by the aggregator. pub fn subscribe_shard( &self, - ) -> impl Sink + Unpin { + ) -> impl Sink + Send + Sync + Unpin + 'static { // Assign a unique aggregator-local ID to each connection that subscribes, and pass // that along with every message to the aggregator loop: let shard_conn_id = self @@ -96,7 +96,7 @@ impl Aggregator { /// Return a sink that a feed can send messages into to be handled by the aggregator. pub fn subscribe_feed( &self, - ) -> impl Sink + Unpin { + ) -> impl Sink + Send + Sync + Unpin + 'static { // Assign a unique aggregator-local ID to each connection that subscribes, and pass // that along with every message to the aggregator loop: let feed_conn_id = self diff --git a/backend/telemetry_core/src/main.rs b/backend/telemetry_core/src/main.rs index 6f1d00e..fb0c77d 100644 --- a/backend/telemetry_core/src/main.rs +++ b/backend/telemetry_core/src/main.rs @@ -2,8 +2,6 @@ mod aggregator; mod feed_message; mod find_location; mod state; - -use std::net::SocketAddr; use std::str::FromStr; use aggregator::{ @@ -15,8 +13,8 @@ use common::ready_chunks_all::ReadyChunksAll; use futures::{channel::mpsc, SinkExt, StreamExt}; use simple_logger::SimpleLogger; use structopt::StructOpt; -use warp::filters::ws; -use warp::Filter; +use hyper::{ Response, Method }; +use common::http_utils; const VERSION: &str = env!("CARGO_PKG_VERSION"); const AUTHORS: &str = env!("CARGO_PKG_AUTHORS"); @@ -60,65 +58,63 @@ async fn main() { /// Declare our routes and start the server. async fn start_server(opts: Opts) -> anyhow::Result<()> { - let shard_aggregator = Aggregator::spawn(opts.denylist).await?; - let feed_aggregator = shard_aggregator.clone(); + let aggregator = Aggregator::spawn(opts.denylist).await?; + let server = http_utils::start_server(opts.socket, move |addr, req| { + let aggregator = aggregator.clone(); + println!("REQUEST: {:?}", (req.method(), req.uri().path())); + async move { + match (req.method(), req.uri().path().trim_end_matches('/')) { + // Check that the server is up and running: + (&Method::GET, "/health") => { + Ok(Response::new("OK".into())) + }, + // Subscribe to feed messages: + (&Method::GET, "/feed") => { + Ok(http_utils::upgrade_to_websocket(req, move |ws_send, ws_recv| async move { + let tx_to_aggregator = aggregator.subscribe_feed(); + let (mut tx_to_aggregator, mut ws_send) + = handle_feed_websocket_connection(ws_send, ws_recv, tx_to_aggregator).await; + log::info!("Closing /feed connection from {:?}", addr); + // Tell the aggregator that this connection has closed, so it can tidy up. + let _ = tx_to_aggregator.send(FromFeedWebsocket::Disconnected).await; + let _ = ws_send.close().await; + })) + }, + // Subscribe to shard messages: + (&Method::GET, "/shard_submit") => { + Ok(http_utils::upgrade_to_websocket(req, move |ws_send, ws_recv| async move { + let tx_to_aggregator = aggregator.subscribe_shard(); + let (mut tx_to_aggregator, mut ws_send) + = handle_shard_websocket_connection(ws_send, ws_recv, tx_to_aggregator).await; + log::info!("Closing /shard_submit connection from {:?}", addr); + // Tell the aggregator that this connection has closed, so it can tidy up. + let _ = tx_to_aggregator.send(FromShardWebsocket::Disconnected).await; + let _ = ws_send.close().await; + })) + }, + // 404 for anything else: + _ => { + Ok(Response::builder() + .status(404) + .body("Not found".into()) + .unwrap()) + } + } + } + }); - // Handle requests to /health by returning OK. - let health_route = warp::path("health").map(|| "OK"); - - // Handle websocket requests from shards. - let ws_shard_submit_route = warp::path("shard_submit") - .and(warp::ws()) - .and(warp::filters::addr::remote()) - .map(move |ws: ws::Ws, addr: Option| { - let tx_to_aggregator = shard_aggregator.subscribe_shard(); - log::info!("Opening /shard_submit connection from {:?}", addr); - ws.on_upgrade(move |websocket| async move { - let (mut tx_to_aggregator, websocket) = - handle_shard_websocket_connection(websocket, tx_to_aggregator).await; - log::info!("Closing /shard_submit connection from {:?}", addr); - // Tell the aggregator that this connection has closed, so it can tidy up. - let _ = tx_to_aggregator - .send(FromShardWebsocket::Disconnected) - .await; - let _ = websocket.close().await; - }) - }); - - // Handle websocket requests from frontends. - let ws_feed_route = warp::path("feed") - .and(warp::ws()) - .and(warp::filters::addr::remote()) - .map(move |ws: ws::Ws, addr: Option| { - let tx_to_aggregator = feed_aggregator.subscribe_feed(); - log::info!("Opening /feed connection from {:?}", addr); - - // We can decide how many messages can be buffered to be sent, but not specifically how - // large those messages are cumulatively allowed to be: - ws.max_send_queue(1_000) - .on_upgrade(move |websocket| async move { - let (mut tx_to_aggregator, websocket) = - handle_feed_websocket_connection(websocket, tx_to_aggregator).await; - log::info!("Closing /feed connection from {:?}", addr); - // Tell the aggregator that this connection has closed, so it can tidy up. - let _ = tx_to_aggregator.send(FromFeedWebsocket::Disconnected).await; - let _ = websocket.close().await; - }) - }); - - // Merge the routes and start our server: - let routes = ws_shard_submit_route.or(ws_feed_route).or(health_route); - warp::serve(routes).run(opts.socket).await; + server.await?; Ok(()) } /// This handles messages coming to/from a shard connection async fn handle_shard_websocket_connection( - mut websocket: ws::WebSocket, + mut ws_send: http_utils::WsSender, + mut ws_recv: http_utils::WsReceiver, mut tx_to_aggregator: S, -) -> (S, ws::WebSocket) +) -> (S, http_utils::WsSender) where - S: futures::Sink + Unpin, + S: futures::Sink + Unpin + Send + 'static, { let (tx_to_shard_conn, mut rx_from_aggregator) = mpsc::unbounded(); @@ -128,102 +124,119 @@ where }; if let Err(e) = tx_to_aggregator.send(init_msg).await { log::error!("Error sending message to aggregator: {}", e); - return (tx_to_aggregator, websocket); + return (tx_to_aggregator, ws_send); } - // Loop, handling new messages from the shard or from the aggregator: - loop { - tokio::select! { - // AGGREGATOR -> SHARD - msg = rx_from_aggregator.next() => { - // End the loop when connection from aggregator ends: - let msg = match msg { - Some(msg) => msg, - None => break - }; + // Channels to notify each loop if the other closes: + let (recv_closer_tx, mut recv_closer_rx) = tokio::sync::oneshot::channel::<()>(); + let (send_closer_tx, mut send_closer_rx) = tokio::sync::oneshot::channel::<()>(); - let internal_msg = match msg { - ToShardWebsocket::Mute { local_id, reason } => { - internal_messages::FromTelemetryCore::Mute { local_id, reason } - } - }; + // Receive messages from a shard: + let recv_handle = tokio::spawn(async move { + loop { + let mut bytes = Vec::new(); - let bytes = bincode::options() - .serialize(&internal_msg) - .expect("message to shard should serialize"); + // Receive a message, or bail if closer called. We don't care about cancel safety; + // if we're halfway through receiving a message, no biggie since we're closing the + // connection anyway. + let msg_info = tokio::select! { + msg_info = ws_recv.receive_data(&mut bytes) => msg_info, + _ = &mut recv_closer_rx => { break } + }; - if let Err(e) = websocket.send(ws::Message::binary(bytes)).await { - log::error!("Error sending message to shard; booting it: {}", e); - break - } + // Handle the socket closing, or errors receiving the message. + if let Err(soketto::connection::Error::Closed) = msg_info { + break; + } + if let Err(e) = msg_info { + log::error!("Shutting down websocket connection: Failed to receive data: {}", e); + break; } - // SHARD -> AGGREGATOR - msg = websocket.next() => { - // End the loop when connection from shard ends: - let msg = match msg { - Some(msg) => msg, - None => break - }; - let msg = match msg { - Err(e) => { - log::error!("Error receiving message from shard; booting it: {}", e); - break; - }, - Ok(msg) => msg - }; - - // Close message? Break and allow connection to be dropped. - if msg.is_close() { + let msg: internal_messages::FromShardAggregator = match bincode::options().deserialize(&bytes) { + Ok(msg) => msg, + Err(e) => { + log::error!("Failed to deserialize message from shard; booting it: {}", e); break; } + }; - // If the message isn't something we want to handle, just ignore it. - // This includes system messages like "pings" and such, so don't log anything. - if !msg.is_binary() && !msg.is_text() { - continue; - } + // Convert and send to the aggregator: + let aggregator_msg = match msg { + internal_messages::FromShardAggregator::AddNode { ip, node, local_id, genesis_hash } => { + FromShardWebsocket::Add { ip, node, genesis_hash, local_id } + }, + internal_messages::FromShardAggregator::UpdateNode { payload, local_id } => { + FromShardWebsocket::Update { local_id, payload } + }, + internal_messages::FromShardAggregator::RemoveNode { local_id } => { + FromShardWebsocket::Remove { local_id } + }, + }; - let bytes = msg.as_bytes(); - let msg: internal_messages::FromShardAggregator = match bincode::options().deserialize(bytes) { - Ok(msg) => msg, - Err(e) => { - log::error!("Failed to deserialize message from shard; booting it: {}", e); - break; - } - }; - - // Convert and send to the aggregator: - let aggregator_msg = match msg { - internal_messages::FromShardAggregator::AddNode { ip, node, local_id, genesis_hash } => { - FromShardWebsocket::Add { ip, node, genesis_hash, local_id } - }, - internal_messages::FromShardAggregator::UpdateNode { payload, local_id } => { - FromShardWebsocket::Update { local_id, payload } - }, - internal_messages::FromShardAggregator::RemoveNode { local_id } => { - FromShardWebsocket::Remove { local_id } - }, - }; - if let Err(e) = tx_to_aggregator.send(aggregator_msg).await { - log::error!("Failed to send message to aggregator; closing shard: {}", e); - break; - } + if let Err(e) = tx_to_aggregator.send(aggregator_msg).await { + log::error!("Failed to send message to aggregator; closing shard: {}", e); + break; } } - } + + drop(send_closer_tx); // Kill the send task if this recv task ends + tx_to_aggregator + }); + + // Send messages to the shard: + let send_handle = tokio::spawn(async move { + loop { + let msg = tokio::select! { + msg = rx_from_aggregator.next() => msg, + _ = &mut send_closer_rx => { break } + }; + + let msg = match msg { + Some(msg) => msg, + None => break + }; + + let internal_msg = match msg { + ToShardWebsocket::Mute { local_id, reason } => { + internal_messages::FromTelemetryCore::Mute { local_id, reason } + } + }; + + let bytes = bincode::options() + .serialize(&internal_msg) + .expect("message to shard should serialize"); + + if let Err(e) = ws_send.send_binary(bytes).await { + log::error!("Failed to send message to aggregator; closing shard: {}", e) + } + if let Err(e) = ws_send.flush().await { + log::error!("Failed to flush message to aggregator; closing shard: {}", e) + } + + } + + drop(recv_closer_tx); // Kill the recv task if this send task ends + ws_send + }); + + // If our send/recv tasks are stopped (if one of them dies, they both will), + // collect the bits we need to hand back from them: + let ws_send = send_handle.await.unwrap(); + let tx_to_aggregator = recv_handle.await.unwrap(); // loop ended; give socket back to parent: - (tx_to_aggregator, websocket) + (tx_to_aggregator, ws_send) } /// This handles messages coming from a feed connection async fn handle_feed_websocket_connection( - mut websocket: ws::WebSocket, + mut ws_send: http_utils::WsSender, + mut ws_recv: http_utils::WsReceiver, mut tx_to_aggregator: S, -) -> (S, ws::WebSocket) +) -> (S, http_utils::WsSender) where - S: futures::Sink + Unpin, + S: futures::Sink + Unpin + Send + 'static, { // unbounded channel so that slow feeds don't block aggregator progress: let (tx_to_feed_conn, rx_from_aggregator) = mpsc::unbounded(); @@ -235,106 +248,108 @@ where }; if let Err(e) = tx_to_aggregator.send(init_msg).await { log::error!("Error sending message to aggregator: {}", e); - return (tx_to_aggregator, websocket); + return (tx_to_aggregator, ws_send); } - // Loop, handling new messages from the shard or from the aggregator: - loop { - // Without any special handling, if messages come in every ~2.5ms to each feed, the select! loop - // has to wake up 400 times a second to poll things. If we have 1000 feeds, that's 400,000 wakeups - // per second. Even without any work in the loop, that uses a bunch of CPU. As an example, try - // replacing the loop with this: - // - // ``` - // let s = tokio::time::sleep(tokio::time::Duration::from_micros(2500)); - // tokio::select! { - // _ = s => {}, - // _ = websocket.next() => {} - // } - // continue; - // ``` - // - // To combat this, we add a small wait to reduce how often the select loop will be woken up under high load. We - // buffer messages to feeds so that we do as much work as possible during each wakeup, and if the - // wakeup lasts longer than 75ms we don't wait before polling again. This knocks ~80% of a CPU worth of usage - // off on my machine running a soak test with 500 feeds, 4 shards and 100 nodes, doesn't seem to impact - // memory usage much, and still ensures that messages are delivered in a timely fashion. - // - // Increasing the wait to 100ms or more doesn't seem to have much more of a positive impact anyway. - let debounce = tokio::time::sleep_until(tokio::time::Instant::now() + std::time::Duration::from_millis(75)); + // Channels to notify each loop if the other closes: + let (recv_closer_tx, mut recv_closer_rx) = tokio::sync::oneshot::channel::<()>(); + let (send_closer_tx, mut send_closer_rx) = tokio::sync::oneshot::channel::<()>(); - tokio::select! {biased; + // Receive messages from the feed: + let recv_handle = tokio::spawn(async move { + loop { + let mut bytes = Vec::new(); - // FRONTEND -> AGGREGATOR (relay messages to the aggregator). Biased, so messages - // from the UI will have priority (especially important with our debounce delay). - msg = websocket.next() => { - // End the loop when connection from feed ends: - let msg = match msg { - Some(msg) => msg, - None => break - }; + // Receive a message, or bail if closer called. We don't care about cancel safety; + // if we're halfway through receiving a message, no biggie since we're closing the + // connection anyway. + let msg_info = tokio::select! { + msg_info = ws_recv.receive_data(&mut bytes) => msg_info, + _ = &mut recv_closer_rx => { break } + }; - // If we see any errors, log them and end our loop: - let msg = match msg { - Err(e) => { - log::error!("Error in node websocket connection: {}", e); - break; - }, - Ok(msg) => msg - }; - - // Close message? Break and allow connection to be dropped. - if msg.is_close() { - break; - } - - // We ignore all but text messages from the frontend: - let text = match msg.to_str() { - Ok(s) => s, - Err(_) => continue - }; - - // Parse the message into a command we understand and send it to the aggregator: - let cmd = match FromFeedWebsocket::from_str(text) { - Ok(cmd) => cmd, - Err(e) => { - log::warn!("Ignoring invalid command '{}' from the frontend: {}", text, e); - continue - } - }; - if let Err(e) = tx_to_aggregator.send(cmd).await { - log::error!("Failed to send message to aggregator; closing feed: {}", e); - break; - } + // Handle the socket closing, or errors receiving the message. + if let Err(soketto::connection::Error::Closed) = msg_info { + break; + } + if let Err(e) = msg_info { + log::error!("Shutting down websocket connection: Failed to receive data: {}", e); + break; } - // AGGREGATOR -> FRONTEND (buffer messages to the UI) - msgs = rx_from_aggregator_chunks.next() => { - // End the loop when connection from aggregator ends: - let msgs = match msgs { - Some(msgs) => msgs, - None => break - }; + // We ignore all but valid UTF8 text messages from the frontend: + let text = match String::from_utf8(bytes) { + Ok(s) => s, + Err(_) => continue + }; - // There is only one message type at the mo; bytes to send - // to the websocket. collect them all up to dispatch in one shot. - let all_ws_msgs = msgs.into_iter().map(|msg| { - let bytes = match msg { - ToFeedWebsocket::Bytes(bytes) => bytes - }; - Ok(ws::Message::binary(&*bytes)) - }); - - if let Err(e) = websocket.send_all(&mut futures::stream::iter(all_ws_msgs)).await { - log::warn!("Closing feed websocket due to error: {}", e); - break; + // Parse the message into a command we understand and send it to the aggregator: + let cmd = match FromFeedWebsocket::from_str(&text) { + Ok(cmd) => cmd, + Err(e) => { + log::warn!("Ignoring invalid command '{}' from the frontend: {}", text, e); + continue } + }; + if let Err(e) = tx_to_aggregator.send(cmd).await { + log::error!("Failed to send message to aggregator; closing feed: {}", e); + break; } + } - debounce.await; - } + drop(send_closer_tx); // Kill the send task if this recv task ends + tx_to_aggregator + }); + + // Send messages to the feed: + let send_handle = tokio::spawn(async move { + loop { + let debounce = tokio::time::sleep_until(tokio::time::Instant::now() + std::time::Duration::from_millis(75)); + + let msgs = tokio::select! { + msgs = rx_from_aggregator_chunks.next() => msgs, + _ = &mut send_closer_rx => { break } + }; + + // End the loop when connection from aggregator ends: + let msgs = match msgs { + Some(msgs) => msgs, + None => break + }; + + // There is only one message type at the mo; bytes to send + // to the websocket. collect them all up to dispatch in one shot. + let all_msg_bytes = msgs.into_iter().map(|msg| { + match msg { + ToFeedWebsocket::Bytes(bytes) => bytes + } + }); + + for bytes in all_msg_bytes { + if let Err(e) = ws_send.send_binary(&bytes).await { + log::warn!("Closing feed websocket due to error sending data: {}", e); + break; + } + } + + if let Err(e) = ws_send.flush().await { + log::warn!("Closing feed websocket due to error flushing data: {}", e); + break; + } + + debounce.await; + } + + drop(recv_closer_tx); // Kill the recv task if this send task ends + ws_send + }); + + // If our send/recv tasks are stopped (if one of them dies, they both will), + // collect the bits we need to hand back from them: + let ws_send = send_handle.await.unwrap(); + let tx_to_aggregator = recv_handle.await.unwrap(); // loop ended; give socket back to parent: - (tx_to_aggregator, websocket) + (tx_to_aggregator, ws_send) }