From b79759a41297388e60f3833b9eddf97ab94c79cc Mon Sep 17 00:00:00 2001 From: James Wilson Date: Mon, 2 Aug 2021 09:50:31 +0100 Subject: [PATCH] Tidy up close handling and allow either side to close --- backend/common/src/ws_client/connect.rs | 35 ++++++++++++------------ backend/common/src/ws_client/receiver.rs | 12 +++++++- backend/common/src/ws_client/sender.rs | 20 ++++---------- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/backend/common/src/ws_client/connect.rs b/backend/common/src/ws_client/connect.rs index a458278..ebf42fe 100644 --- a/backend/common/src/ws_client/connect.rs +++ b/backend/common/src/ws_client/connect.rs @@ -23,7 +23,7 @@ use super::on_close::OnClose; use super::{ receiver::{Receiver, RecvMessage}, - sender::{Sender, SentMessage, SentMessageInternal}, + sender::{Sender, SentMessage}, }; /// The send side of a Soketto WebSocket connection @@ -72,7 +72,7 @@ impl Connection { let tx_closed2 = tx_closed1.clone(); let mut rx_closed2 = tx_closed1.subscribe(); - // Receive messages from the socket and post them out: + // Receive messages from the socket: let (mut tx_to_external, rx_from_ws) = mpsc::unbounded(); tokio::spawn(async move { let mut send_to_external = true; @@ -113,32 +113,37 @@ impl Connection { // Our external channel may have closed or errored, but the socket hasn't // been closed, so keep receiving in order to allow the socket to continue to // function properly (we may be happy just sending messages to it), but stop - // trying to send messages out, + // trying to hand back messages we've received from the socket. log::warn!("Failed to send data out: {}", e); send_to_external = false; } } }); - // Receive messages externally to send to the socket. + // Send messages to the socket: let (tx_to_ws, mut rx_from_external) = mpsc::unbounded(); tokio::spawn(async move { loop { // Wait for messages, or bail entirely if asked to close. let msg = tokio::select! { msg = rx_from_external.next() => { msg }, - _ = rx_closed2.recv() => { break } + _ = rx_closed2.recv() => { + // attempt to gracefully end the connection. + let _ = ws_to_connection.close().await; + break + } }; - // No more messages; channel closed. End this loop. + // No more messages; channel closed. End this loop. Unlike the recv side which + // needs to keep receiving data for the WS connection to stay open, there's no + // reason to keep this side of the loop open if our channel is closed. let msg = match msg { None => break, Some(msg) => msg, }; - // Any errors we hit here will be hit faster in the `receive_data` match msg { - SentMessageInternal::Message(SentMessage::Text(s)) => { + SentMessage::Text(s) => { if let Err(e) = ws_to_connection.send_text_owned(s).await { log::error!( "Shutting down websocket connection: Failed to send text data: {}", @@ -147,7 +152,7 @@ impl Connection { break; } } - SentMessageInternal::Message(SentMessage::Binary(bytes)) => { + SentMessage::Binary(bytes) => { if let Err(e) = ws_to_connection.send_binary_mut(bytes).await { log::error!( "Shutting down websocket connection: Failed to send binary data: {}", @@ -156,7 +161,7 @@ impl Connection { break; } } - SentMessageInternal::Message(SentMessage::StaticText(s)) => { + SentMessage::StaticText(s) => { if let Err(e) = ws_to_connection.send_text(s).await { log::error!( "Shutting down websocket connection: Failed to send text data: {}", @@ -165,7 +170,7 @@ impl Connection { break; } } - SentMessageInternal::Message(SentMessage::StaticBinary(bytes)) => { + SentMessage::StaticBinary(bytes) => { if let Err(e) = ws_to_connection.send_binary(bytes).await { log::error!( "Shutting down websocket connection: Failed to send binary data: {}", @@ -174,12 +179,6 @@ impl Connection { break; } } - SentMessageInternal::Close => { - if let Err(e) = ws_to_connection.close().await { - log::error!("Error attempting to close connection: {}", e); - break; - } - } } if let Err(e) = ws_to_connection.flush().await { @@ -202,7 +201,7 @@ impl Connection { }, Receiver { inner: rx_from_ws, - _closer: on_close, + closer: on_close, }) } } diff --git a/backend/common/src/ws_client/receiver.rs b/backend/common/src/ws_client/receiver.rs index 9d1a680..cdd4912 100644 --- a/backend/common/src/ws_client/receiver.rs +++ b/backend/common/src/ws_client/receiver.rs @@ -22,7 +22,7 @@ use super::on_close::OnClose; /// Receive messages out of a connection pub struct Receiver { pub(super) inner: mpsc::UnboundedReceiver>, - pub(super) _closer: Arc, + pub(super) closer: Arc, } #[derive(thiserror::Error, Debug)] @@ -31,6 +31,16 @@ pub enum RecvError { InvalidUtf8(#[from] std::string::FromUtf8Error), #[error("Stream finished")] StreamFinished, + #[error("Failed to send close message")] + CloseError +} + +impl Receiver { + /// Ask the underlying Websocket connection to close. + pub async fn close(&mut self) -> Result<(), RecvError> { + self.closer.0.send(()).map_err(|_| RecvError::CloseError)?; + Ok(()) + } } impl Stream for Receiver { diff --git a/backend/common/src/ws_client/sender.rs b/backend/common/src/ws_client/sender.rs index 0c06028..71872a4 100644 --- a/backend/common/src/ws_client/sender.rs +++ b/backend/common/src/ws_client/sender.rs @@ -38,27 +38,17 @@ pub enum SentMessage { Binary(Vec), } -/// Messages sent into the channel interface can be anything publically visible, or a close message. -#[derive(Debug, Clone)] -pub(super) enum SentMessageInternal { - Message(SentMessage), - Close, -} - /// Send messages into the connection #[derive(Clone)] pub struct Sender { - pub(super) inner: mpsc::UnboundedSender, + pub(super) inner: mpsc::UnboundedSender, pub(super) closer: Arc, } impl Sender { /// Ask the underlying Websocket connection to close. pub async fn close(&mut self) -> Result<(), SendError> { - self.inner.send(SentMessageInternal::Close).await?; - // fire the "proper" close handler (this shouldn't really be necessary - // since the above will cascade closing to both sides anyway). - let _ = self.closer.0.send(()); + self.closer.0.send(()).map_err(|_| SendError::CloseError)?; Ok(()) } /// Returns whether this channel is closed. @@ -69,7 +59,7 @@ impl Sender { /// need to be awaited. pub fn unbounded_send(&self, msg: SentMessage) -> Result<(), SendError> { self.inner - .unbounded_send(SentMessageInternal::Message(msg)) + .unbounded_send(msg) .map_err(|e| e.into_send_error())?; Ok(()) } @@ -79,6 +69,8 @@ impl Sender { pub enum SendError { #[error("Failed to send message: {0}")] ChannelError(#[from] mpsc::SendError), + #[error("Failed to send close message")] + CloseError } impl Sink for Sender { @@ -94,7 +86,7 @@ impl Sink for Sender { item: SentMessage, ) -> Result<(), Self::Error> { self.inner - .start_send_unpin(SentMessageInternal::Message(item)) + .start_send_unpin(item) .map_err(|e| e.into()) } fn poll_flush(