From 13afb96f9323ae5468cde0a4a6d60cb9430a9916 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Fri, 30 Jul 2021 21:26:52 +0100 Subject: [PATCH] Be more fussy with Soketto channel wrapper; Soketto might do otherwise, but we only explicitly close the connection when both ends of the channel handed out are closed (or we encounter a connection error) --- backend/common/src/ws_client/connect.rs | 80 ++++++++++++++++-------- backend/common/src/ws_client/receiver.rs | 12 ++++ backend/common/src/ws_client/sender.rs | 12 ++++ 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/backend/common/src/ws_client/connect.rs b/backend/common/src/ws_client/connect.rs index 2eaf2fc..04df96b 100644 --- a/backend/common/src/ws_client/connect.rs +++ b/backend/common/src/ws_client/connect.rs @@ -18,6 +18,7 @@ use futures::{SinkExt, StreamExt}; use soketto::handshake::{Client, ServerResponse}; use tokio::net::TcpStream; use tokio_util::compat::TokioAsyncReadCompatExt; +use std::sync::Arc; use super::{ receiver::{Receiver, RecvMessage}, @@ -61,25 +62,46 @@ impl Connection { pub fn into_channels(self) -> (Sender, Receiver) { let (mut ws_to_connection, mut ws_from_connection) = (self.tx, self.rx); + // Shut everything down when we're told to close, which will be either when + // we hit an error trying to receive data on the socket, or when both the send + // and recv channels that we hand out are dropped. Notably, we allow either recv or + // send alone to be dropped and still keep the socket open (we may only care about + // one way communication). + let (tx_closed, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1); + let tx_closed1 = tx_closed.clone(); + let tx_closed2 = tx_closed.clone(); + let mut rx_closed2 = tx_closed.subscribe(); + // Receive messages from the socket and post them out: let (mut tx_to_external, rx_from_ws) = mpsc::unbounded(); - let (tx_has_closed, mut rx_has_closed) = futures::channel::oneshot::channel(); tokio::spawn(async move { - let mut data = Vec::with_capacity(128); + let mut send_to_external = true; loop { - // Clear the buffer and wait for the next message to arrive: - data.clear(); - let message_data = match ws_from_connection.receive_data(&mut data).await { + let mut data = Vec::new(); + + // Wait for messages, or bail entirely if asked to close. + let message_data = tokio::select! { + msg_data = ws_from_connection.receive_data(&mut data) => { msg_data }, + _ = rx_closed1.recv() => { break } + }; + + let message_data = match message_data { Err(e) => { - log::error!( - "Shutting down websocket connection: Failed to receive data: {}", - e - ); + // The socket had an error, so notify interested parties that we should + // shut the connection down and bail out of this receive loop. + log::error!("Shutting down websocket connection: Failed to receive data: {}", e); + let _ = tx_closed.send(()); break; } Ok(data) => data, }; + // if we hit an error sending, we keep receiving messages and reacting + // to recv issues, but we stop trying to send them anywhere. + if !send_to_external { + continue; + } + let msg = match message_data { soketto::Data::Binary(_) => Ok(RecvMessage::Binary(data)), soketto::Data::Text(_) => String::from_utf8(data) @@ -87,39 +109,34 @@ impl Connection { .map_err(|e| e.into()), }; - data = Vec::with_capacity(128); - if let Err(e) = tx_to_external.send(msg).await { - log::error!( - "Shutting down websocket connection: Failed to send data out: {}", - e - ); - break; + // Our external channel may have closed or errored, but the socket hasn't + // been closed, so keep receiving in order to allow the socket to continue to + // function properly (we may be happy just sending messages to it), but stop + // trying to send messages out, + log::warn!("Failed to send data out: {}", e); + send_to_external = false; } } - - // If the receive loop ends, make sure the other half closes and let the socket - // connection be dropped. While we might want to drop half of the channel and keep - // the connection open, we need to keep receiving from it in order to ackknowledge - // control messages, so if this loops ends, we had better just shut it all down. - let _ = tx_has_closed.send(()); }); // Receive messages externally to send to the socket. let (tx_to_ws, mut rx_from_external) = mpsc::unbounded(); tokio::spawn(async move { loop { + // Wait for messages, or bail entirely if asked to close. let msg = tokio::select! { msg = rx_from_external.next() => { msg }, - // Websocket connection closed? Don't wait for incoming message; break immediately. - _ = &mut rx_has_closed => { break }, + _ = rx_closed2.recv() => { break } }; + // No more messages; channel closed. End this loop. let msg = match msg { None => break, Some(msg) => msg, }; + // Any errors we hit here will be hit faster in the `receive_data` match msg { SentMessageInternal::Message(SentMessage::Text(s)) => { if let Err(e) = ws_to_connection.send_text_owned(s).await { @@ -175,7 +192,20 @@ impl Connection { } }); - (Sender { inner: tx_to_ws }, Receiver { inner: rx_from_ws }) + // Keep track of whether one of sender or received have + // been dropped. If both have, close the socket connection. + let counter = Arc::new(()); + + (Sender { + inner: tx_to_ws, + closer: tx_closed1, + count: Arc::clone(&counter), + }, + Receiver { + inner: rx_from_ws , + closer: tx_closed2, + count: counter, + }) } } diff --git a/backend/common/src/ws_client/receiver.rs b/backend/common/src/ws_client/receiver.rs index f4c80e0..93c7c3f 100644 --- a/backend/common/src/ws_client/receiver.rs +++ b/backend/common/src/ws_client/receiver.rs @@ -20,6 +20,18 @@ use futures::{Stream, StreamExt}; /// Receive messages out of a connection pub struct Receiver { pub(super) inner: mpsc::UnboundedReceiver>, + pub(super) closer: tokio::sync::broadcast::Sender<()>, + pub(super) count: std::sync::Arc<()>, +} + +impl Drop for Receiver { + fn drop(&mut self) { + // Close the socket connection if this is the last half + // of the channel (ie the sender has been dropped already). + if std::sync::Arc::strong_count(&self.count) == 1 { + let _ = self.closer.send(()); + } + } } #[derive(thiserror::Error, Debug)] diff --git a/backend/common/src/ws_client/sender.rs b/backend/common/src/ws_client/sender.rs index b0f47a4..711536b 100644 --- a/backend/common/src/ws_client/sender.rs +++ b/backend/common/src/ws_client/sender.rs @@ -47,6 +47,18 @@ pub(super) enum SentMessageInternal { #[derive(Clone)] pub struct Sender { pub(super) inner: mpsc::UnboundedSender, + pub(super) closer: tokio::sync::broadcast::Sender<()>, + pub(super) count: std::sync::Arc<()>, +} + +impl Drop for Sender { + fn drop(&mut self) { + // Close the socket connection if this is the last half + // of the channel (ie the receiver has been dropped already). + if std::sync::Arc::strong_count(&self.count) == 1 { + let _ = self.closer.send(()); + } + } } impl Sender {