diff --git a/backend/common/src/ws_client.rs b/backend/common/src/ws_client.rs deleted file mode 100644 index 7de9363..0000000 --- a/backend/common/src/ws_client.rs +++ /dev/null @@ -1,274 +0,0 @@ -use futures::channel::mpsc; -use futures::{Sink, SinkExt, Stream, StreamExt}; -use soketto::handshake::{Client, ServerResponse}; -use tokio::net::TcpStream; -use tokio_util::compat::TokioAsyncReadCompatExt; - -/// Send messages into the connection -#[derive(Clone)] -pub struct Sender { - inner: mpsc::UnboundedSender, -} - -impl Sender { - /// Ask the underlying Websocket connection to close. - pub async fn close(&mut self) -> Result<(), SendError> { - self.inner.send(SentMessageInternal::Close).await?; - Ok(()) - } - /// Returns whether this channel is closed. - pub fn is_closed(&mut self) -> bool { - self.inner.is_closed() - } - /// Unbounded send will always queue the message and doesn't - /// need to be awaited. - pub fn unbounded_send(&self, msg: SentMessage) -> Result<(), SendError> { - self.inner - .unbounded_send(SentMessageInternal::Message(msg)) - .map_err(|e| e.into_send_error())?; - Ok(()) - } -} - -#[derive(thiserror::Error, Debug, Clone)] -pub enum SendError { - #[error("Failed to send message: {0}")] - ChannelError(#[from] mpsc::SendError) -} - -impl Sink for Sender { - type Error = SendError; - fn poll_ready( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_ready_unpin(cx).map_err(|e| e.into()) - } - fn start_send(mut self: std::pin::Pin<&mut Self>, item: SentMessage) -> Result<(), Self::Error> { - self.inner - .start_send_unpin(SentMessageInternal::Message(item)) - .map_err(|e| e.into()) - } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_flush_unpin(cx).map_err(|e| e.into()) - } - fn poll_close( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_close_unpin(cx).map_err(|e| e.into()) - } -} - -/// Receive messages out of a connection -pub struct Receiver { - inner: mpsc::UnboundedReceiver>, -} - -#[derive(thiserror::Error, Debug)] -pub enum RecvError { - #[error("Text message contains invalid UTF8: {0}")] - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("Stream finished")] - StreamFinished, -} - -impl Stream for Receiver { - type Item = Result; - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_next_unpin(cx).map_err(|e| e.into()) - } -} - -/// A message that can be received from the connection -#[derive(Debug, Clone)] -pub enum RecvMessage { - /// Send an owned string into the socket. - Text(String), - /// Send owned bytes into the socket. - Binary(Vec), -} - -impl RecvMessage { - pub fn len(&self) -> usize { - match self { - RecvMessage::Binary(b) => b.len(), - RecvMessage::Text(s) => s.len(), - } - } -} - -/// A message that can be sent into the connection -#[derive(Debug, Clone)] -pub enum SentMessage { - /// Being able to send static text is primarily useful for benchmarking, - /// so that we can avoid cloning an owned string and pass a static reference - /// (one such option here is using [`Box::leak`] to generate strings with - /// static lifetimes). - StaticText(&'static str), - /// Being able to send static bytes is primarily useful for benchmarking, - /// so that we can avoid cloning an owned string and pass a static reference - /// (one such option here is using [`Box::leak`] to generate bytes with - /// static lifetimes). - StaticBinary(&'static [u8]), - /// Send an owned string into the socket. - Text(String), - /// Send owned bytes into the socket. - Binary(Vec), -} - -/// Sent messages can be anything publically visible, or a close message. -#[derive(Debug, Clone)] -enum SentMessageInternal { - Message(SentMessage), - Close, -} - -#[derive(thiserror::Error, Debug)] -pub enum ConnectError { - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - #[error("Handshake error: {0}")] - Handshake(#[from] soketto::handshake::Error), - #[error("Redirect not supported (status code: {status_code})")] - ConnectionFailedRedirect { status_code: u16 }, - #[error("Connection rejected (status code: {status_code})")] - ConnectionFailedRejected { status_code: u16 }, -} - -/// Establish a websocket connection that you can send and receive messages from. -/// A thin wrapper around Soketto that provides cancel-safe send/receive handles. -/// -/// This must be called within the context of a tokio runtime. -pub async fn connect(uri: &http::Uri) -> Result<(Sender, Receiver), ConnectError> { - let host = uri.host().unwrap_or("127.0.0.1"); - let port = uri.port_u16().unwrap_or(80); - let path = uri.path(); - - let socket = TcpStream::connect((host, port)).await?; - socket.set_nodelay(true).expect("socket set_nodelay failed"); - - // Establish a WS connection: - let mut client = Client::new(socket.compat(), host, &path); - let (mut ws_to_connection, mut ws_from_connection) = match client.handshake().await? { - ServerResponse::Accepted { .. } => client.into_builder().finish(), - ServerResponse::Redirect { status_code, .. } => { - return Err(ConnectError::ConnectionFailedRedirect { status_code }) - } - ServerResponse::Rejected { status_code } => { - return Err(ConnectError::ConnectionFailedRejected { status_code }) - } - }; - - // Soketto sending/receiving isn't cancel safe, so we wrap the message stuff into spawned - // tasks and use channels (which are cancel safe) to send/recv messages atomically.. - - // Receive messages from the socket and post them out: - let (mut tx_to_external, rx_from_ws) = mpsc::unbounded(); - tokio::spawn(async move { - let mut data = Vec::with_capacity(128); - loop { - // Clear the buffer and wait for the next message to arrive: - data.clear(); - - let message_data = match ws_from_connection.receive_data(&mut data).await { - Err(e) => { - // Couldn't receive data may mean all senders are gone, so log - // the error and shut this down: - log::error!( - "Shutting down websocket connection: Failed to receive data: {}", - e - ); - break; - } - Ok(data) => data, - }; - - let msg = match message_data { - soketto::Data::Text(_) => Ok(RecvMessage::Binary(data)), - soketto::Data::Binary(_) => String::from_utf8(data) - .map(|s| RecvMessage::Text(s)) - .map_err(|e| e.into()), - }; - - data = Vec::with_capacity(128); - - if let Err(e) = tx_to_external.send(msg).await { - // Failure to send likely means that the recv has been dropped, - // so let's drop this loop too. - log::error!( - "Shutting down websocket connection: Failed to send data out: {}", - e - ); - break; - } - } - }); - - // Receive messages externally to send to the socket. - let (tx_to_ws, mut rx_from_external) = mpsc::unbounded(); - tokio::spawn(async move { - while let Some(msg) = rx_from_external.next().await { - match msg { - SentMessageInternal::Message(SentMessage::Text(s)) => { - if let Err(e) = ws_to_connection.send_text_owned(s).await { - log::error!( - "Shutting down websocket connection: Failed to send text data: {}", - e - ); - break; - } - } - SentMessageInternal::Message(SentMessage::Binary(bytes)) => { - if let Err(e) = ws_to_connection.send_binary_mut(bytes).await { - log::error!( - "Shutting down websocket connection: Failed to send binary data: {}", - e - ); - break; - } - }, - SentMessageInternal::Message(SentMessage::StaticText(s)) => { - if let Err(e) = ws_to_connection.send_text(s).await { - log::error!( - "Shutting down websocket connection: Failed to send text data: {}", - e - ); - break; - } - } - SentMessageInternal::Message(SentMessage::StaticBinary(bytes)) => { - if let Err(e) = ws_to_connection.send_binary(bytes).await { - log::error!( - "Shutting down websocket connection: Failed to send binary data: {}", - e - ); - break; - } - }, - SentMessageInternal::Close => { - if let Err(e) = ws_to_connection.close().await { - log::error!("Error attempting to close connection: {}", e); - break; - } - } - } - - if let Err(e) = ws_to_connection.flush().await { - log::error!( - "Shutting down websocket connection: Failed to flush data: {}", - e - ); - break; - } - } - }); - - Ok((Sender { inner: tx_to_ws }, Receiver { inner: rx_from_ws })) -} diff --git a/backend/common/src/ws_client/connect.rs b/backend/common/src/ws_client/connect.rs new file mode 100644 index 0000000..9cf3f1a --- /dev/null +++ b/backend/common/src/ws_client/connect.rs @@ -0,0 +1,185 @@ +use futures::channel::mpsc; +use futures::{SinkExt, StreamExt}; +use soketto::handshake::{Client, ServerResponse}; +use tokio::net::TcpStream; +use tokio_util::compat::TokioAsyncReadCompatExt; + +use super::{ + sender::{ Sender, SentMessage, SentMessageInternal }, + receiver::{ Receiver, RecvMessage } +}; + +/// The send side of a Soketto WebSocket connection +pub type RawSender = soketto::connection::Sender>; + +/// The receive side of a Soketto WebSocket connection +pub type RawReceiver = soketto::connection::Receiver>; + +/// A websocket connection. From this, we can either expose the raw connection +/// or expose a cancel-safe interface to it. +pub struct Connection { + tx: soketto::connection::Sender>, + rx: soketto::connection::Receiver> +} + +impl Connection { + /// Get hold of the raw send/receive interface for this connection. + /// These are not cancel-safe, but can be more performant than the + /// cancel-safe channel based interface. + pub fn into_raw(self) -> (RawSender, RawReceiver) { + (self.tx, self.rx) + } + + /// Get hold of send and receive channels for this connection. + /// These channels are cancel-safe. + /// + /// # Panics + /// + /// This will panic if not called within the context of a tokio runtime. + /// + pub fn into_channels(self) -> (Sender, Receiver) { + let (mut ws_to_connection, mut ws_from_connection) = (self.tx, self.rx); + + // Receive messages from the socket and post them out: + let (mut tx_to_external, rx_from_ws) = mpsc::unbounded(); + tokio::spawn(async move { + let mut data = Vec::with_capacity(128); + loop { + // Clear the buffer and wait for the next message to arrive: + data.clear(); + + let message_data = match ws_from_connection.receive_data(&mut data).await { + Err(e) => { + // Couldn't receive data may mean all senders are gone, so log + // the error and shut this down: + log::error!( + "Shutting down websocket connection: Failed to receive data: {}", + e + ); + break; + } + Ok(data) => data, + }; + + let msg = match message_data { + soketto::Data::Text(_) => Ok(RecvMessage::Binary(data)), + soketto::Data::Binary(_) => String::from_utf8(data) + .map(|s| RecvMessage::Text(s)) + .map_err(|e| e.into()), + }; + + data = Vec::with_capacity(128); + + if let Err(e) = tx_to_external.send(msg).await { + // Failure to send likely means that the recv has been dropped, + // so let's drop this loop too. + log::error!( + "Shutting down websocket connection: Failed to send data out: {}", + e + ); + break; + } + } + }); + + // Receive messages externally to send to the socket. + let (tx_to_ws, mut rx_from_external) = mpsc::unbounded(); + tokio::spawn(async move { + while let Some(msg) = rx_from_external.next().await { + match msg { + SentMessageInternal::Message(SentMessage::Text(s)) => { + if let Err(e) = ws_to_connection.send_text_owned(s).await { + log::error!( + "Shutting down websocket connection: Failed to send text data: {}", + e + ); + break; + } + } + SentMessageInternal::Message(SentMessage::Binary(bytes)) => { + if let Err(e) = ws_to_connection.send_binary_mut(bytes).await { + log::error!( + "Shutting down websocket connection: Failed to send binary data: {}", + e + ); + break; + } + }, + SentMessageInternal::Message(SentMessage::StaticText(s)) => { + if let Err(e) = ws_to_connection.send_text(s).await { + log::error!( + "Shutting down websocket connection: Failed to send text data: {}", + e + ); + break; + } + } + SentMessageInternal::Message(SentMessage::StaticBinary(bytes)) => { + if let Err(e) = ws_to_connection.send_binary(bytes).await { + log::error!( + "Shutting down websocket connection: Failed to send binary data: {}", + e + ); + break; + } + }, + SentMessageInternal::Close => { + if let Err(e) = ws_to_connection.close().await { + log::error!("Error attempting to close connection: {}", e); + break; + } + } + } + + if let Err(e) = ws_to_connection.flush().await { + log::error!( + "Shutting down websocket connection: Failed to flush data: {}", + e + ); + break; + } + } + }); + + (Sender { inner: tx_to_ws }, Receiver { inner: rx_from_ws }) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum ConnectError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Handshake error: {0}")] + Handshake(#[from] soketto::handshake::Error), + #[error("Redirect not supported (status code: {status_code})")] + ConnectionFailedRedirect { status_code: u16 }, + #[error("Connection rejected (status code: {status_code})")] + ConnectionFailedRejected { status_code: u16 }, +} + +/// Establish a websocket connection that you can send and receive messages from. +pub async fn connect(uri: &http::Uri) -> Result { + let host = uri.host().unwrap_or("127.0.0.1"); + let port = uri.port_u16().unwrap_or(80); + let path = uri.path(); + + let socket = TcpStream::connect((host, port)).await?; + socket.set_nodelay(true).expect("socket set_nodelay failed"); + + // Establish a WS connection: + let mut client = Client::new(socket.compat(), host, &path); + let (ws_to_connection, ws_from_connection) = match client.handshake().await? { + ServerResponse::Accepted { .. } => client.into_builder().finish(), + ServerResponse::Redirect { status_code, .. } => { + return Err(ConnectError::ConnectionFailedRedirect { status_code }) + } + ServerResponse::Rejected { status_code } => { + return Err(ConnectError::ConnectionFailedRejected { status_code }) + } + }; + + Ok(Connection { + tx: ws_to_connection, + rx: ws_from_connection + }) +} diff --git a/backend/common/src/ws_client/mod.rs b/backend/common/src/ws_client/mod.rs new file mode 100644 index 0000000..a5d1b66 --- /dev/null +++ b/backend/common/src/ws_client/mod.rs @@ -0,0 +1,10 @@ +/// Functionality to establish a connection +mod connect; +/// The channel based send interface +mod sender; +/// The channel based receive interface +mod receiver; + +pub use connect::{ connect, ConnectError, Connection, RawSender, RawReceiver }; +pub use sender::{ Sender, SentMessage, SendError }; +pub use receiver::{ Receiver, RecvMessage, RecvError }; \ No newline at end of file diff --git a/backend/common/src/ws_client/receiver.rs b/backend/common/src/ws_client/receiver.rs new file mode 100644 index 0000000..1d362a3 --- /dev/null +++ b/backend/common/src/ws_client/receiver.rs @@ -0,0 +1,43 @@ +use futures::channel::mpsc; +use futures::{Stream, StreamExt}; + +/// Receive messages out of a connection +pub struct Receiver { + pub (super) inner: mpsc::UnboundedReceiver>, +} + +#[derive(thiserror::Error, Debug)] +pub enum RecvError { + #[error("Text message contains invalid UTF8: {0}")] + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("Stream finished")] + StreamFinished, +} + +impl Stream for Receiver { + type Item = Result; + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_next_unpin(cx).map_err(|e| e.into()) + } +} + +/// A message that can be received from the channel interface +#[derive(Debug, Clone)] +pub enum RecvMessage { + /// Send an owned string into the socket. + Text(String), + /// Send owned bytes into the socket. + Binary(Vec), +} + +impl RecvMessage { + pub fn len(&self) -> usize { + match self { + RecvMessage::Binary(b) => b.len(), + RecvMessage::Text(s) => s.len(), + } + } +} \ No newline at end of file diff --git a/backend/common/src/ws_client/sender.rs b/backend/common/src/ws_client/sender.rs new file mode 100644 index 0000000..09a7c86 --- /dev/null +++ b/backend/common/src/ws_client/sender.rs @@ -0,0 +1,87 @@ +use futures::channel::mpsc; +use futures::{Sink, SinkExt}; + +/// A message that can be sent into the channel interface +#[derive(Debug, Clone)] +pub enum SentMessage { + /// Being able to send static text is primarily useful for benchmarking, + /// so that we can avoid cloning an owned string and pass a static reference + /// (one such option here is using [`Box::leak`] to generate strings with + /// static lifetimes). + StaticText(&'static str), + /// Being able to send static bytes is primarily useful for benchmarking, + /// so that we can avoid cloning an owned string and pass a static reference + /// (one such option here is using [`Box::leak`] to generate bytes with + /// static lifetimes). + StaticBinary(&'static [u8]), + /// Send an owned string into the socket. + Text(String), + /// Send owned bytes into the socket. + Binary(Vec), +} + +/// Messages sent into the channel interface can be anything publically visible, or a close message. +#[derive(Debug, Clone)] +pub (super) enum SentMessageInternal { + Message(SentMessage), + Close, +} + +/// Send messages into the connection +#[derive(Clone)] +pub struct Sender { + pub (super) inner: mpsc::UnboundedSender, +} + +impl Sender { + /// Ask the underlying Websocket connection to close. + pub async fn close(&mut self) -> Result<(), SendError> { + self.inner.send(SentMessageInternal::Close).await?; + Ok(()) + } + /// Returns whether this channel is closed. + pub fn is_closed(&mut self) -> bool { + self.inner.is_closed() + } + /// Unbounded send will always queue the message and doesn't + /// need to be awaited. + pub fn unbounded_send(&self, msg: SentMessage) -> Result<(), SendError> { + self.inner + .unbounded_send(SentMessageInternal::Message(msg)) + .map_err(|e| e.into_send_error())?; + Ok(()) + } +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum SendError { + #[error("Failed to send message: {0}")] + ChannelError(#[from] mpsc::SendError) +} + +impl Sink for Sender { + type Error = SendError; + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready_unpin(cx).map_err(|e| e.into()) + } + fn start_send(mut self: std::pin::Pin<&mut Self>, item: SentMessage) -> Result<(), Self::Error> { + self.inner + .start_send_unpin(SentMessageInternal::Message(item)) + .map_err(|e| e.into()) + } + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_flush_unpin(cx).map_err(|e| e.into()) + } + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_close_unpin(cx).map_err(|e| e.into()) + } +} \ No newline at end of file diff --git a/backend/telemetry_shard/src/connection.rs b/backend/telemetry_shard/src/connection.rs index edece48..d058ccd 100644 --- a/backend/telemetry_shard/src/connection.rs +++ b/backend/telemetry_shard/src/connection.rs @@ -41,7 +41,8 @@ where // to/from the core. If the external channels break, we end for good. If the internal // channels break, we loop around and try connecting again. match ws_client::connect(&telemetry_uri).await { - Ok((tx_to_core, mut rx_from_core)) => { + Ok(connection) => { + let (tx_to_core, mut rx_from_core) = connection.into_channels(); is_connected = true; let mut tx_out = tx_out.clone(); diff --git a/backend/test_utils/src/server/server.rs b/backend/test_utils/src/server/server.rs index 97d94eb..e4436a6 100644 --- a/backend/test_utils/src/server/server.rs +++ b/backend/test_utils/src/server/server.rs @@ -339,18 +339,26 @@ impl Process { } } +/// Establish a raw WebSocket connection (not cancel-safe) +async fn connect_to_uri_raw(uri: &http::Uri) -> Result<(ws_client::RawSender, ws_client::RawReceiver), Error> { + ws_client::connect(uri) + .await + .map(|c| c.into_raw()) + .map_err(|e| e.into()) +} + impl, Recv: From> Process<(Send, Recv)> { /// Establish a connection to the process - async fn connect_to_uri(&self, uri: &http::Uri) -> Result<(Send, Recv), Error> { + async fn connect_to_uri(uri: &http::Uri) -> Result<(Send, Recv), Error> { ws_client::connect(uri) .await + .map(|c| c.into_channels()) .map(|(s, r)| (s.into(), r.into())) .map_err(|e| e.into()) } /// Establish multiple connections to the process async fn connect_multiple_to_uri( - &self, uri: &http::Uri, num_connections: usize, ) -> Result, Error> { @@ -362,30 +370,42 @@ impl, Recv: From> Process<(Se } impl ShardProcess { + /// Establish a raw connection to the process + pub async fn connect_node_raw(&self) -> Result<(ws_client::RawSender, ws_client::RawReceiver), Error> { + let uri = format!("http://{}/submit", self.host).parse()?; + connect_to_uri_raw(&uri).await + } + /// Establish a connection to the process pub async fn connect_node(&self) -> Result<(channels::ShardSender, channels::ShardReceiver), Error> { let uri = format!("http://{}/submit", self.host).parse()?; - self.connect_to_uri(&uri).await + Process::connect_to_uri(&uri).await } /// Establish multiple connections to the process pub async fn connect_multiple_nodes(&self, num_connections: usize) -> Result, Error> { let uri = format!("http://{}/submit", self.host).parse()?; - self.connect_multiple_to_uri(&uri, num_connections).await + Process::connect_multiple_to_uri(&uri, num_connections).await } } impl CoreProcess { + /// Establish a raw connection to the process + pub async fn connect_feed_raw(&self) -> Result<(ws_client::RawSender, ws_client::RawReceiver), Error> { + let uri = format!("http://{}/feed", self.host).parse()?; + connect_to_uri_raw(&uri).await + } + /// Establish a connection to the process pub async fn connect_feed(&self) -> Result<(channels::FeedSender, channels::FeedReceiver), Error> { let uri = format!("http://{}/feed", self.host).parse()?; - self.connect_to_uri(&uri).await + Process::connect_to_uri(&uri).await } /// Establish multiple connections to the process pub async fn connect_multiple_feeds(&self, num_connections: usize) -> Result, Error> { let uri = format!("http://{}/feed", self.host).parse()?; - self.connect_multiple_to_uri(&uri, num_connections).await + Process::connect_multiple_to_uri(&uri, num_connections).await } } diff --git a/backend/test_utils/src/server/utils.rs b/backend/test_utils/src/server/utils.rs index 522d490..e2ff5dd 100644 --- a/backend/test_utils/src/server/utils.rs +++ b/backend/test_utils/src/server/utils.rs @@ -79,7 +79,7 @@ pub async fn connect_multiple_to_uri( // (Side note: on Ubuntu the concurrency seemed to be no issue up to at least 10k connections). let mut sockets = vec![]; for _ in 0..num_connections { - sockets.push(ws_client::connect(uri).await?); + sockets.push(ws_client::connect(uri).await?.into_channels()); } Ok(sockets) }