add on_close to abstract and simplify ws_client socket closing

This commit is contained in:
James Wilson
2021-08-02 09:30:29 +01:00
parent 13afb96f93
commit 963ef576f3
5 changed files with 47 additions and 36 deletions
+10 -12
View File
@@ -19,6 +19,7 @@ use soketto::handshake::{Client, ServerResponse};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncReadCompatExt; use tokio_util::compat::TokioAsyncReadCompatExt;
use std::sync::Arc; use std::sync::Arc;
use super::on_close::OnClose;
use super::{ use super::{
receiver::{Receiver, RecvMessage}, receiver::{Receiver, RecvMessage},
@@ -67,10 +68,9 @@ impl Connection {
// and recv channels that we hand out are dropped. Notably, we allow either recv or // and recv channels that we hand out are dropped. Notably, we allow either recv or
// send alone to be dropped and still keep the socket open (we may only care about // send alone to be dropped and still keep the socket open (we may only care about
// one way communication). // one way communication).
let (tx_closed, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1); let (tx_closed1, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1);
let tx_closed1 = tx_closed.clone(); let tx_closed2 = tx_closed1.clone();
let tx_closed2 = tx_closed.clone(); let mut rx_closed2 = tx_closed1.subscribe();
let mut rx_closed2 = tx_closed.subscribe();
// Receive messages from the socket and post them out: // Receive messages from the socket and post them out:
let (mut tx_to_external, rx_from_ws) = mpsc::unbounded(); let (mut tx_to_external, rx_from_ws) = mpsc::unbounded();
@@ -90,7 +90,7 @@ impl Connection {
// The socket had an error, so notify interested parties that we should // The socket had an error, so notify interested parties that we should
// shut the connection down and bail out of this receive loop. // shut the connection down and bail out of this receive loop.
log::error!("Shutting down websocket connection: Failed to receive data: {}", e); log::error!("Shutting down websocket connection: Failed to receive data: {}", e);
let _ = tx_closed.send(()); let _ = tx_closed1.send(());
break; break;
} }
Ok(data) => data, Ok(data) => data,
@@ -193,18 +193,16 @@ impl Connection {
}); });
// Keep track of whether one of sender or received have // Keep track of whether one of sender or received have
// been dropped. If both have, close the socket connection. // been dropped. If both have, we close the socket connection.
let counter = Arc::new(()); let on_close = Arc::new(OnClose(tx_closed2));
(Sender { (Sender {
inner: tx_to_ws, inner: tx_to_ws,
closer: tx_closed1, closer: Arc::clone(&on_close),
count: Arc::clone(&counter),
}, },
Receiver { Receiver {
inner: rx_from_ws , inner: rx_from_ws,
closer: tx_closed2, _closer: on_close,
count: counter,
}) })
} }
} }
+2
View File
@@ -20,6 +20,8 @@ mod connect;
mod receiver; mod receiver;
/// The channel based send interface /// The channel based send interface
mod sender; mod sender;
/// A close helper that we use in sender/receiver.
mod on_close;
pub use connect::{connect, ConnectError, Connection, RawReceiver, RawSender}; pub use connect::{connect, ConnectError, Connection, RawReceiver, RawSender};
pub use receiver::{Receiver, RecvError, RecvMessage}; pub use receiver::{Receiver, RecvError, RecvMessage};
+26
View File
@@ -0,0 +1,26 @@
// Source code for the Substrate Telemetry Server.
// Copyright (C) 2021 Parity Technologies (UK) Ltd.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use tokio::sync::broadcast;
/// A small helper to fire the "close" channel when it's dropped.
pub struct OnClose(pub broadcast::Sender<()>);
impl Drop for OnClose {
fn drop(&mut self) {
let _ = self.0.send(());
}
}
+3 -12
View File
@@ -16,22 +16,13 @@
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::sync::Arc;
use super::on_close::OnClose;
/// Receive messages out of a connection /// Receive messages out of a connection
pub struct Receiver { pub struct Receiver {
pub(super) inner: mpsc::UnboundedReceiver<Result<RecvMessage, RecvError>>, pub(super) inner: mpsc::UnboundedReceiver<Result<RecvMessage, RecvError>>,
pub(super) closer: tokio::sync::broadcast::Sender<()>, pub(super) _closer: Arc<OnClose>,
pub(super) count: std::sync::Arc<()>,
}
impl Drop for Receiver {
fn drop(&mut self) {
// Close the socket connection if this is the last half
// of the channel (ie the sender has been dropped already).
if std::sync::Arc::strong_count(&self.count) == 1 {
let _ = self.closer.send(());
}
}
} }
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
+6 -12
View File
@@ -16,6 +16,8 @@
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::{Sink, SinkExt}; use futures::{Sink, SinkExt};
use std::sync::Arc;
use super::on_close::OnClose;
/// A message that can be sent into the channel interface /// A message that can be sent into the channel interface
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -47,24 +49,16 @@ pub(super) enum SentMessageInternal {
#[derive(Clone)] #[derive(Clone)]
pub struct Sender { pub struct Sender {
pub(super) inner: mpsc::UnboundedSender<SentMessageInternal>, pub(super) inner: mpsc::UnboundedSender<SentMessageInternal>,
pub(super) closer: tokio::sync::broadcast::Sender<()>, pub(super) closer: Arc<OnClose>,
pub(super) count: std::sync::Arc<()>,
}
impl Drop for Sender {
fn drop(&mut self) {
// Close the socket connection if this is the last half
// of the channel (ie the receiver has been dropped already).
if std::sync::Arc::strong_count(&self.count) == 1 {
let _ = self.closer.send(());
}
}
} }
impl Sender { impl Sender {
/// Ask the underlying Websocket connection to close. /// Ask the underlying Websocket connection to close.
pub async fn close(&mut self) -> Result<(), SendError> { pub async fn close(&mut self) -> Result<(), SendError> {
self.inner.send(SentMessageInternal::Close).await?; self.inner.send(SentMessageInternal::Close).await?;
// fire the "proper" close handler (this shouldn't really be necessary
// since the above will cascade closing to both sides anyway).
let _ = self.closer.0.send(());
Ok(()) Ok(())
} }
/// Returns whether this channel is closed. /// Returns whether this channel is closed.