Be more fussy with Soketto channel wrapper; Soketto might do otherwise, but we only explicitly close the connection when both ends of the channel handed out are closed (or we encounter a connection error)

This commit is contained in:
James Wilson
2021-07-30 21:26:52 +01:00
parent 18918e6d3e
commit 13afb96f93
3 changed files with 79 additions and 25 deletions
+55 -25
View File
@@ -18,6 +18,7 @@ use futures::{SinkExt, StreamExt};
use soketto::handshake::{Client, ServerResponse};
use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncReadCompatExt;
use std::sync::Arc;
use super::{
receiver::{Receiver, RecvMessage},
@@ -61,25 +62,46 @@ impl Connection {
pub fn into_channels(self) -> (Sender, Receiver) {
let (mut ws_to_connection, mut ws_from_connection) = (self.tx, self.rx);
// Shut everything down when we're told to close, which will be either when
// we hit an error trying to receive data on the socket, or when both the send
// and recv channels that we hand out are dropped. Notably, we allow either recv or
// send alone to be dropped and still keep the socket open (we may only care about
// one way communication).
let (tx_closed, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1);
let tx_closed1 = tx_closed.clone();
let tx_closed2 = tx_closed.clone();
let mut rx_closed2 = tx_closed.subscribe();
// Receive messages from the socket and post them out:
let (mut tx_to_external, rx_from_ws) = mpsc::unbounded();
let (tx_has_closed, mut rx_has_closed) = futures::channel::oneshot::channel();
tokio::spawn(async move {
let mut data = Vec::with_capacity(128);
let mut send_to_external = true;
loop {
// Clear the buffer and wait for the next message to arrive:
data.clear();
let message_data = match ws_from_connection.receive_data(&mut data).await {
let mut data = Vec::new();
// Wait for messages, or bail entirely if asked to close.
let message_data = tokio::select! {
msg_data = ws_from_connection.receive_data(&mut data) => { msg_data },
_ = rx_closed1.recv() => { break }
};
let message_data = match message_data {
Err(e) => {
log::error!(
"Shutting down websocket connection: Failed to receive data: {}",
e
);
// The socket had an error, so notify interested parties that we should
// shut the connection down and bail out of this receive loop.
log::error!("Shutting down websocket connection: Failed to receive data: {}", e);
let _ = tx_closed.send(());
break;
}
Ok(data) => data,
};
// if we hit an error sending, we keep receiving messages and reacting
// to recv issues, but we stop trying to send them anywhere.
if !send_to_external {
continue;
}
let msg = match message_data {
soketto::Data::Binary(_) => Ok(RecvMessage::Binary(data)),
soketto::Data::Text(_) => String::from_utf8(data)
@@ -87,39 +109,34 @@ impl Connection {
.map_err(|e| e.into()),
};
data = Vec::with_capacity(128);
if let Err(e) = tx_to_external.send(msg).await {
log::error!(
"Shutting down websocket connection: Failed to send data out: {}",
e
);
break;
// Our external channel may have closed or errored, but the socket hasn't
// been closed, so keep receiving in order to allow the socket to continue to
// function properly (we may be happy just sending messages to it), but stop
// trying to send messages out,
log::warn!("Failed to send data out: {}", e);
send_to_external = false;
}
}
// If the receive loop ends, make sure the other half closes and let the socket
// connection be dropped. While we might want to drop half of the channel and keep
// the connection open, we need to keep receiving from it in order to ackknowledge
// control messages, so if this loops ends, we had better just shut it all down.
let _ = tx_has_closed.send(());
});
// Receive messages externally to send to the socket.
let (tx_to_ws, mut rx_from_external) = mpsc::unbounded();
tokio::spawn(async move {
loop {
// Wait for messages, or bail entirely if asked to close.
let msg = tokio::select! {
msg = rx_from_external.next() => { msg },
// Websocket connection closed? Don't wait for incoming message; break immediately.
_ = &mut rx_has_closed => { break },
_ = rx_closed2.recv() => { break }
};
// No more messages; channel closed. End this loop.
let msg = match msg {
None => break,
Some(msg) => msg,
};
// Any errors we hit here will be hit faster in the `receive_data`
match msg {
SentMessageInternal::Message(SentMessage::Text(s)) => {
if let Err(e) = ws_to_connection.send_text_owned(s).await {
@@ -175,7 +192,20 @@ impl Connection {
}
});
(Sender { inner: tx_to_ws }, Receiver { inner: rx_from_ws })
// Keep track of whether one of sender or received have
// been dropped. If both have, close the socket connection.
let counter = Arc::new(());
(Sender {
inner: tx_to_ws,
closer: tx_closed1,
count: Arc::clone(&counter),
},
Receiver {
inner: rx_from_ws ,
closer: tx_closed2,
count: counter,
})
}
}
+12
View File
@@ -20,6 +20,18 @@ use futures::{Stream, StreamExt};
/// Receive messages out of a connection
pub struct Receiver {
pub(super) inner: mpsc::UnboundedReceiver<Result<RecvMessage, RecvError>>,
pub(super) closer: tokio::sync::broadcast::Sender<()>,
pub(super) count: std::sync::Arc<()>,
}
impl Drop for Receiver {
fn drop(&mut self) {
// Close the socket connection if this is the last half
// of the channel (ie the sender has been dropped already).
if std::sync::Arc::strong_count(&self.count) == 1 {
let _ = self.closer.send(());
}
}
}
#[derive(thiserror::Error, Debug)]
+12
View File
@@ -47,6 +47,18 @@ pub(super) enum SentMessageInternal {
#[derive(Clone)]
pub struct Sender {
pub(super) inner: mpsc::UnboundedSender<SentMessageInternal>,
pub(super) closer: tokio::sync::broadcast::Sender<()>,
pub(super) count: std::sync::Arc<()>,
}
impl Drop for Sender {
fn drop(&mut self) {
// Close the socket connection if this is the last half
// of the channel (ie the receiver has been dropped already).
if std::sync::Arc::strong_count(&self.count) == 1 {
let _ = self.closer.send(());
}
}
}
impl Sender {