From 963ef576f3aa803ab51cdab54f6b5de8a64e2f13 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Mon, 2 Aug 2021 09:30:29 +0100 Subject: [PATCH] add on_close to abstract and simplify ws_client socket closing --- backend/common/src/ws_client/connect.rs | 22 +++++++++----------- backend/common/src/ws_client/mod.rs | 2 ++ backend/common/src/ws_client/on_close.rs | 26 ++++++++++++++++++++++++ backend/common/src/ws_client/receiver.rs | 15 +++----------- backend/common/src/ws_client/sender.rs | 18 ++++++---------- 5 files changed, 47 insertions(+), 36 deletions(-) create mode 100644 backend/common/src/ws_client/on_close.rs diff --git a/backend/common/src/ws_client/connect.rs b/backend/common/src/ws_client/connect.rs index 04df96b..a458278 100644 --- a/backend/common/src/ws_client/connect.rs +++ b/backend/common/src/ws_client/connect.rs @@ -19,6 +19,7 @@ use soketto::handshake::{Client, ServerResponse}; use tokio::net::TcpStream; use tokio_util::compat::TokioAsyncReadCompatExt; use std::sync::Arc; +use super::on_close::OnClose; use super::{ receiver::{Receiver, RecvMessage}, @@ -67,10 +68,9 @@ impl Connection { // and recv channels that we hand out are dropped. Notably, we allow either recv or // send alone to be dropped and still keep the socket open (we may only care about // one way communication). - let (tx_closed, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1); - let tx_closed1 = tx_closed.clone(); - let tx_closed2 = tx_closed.clone(); - let mut rx_closed2 = tx_closed.subscribe(); + let (tx_closed1, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1); + let tx_closed2 = tx_closed1.clone(); + let mut rx_closed2 = tx_closed1.subscribe(); // Receive messages from the socket and post them out: let (mut tx_to_external, rx_from_ws) = mpsc::unbounded(); @@ -90,7 +90,7 @@ impl Connection { // The socket had an error, so notify interested parties that we should // shut the connection down and bail out of this receive loop. log::error!("Shutting down websocket connection: Failed to receive data: {}", e); - let _ = tx_closed.send(()); + let _ = tx_closed1.send(()); break; } Ok(data) => data, @@ -193,18 +193,16 @@ impl Connection { }); // Keep track of whether one of sender or received have - // been dropped. If both have, close the socket connection. - let counter = Arc::new(()); + // been dropped. If both have, we close the socket connection. + let on_close = Arc::new(OnClose(tx_closed2)); (Sender { inner: tx_to_ws, - closer: tx_closed1, - count: Arc::clone(&counter), + closer: Arc::clone(&on_close), }, Receiver { - inner: rx_from_ws , - closer: tx_closed2, - count: counter, + inner: rx_from_ws, + _closer: on_close, }) } } diff --git a/backend/common/src/ws_client/mod.rs b/backend/common/src/ws_client/mod.rs index a6654ae..ea12913 100644 --- a/backend/common/src/ws_client/mod.rs +++ b/backend/common/src/ws_client/mod.rs @@ -20,6 +20,8 @@ mod connect; mod receiver; /// The channel based send interface mod sender; +/// A close helper that we use in sender/receiver. +mod on_close; pub use connect::{connect, ConnectError, Connection, RawReceiver, RawSender}; pub use receiver::{Receiver, RecvError, RecvMessage}; diff --git a/backend/common/src/ws_client/on_close.rs b/backend/common/src/ws_client/on_close.rs new file mode 100644 index 0000000..0c03ce0 --- /dev/null +++ b/backend/common/src/ws_client/on_close.rs @@ -0,0 +1,26 @@ +// Source code for the Substrate Telemetry Server. +// Copyright (C) 2021 Parity Technologies (UK) Ltd. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +use tokio::sync::broadcast; + +/// A small helper to fire the "close" channel when it's dropped. +pub struct OnClose(pub broadcast::Sender<()>); + +impl Drop for OnClose { + fn drop(&mut self) { + let _ = self.0.send(()); + } +} \ No newline at end of file diff --git a/backend/common/src/ws_client/receiver.rs b/backend/common/src/ws_client/receiver.rs index 93c7c3f..9d1a680 100644 --- a/backend/common/src/ws_client/receiver.rs +++ b/backend/common/src/ws_client/receiver.rs @@ -16,22 +16,13 @@ use futures::channel::mpsc; use futures::{Stream, StreamExt}; +use std::sync::Arc; +use super::on_close::OnClose; /// Receive messages out of a connection pub struct Receiver { pub(super) inner: mpsc::UnboundedReceiver>, - pub(super) closer: tokio::sync::broadcast::Sender<()>, - pub(super) count: std::sync::Arc<()>, -} - -impl Drop for Receiver { - fn drop(&mut self) { - // Close the socket connection if this is the last half - // of the channel (ie the sender has been dropped already). - if std::sync::Arc::strong_count(&self.count) == 1 { - let _ = self.closer.send(()); - } - } + pub(super) _closer: Arc, } #[derive(thiserror::Error, Debug)] diff --git a/backend/common/src/ws_client/sender.rs b/backend/common/src/ws_client/sender.rs index 711536b..0c06028 100644 --- a/backend/common/src/ws_client/sender.rs +++ b/backend/common/src/ws_client/sender.rs @@ -16,6 +16,8 @@ use futures::channel::mpsc; use futures::{Sink, SinkExt}; +use std::sync::Arc; +use super::on_close::OnClose; /// A message that can be sent into the channel interface #[derive(Debug, Clone)] @@ -47,24 +49,16 @@ pub(super) enum SentMessageInternal { #[derive(Clone)] pub struct Sender { pub(super) inner: mpsc::UnboundedSender, - pub(super) closer: tokio::sync::broadcast::Sender<()>, - pub(super) count: std::sync::Arc<()>, -} - -impl Drop for Sender { - fn drop(&mut self) { - // Close the socket connection if this is the last half - // of the channel (ie the receiver has been dropped already). - if std::sync::Arc::strong_count(&self.count) == 1 { - let _ = self.closer.send(()); - } - } + pub(super) closer: Arc, } impl Sender { /// Ask the underlying Websocket connection to close. pub async fn close(&mut self) -> Result<(), SendError> { self.inner.send(SentMessageInternal::Close).await?; + // fire the "proper" close handler (this shouldn't really be necessary + // since the above will cascade closing to both sides anyway). + let _ = self.closer.0.send(()); Ok(()) } /// Returns whether this channel is closed.