add on_close to abstract and simplify ws_client socket closing

This commit is contained in:
James Wilson
2021-08-02 09:30:29 +01:00
parent 13afb96f93
commit 963ef576f3
5 changed files with 47 additions and 36 deletions
+10 -12
View File
@@ -19,6 +19,7 @@ use soketto::handshake::{Client, ServerResponse};
use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncReadCompatExt;
use std::sync::Arc;
use super::on_close::OnClose;
use super::{
receiver::{Receiver, RecvMessage},
@@ -67,10 +68,9 @@ impl Connection {
// and recv channels that we hand out are dropped. Notably, we allow either recv or
// send alone to be dropped and still keep the socket open (we may only care about
// one way communication).
let (tx_closed, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1);
let tx_closed1 = tx_closed.clone();
let tx_closed2 = tx_closed.clone();
let mut rx_closed2 = tx_closed.subscribe();
let (tx_closed1, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1);
let tx_closed2 = tx_closed1.clone();
let mut rx_closed2 = tx_closed1.subscribe();
// Receive messages from the socket and post them out:
let (mut tx_to_external, rx_from_ws) = mpsc::unbounded();
@@ -90,7 +90,7 @@ impl Connection {
// The socket had an error, so notify interested parties that we should
// shut the connection down and bail out of this receive loop.
log::error!("Shutting down websocket connection: Failed to receive data: {}", e);
let _ = tx_closed.send(());
let _ = tx_closed1.send(());
break;
}
Ok(data) => data,
@@ -193,18 +193,16 @@ impl Connection {
});
// Keep track of whether one of sender or received have
// been dropped. If both have, close the socket connection.
let counter = Arc::new(());
// been dropped. If both have, we close the socket connection.
let on_close = Arc::new(OnClose(tx_closed2));
(Sender {
inner: tx_to_ws,
closer: tx_closed1,
count: Arc::clone(&counter),
closer: Arc::clone(&on_close),
},
Receiver {
inner: rx_from_ws ,
closer: tx_closed2,
count: counter,
inner: rx_from_ws,
_closer: on_close,
})
}
}
+2
View File
@@ -20,6 +20,8 @@ mod connect;
mod receiver;
/// The channel based send interface
mod sender;
/// A close helper that we use in sender/receiver.
mod on_close;
pub use connect::{connect, ConnectError, Connection, RawReceiver, RawSender};
pub use receiver::{Receiver, RecvError, RecvMessage};
+26
View File
@@ -0,0 +1,26 @@
// Source code for the Substrate Telemetry Server.
// Copyright (C) 2021 Parity Technologies (UK) Ltd.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use tokio::sync::broadcast;
/// A small helper to fire the "close" channel when it's dropped.
pub struct OnClose(pub broadcast::Sender<()>);
impl Drop for OnClose {
fn drop(&mut self) {
let _ = self.0.send(());
}
}
+3 -12
View File
@@ -16,22 +16,13 @@
use futures::channel::mpsc;
use futures::{Stream, StreamExt};
use std::sync::Arc;
use super::on_close::OnClose;
/// Receive messages out of a connection
pub struct Receiver {
pub(super) inner: mpsc::UnboundedReceiver<Result<RecvMessage, RecvError>>,
pub(super) closer: tokio::sync::broadcast::Sender<()>,
pub(super) count: std::sync::Arc<()>,
}
impl Drop for Receiver {
fn drop(&mut self) {
// Close the socket connection if this is the last half
// of the channel (ie the sender has been dropped already).
if std::sync::Arc::strong_count(&self.count) == 1 {
let _ = self.closer.send(());
}
}
pub(super) _closer: Arc<OnClose>,
}
#[derive(thiserror::Error, Debug)]
+6 -12
View File
@@ -16,6 +16,8 @@
use futures::channel::mpsc;
use futures::{Sink, SinkExt};
use std::sync::Arc;
use super::on_close::OnClose;
/// A message that can be sent into the channel interface
#[derive(Debug, Clone)]
@@ -47,24 +49,16 @@ pub(super) enum SentMessageInternal {
#[derive(Clone)]
pub struct Sender {
pub(super) inner: mpsc::UnboundedSender<SentMessageInternal>,
pub(super) closer: tokio::sync::broadcast::Sender<()>,
pub(super) count: std::sync::Arc<()>,
}
impl Drop for Sender {
fn drop(&mut self) {
// Close the socket connection if this is the last half
// of the channel (ie the receiver has been dropped already).
if std::sync::Arc::strong_count(&self.count) == 1 {
let _ = self.closer.send(());
}
}
pub(super) closer: Arc<OnClose>,
}
impl Sender {
/// Ask the underlying Websocket connection to close.
pub async fn close(&mut self) -> Result<(), SendError> {
self.inner.send(SentMessageInternal::Close).await?;
// fire the "proper" close handler (this shouldn't really be necessary
// since the above will cascade closing to both sides anyway).
let _ = self.closer.0.send(());
Ok(())
}
/// Returns whether this channel is closed.