diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 4f87f91..a048be6 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -1,5 +1,14 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +[[package]] +name = "aho-corasick" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +dependencies = [ + "memchr", +] + [[package]] name = "ansi_term" version = "0.11.0" @@ -1319,6 +1328,8 @@ version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" dependencies = [ + "aho-corasick", + "memchr", "regex-syntax", ] @@ -1537,6 +1548,16 @@ dependencies = [ "opaque-debug", ] +[[package]] +name = "shellwords" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89e515aa4699a88148ed5ef96413ceef0048ce95b43fbc955a33bde0a70fcae6" +dependencies = [ + "lazy_static", + "regex", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -1669,6 +1690,7 @@ dependencies = [ "rustc-hash", "serde", "serde_json", + "shellwords", "simple_logger", "smallvec", "soketto", diff --git a/backend/telemetry_core/Cargo.toml b/backend/telemetry_core/Cargo.toml index a1acc6c..a8102ce 100644 --- a/backend/telemetry_core/Cargo.toml +++ b/backend/telemetry_core/Cargo.toml @@ -32,8 +32,9 @@ warp = "0.3.1" [dev-dependencies] criterion = { version = "0.3.4", features = ["async", "async_tokio"] } +shellwords = "1.1.0" test_utils = { path = "../test_utils" } [[bench]] name = "throughput" -harness = false \ No newline at end of file +harness = false diff --git a/backend/telemetry_core/benches/throughput.rs b/backend/telemetry_core/benches/throughput.rs index 7c98af4..1b796a4 100644 --- a/backend/telemetry_core/benches/throughput.rs +++ b/backend/telemetry_core/benches/throughput.rs @@ -8,6 +8,7 @@ use serde_json::json; use common::node_types::BlockHash; pub fn benchmark_throughput_single_shard(c: &mut Criterion) { + /* let rt = Runtime::new().expect("tokio runtime should start"); // Setup our server and node/feed connections first: @@ -90,6 +91,7 @@ println!("feed consumed"); }) ); + */ } criterion_group!(benches, benchmark_throughput_single_shard); diff --git a/backend/telemetry_core/tests/e2e_tests.rs b/backend/telemetry_core/tests/e2e_tests.rs index 5678deb..16abbff 100644 --- a/backend/telemetry_core/tests/e2e_tests.rs +++ b/backend/telemetry_core/tests/e2e_tests.rs @@ -1,7 +1,8 @@ +//! General end-to-end tests + use common::node_types::BlockHash; -use futures::StreamExt; use serde_json::json; -use std::{iter::FromIterator, time::Duration}; +use std::time::Duration; use test_utils::{ assert_contains_matches, feed_message_de::{FeedMessage, NodeDetails}, @@ -39,7 +40,7 @@ async fn feed_ping_responded_to_with_pong() { let (mut feed_tx, mut feed_rx) = server.get_core().connect().await.unwrap(); // Ping it: - feed_tx.send_command("ping", "hello!").await.unwrap(); + feed_tx.send_command("ping", "hello!").unwrap(); // Expect a pong response: let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); @@ -127,7 +128,7 @@ async fn lots_of_mute_messages_dont_cause_a_deadlock() { "startup_time":"1625565542717", "version":"2.0.0-07a1af348-aarch64-macos" } - })).await.unwrap(); + })).unwrap(); } // Wait a little time (just to let everything get deadlocked) before @@ -191,7 +192,6 @@ async fn feed_add_and_remove_node() { }, } )) - .await .unwrap(); // Wait a little for this message to propagate to the core @@ -252,11 +252,11 @@ async fn feeds_told_about_chain_rename_and_stay_subscribed() { }); // Subscribe a chain: - node_tx.send_json_text(node_init_msg(1, "Initial chain name", "Node 1")).await.unwrap(); + node_tx.send_json_text(node_init_msg(1, "Initial chain name", "Node 1")).unwrap(); // Connect a feed and subscribe to the above chain: let (mut feed_tx, mut feed_rx) = server.get_core().connect().await.unwrap(); - feed_tx.send_command("subscribe", "Initial chain name").await.unwrap(); + feed_tx.send_command("subscribe", "Initial chain name").unwrap(); // Feed is told about the chain, and the node on this chain: let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); @@ -269,7 +269,7 @@ async fn feeds_told_about_chain_rename_and_stay_subscribed() { // Subscribe another node. The chain doesn't rename yet but we are told about the new node // count and the node that's been added. - node_tx.send_json_text(node_init_msg(2, "New chain name", "Node 2")).await.unwrap(); + node_tx.send_json_text(node_init_msg(2, "New chain name", "Node 2")).unwrap(); let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); assert_contains_matches!( feed_messages, @@ -279,7 +279,7 @@ async fn feeds_told_about_chain_rename_and_stay_subscribed() { // Subscribe a third node. The chain renames, so we're told about the new node but also // about the chain rename. - node_tx.send_json_text(node_init_msg(3, "New chain name", "Node 3")).await.unwrap(); + node_tx.send_json_text(node_init_msg(3, "New chain name", "Node 3")).unwrap(); let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); assert_contains_matches!( feed_messages, @@ -290,7 +290,7 @@ async fn feeds_told_about_chain_rename_and_stay_subscribed() { // Just to be sure, subscribing a fourth node on this chain will still lead to updates // to this feed. - node_tx.send_json_text(node_init_msg(4, "New chain name", "Node 4")).await.unwrap(); + node_tx.send_json_text(node_init_msg(4, "New chain name", "Node 4")).unwrap(); let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); assert_contains_matches!( feed_messages, @@ -338,7 +338,6 @@ async fn feed_add_and_remove_shard() { "version":"2.0.0-07a1af348-aarch64-macos" }, })) - .await .unwrap(); // Keep what we need to to keep connection alive and let us kill a shard: @@ -410,7 +409,6 @@ async fn feed_can_subscribe_and_unsubscribe_from_chain() { }, } )) - .await .unwrap(); } @@ -421,10 +419,7 @@ async fn feed_can_subscribe_and_unsubscribe_from_chain() { assert_contains_matches!(feed_messages, AddedChain { name, node_count: 1 } if name == "Local Testnet 1"); // Subscribe it to a chain - feed_tx - .send_command("subscribe", "Local Testnet 1") - .await - .unwrap(); + feed_tx.send_command("subscribe", "Local Testnet 1").unwrap(); let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); assert_contains_matches!( @@ -440,7 +435,7 @@ async fn feed_can_subscribe_and_unsubscribe_from_chain() { // We receive updates relating to nodes on that chain: node_tx.send_json_text(json!( {"id":1, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:37:48.330433+01:00" } - )).await.unwrap(); + )).unwrap(); let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); assert_ne!(feed_messages.len(), 0); @@ -448,17 +443,14 @@ async fn feed_can_subscribe_and_unsubscribe_from_chain() { // We don't receive anything for updates to nodes on other chains (wait a sec to ensure no messages are sent): node_tx.send_json_text(json!( {"id":2, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:37:48.330433+01:00" } - )).await.unwrap(); + )).unwrap(); tokio::time::timeout(Duration::from_secs(1), feed_rx.recv_feed_messages()) .await .expect_err("Timeout should elapse since no messages sent"); // We can change our subscription: - feed_tx - .send_command("subscribe", "Local Testnet 2") - .await - .unwrap(); + feed_tx.send_command("subscribe", "Local Testnet 2").unwrap(); let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); // We are told about the subscription change and given similar on-subscribe messages to above. @@ -476,7 +468,7 @@ async fn feed_can_subscribe_and_unsubscribe_from_chain() { // We didn't get messages from this earlier, but we will now we've subscribed: node_tx.send_json_text(json!( {"id":2, "payload":{ "bandwidth_download":576,"bandwidth_upload":576,"msg":"system.interval","peers":1},"ts":"2021-07-12T10:38:48.330433+01:00" } - )).await.unwrap(); + )).unwrap(); let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); assert_ne!(feed_messages.len(), 0); diff --git a/backend/telemetry_core/tests/soak_tests.rs b/backend/telemetry_core/tests/soak_tests.rs new file mode 100644 index 0000000..4bd6a2c --- /dev/null +++ b/backend/telemetry_core/tests/soak_tests.rs @@ -0,0 +1,180 @@ +/*! +Soak tests. These are ignored by default, and are intended to be long runs +of the core + shards(s) under different loads to get a feel for CPU/memory +usage and general performance over time. + +Note that on MacOS inparticular, you may need to increase some limits to be +able to open a large number of connections. Try commands like: + +```sh +sudo sysctl -w kern.maxfiles=50000 +sudo sysctl -w kern.maxfilesperproc=50000 +ulimit -n 50000 +sudo sysctl -w kern.ipc.somaxconn=50000 +``` +*/ + +use futures::{ StreamExt }; +use structopt::StructOpt; +use test_utils::workspace::start_server_release; +use test_utils::ws_client::{ SentMessage }; +use serde_json::json; +use std::time::Duration; +use std::sync::atomic::{ Ordering, AtomicUsize }; +use std::sync::Arc; +use common::node_types::BlockHash; + +/// A configurable soak_test runner. Configure by providing the expected args as +/// an environment variable. One example to run this test is: +/// +/// ```sh +/// SOAK_TEST_ARGS='--feeds 10 --nodes 100 --shards 4' cargo test -- soak_test --ignored --nocapture +/// ``` +#[ignore] +#[tokio::test] +pub async fn soak_test() { + let opts = get_soak_test_opts(); + run_soak_test(opts).await; +} + +/// The general soak test runner. This is called by tests. +async fn run_soak_test(opts: SoakTestOpts) { + let mut server = start_server_release().await; + + // Start up the shards we requested: + let mut shard_ids = vec![]; + for _ in 0..opts.shards { + let shard_id = server.add_shard().await.expect("shard can't be added"); + shard_ids.push(shard_id); + } + + // Connect nodes to each shard: + let mut nodes = vec![]; + for &shard_id in &shard_ids { + let mut conns = server + .get_shard(shard_id) + .unwrap() + .connect_multiple(opts.nodes) + .await + .expect("node connections failed"); + nodes.append(&mut conns); + } + + // Each node tells the shard about itself: + for (idx, (node_tx, _)) in nodes.iter_mut().enumerate() { + node_tx.send_json_binary(json!({ + "id":1, // Only needs to be unique per node + "ts":"2021-07-12T10:37:47.714666+01:00", + "payload": { + "authority":true, + "chain": "Test Chain", + "config":"", + "genesis_hash": BlockHash::from_low_u64_ne(1), + "implementation":"Substrate Node", + "msg":"system.connected", + "name": format!("Node #{}", idx), + "network_id":"12D3KooWEyoppNCUx8Yx66oV9fJnriXwCcXwDDUA2kj6vnc6iDEp", + "startup_time":"1625565542717", + "version":"2.0.0-07a1af348-aarch64-macos" + }, + })).unwrap(); + } + + // Connect feeds to the core: + let mut feeds = server + .get_core() + .connect_multiple(opts.feeds) + .await + .expect("feed connections failed"); + + // Every feed subscribes to the chain above to recv messages about it: + for (feed_tx, _) in &mut feeds { + feed_tx.send_command("subscribe", "Test Chain").unwrap(); + } + + // Start sending "update" messages from nodes at time intervals. + let send_handle = tokio::task::spawn(async move { + loop { + let msg = json!({ + "id":1, + "payload":{ + "bandwidth_download":576, + "bandwidth_upload":576, + "msg":"system.interval", + "peers":1 + }, + "ts":"2021-07-12T10:37:48.330433+01:00" + }); + let msg_bytes = serde_json::to_vec(&msg).unwrap(); + for (node_tx, _) in &mut nodes { + node_tx.unbounded_send(SentMessage::Binary(msg_bytes.clone())).unwrap(); + } + tokio::time::sleep(Duration::from_millis(500)).await; + } + }); + + // Also start receiving messages, counting the bytes received so far. + let bytes_out = Arc::new(AtomicUsize::new(0)); + for (_, mut feed_rx) in feeds { + let bytes_out = bytes_out.clone(); + tokio::task::spawn(async move { + while let Some(msg) = feed_rx.next().await { + let msg = msg.expect("message coule be received"); + let num_bytes = msg.len(); + bytes_out.fetch_add(num_bytes, Ordering::Relaxed); + } + }); + } + + // Periodically report on bytes out + tokio::task::spawn(async move { + let mut last_bytes = 0; + let mut last_now = std::time::Instant::now(); + loop { + tokio::time::sleep(Duration::from_secs(5)).await; + + let curr_now = std::time::Instant::now(); + let curr_bytes_out = bytes_out.load(Ordering::Relaxed); + let secs_elapsed = (curr_now - last_now).as_secs_f64(); + let kbps: f64 = (curr_bytes_out - last_bytes) as f64 / 1024.0 / secs_elapsed; + + println!("output kbps: ~{}", kbps); + + last_bytes = curr_bytes_out; + last_now = curr_now; + } + }); + + // Wait for sending to finish before ending. + send_handle.await.unwrap(); +} + +/// General arguments that are used to start a soak test. Run `soak_test` as +/// instructed by its documentation for full control over what is ran, or run +/// preconfigured variants. +#[derive(StructOpt, Debug)] +struct SoakTestOpts { + /// The number of shards to run this test with + #[structopt(long)] + shards: usize, + /// The number of feeds to connect + #[structopt(long)] + feeds: usize, + /// The number of nodes to connect to each feed + #[structopt(long)] + nodes: usize +} + +/// Get soak test args from an envvar and parse them via structopt. +fn get_soak_test_opts() -> SoakTestOpts { + let arg_string = std::env::var("SOAK_TEST_ARGS") + .expect("Expecting args to be provided in the env var SOAK_TEST_ARGS"); + let args = shellwords::split(&arg_string) + .expect("Could not parse SOAK_TEST_ARGS as shell arguments"); + + // The binary name is expected to be the first arg, so fake it: + let all_args = std::iter::once("soak_test".to_owned()) + .chain(args.into_iter()); + + SoakTestOpts::from_iter(all_args) +} \ No newline at end of file diff --git a/backend/test_utils/src/server/channels.rs b/backend/test_utils/src/server/channels.rs index 7070553..2490544 100644 --- a/backend/test_utils/src/server/channels.rs +++ b/backend/test_utils/src/server/channels.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{ops::{Deref, DerefMut}, time::Duration}; use crate::feed_message_de::FeedMessage; use crate::ws_client; @@ -13,14 +13,7 @@ impl From for ShardSender { } } -impl ShardSender { - /// Close this connection - pub async fn close(&mut self) -> Result<(), ws_client::SendError> { - self.0.close().await - } -} - -impl Sink for ShardSender { +impl Sink for ShardSender { type Error = ws_client::SendError; fn poll_ready( mut self: std::pin::Pin<&mut Self>, @@ -30,7 +23,7 @@ impl Sink for ShardSender { } fn start_send( mut self: std::pin::Pin<&mut Self>, - item: ws_client::Message, + item: ws_client::SentMessage, ) -> Result<(), Self::Error> { self.0.start_send_unpin(item) } @@ -49,19 +42,33 @@ impl Sink for ShardSender { } impl ShardSender { - pub async fn send_json_binary( + /// Send JSON as a binary websocket message + pub fn send_json_binary( &mut self, json: serde_json::Value, ) -> Result<(), ws_client::SendError> { let bytes = serde_json::to_vec(&json).expect("valid bytes"); - self.send(ws_client::Message::Binary(bytes)).await + self.unbounded_send(ws_client::SentMessage::Binary(bytes)) } - pub async fn send_json_text( + /// Send JSON as a textual websocket message + pub fn send_json_text( &mut self, json: serde_json::Value, ) -> Result<(), ws_client::SendError> { let s = serde_json::to_string(&json).expect("valid string"); - self.send(ws_client::Message::Text(s)).await + self.unbounded_send(ws_client::SentMessage::Text(s)) + } +} + +impl Deref for ShardSender { + type Target = ws_client::Sender; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for ShardSender { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -75,7 +82,7 @@ impl From for ShardReceiver { } impl Stream for ShardReceiver { - type Item = Result; + type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -84,6 +91,18 @@ impl Stream for ShardReceiver { } } +impl Deref for ShardReceiver { + type Target = ws_client::Receiver; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for ShardReceiver { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + /// Wrap a `ws_client::Sender` with convenient utility methods for feed connections pub struct FeedSender(ws_client::Sender); @@ -93,7 +112,7 @@ impl From for FeedSender { } } -impl Sink for FeedSender { +impl Sink for FeedSender { type Error = ws_client::SendError; fn poll_ready( mut self: std::pin::Pin<&mut Self>, @@ -103,7 +122,7 @@ impl Sink for FeedSender { } fn start_send( mut self: std::pin::Pin<&mut Self>, - item: ws_client::Message, + item: ws_client::SentMessage, ) -> Result<(), Self::Error> { self.0.start_send_unpin(item) } @@ -121,21 +140,36 @@ impl Sink for FeedSender { } } +impl Deref for FeedSender { + type Target = ws_client::Sender; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FeedSender { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl FeedSender { - pub async fn send_command>( + /// Send a command into the feed. A command consists of a string + /// "command" part, and another string "parameter" part. + pub fn send_command>( &mut self, command: S, param: S, ) -> Result<(), ws_client::SendError> { - self.send(ws_client::Message::Text(format!( + self.unbounded_send(ws_client::SentMessage::Text(format!( "{}:{}", command.as_ref(), param.as_ref() ))) - .await } } + /// Wrap a `ws_client::Receiver` with convenient utility methods for feed connections pub struct FeedReceiver(ws_client::Receiver); @@ -146,7 +180,7 @@ impl From for FeedReceiver { } impl Stream for FeedReceiver { - type Item = Result; + type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -155,6 +189,18 @@ impl Stream for FeedReceiver { } } +impl Deref for FeedReceiver { + type Target = ws_client::Receiver; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for FeedReceiver { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl FeedReceiver { /// Wait for the next set of feed messages to arrive. Returns an error if the connection /// is closed, or the messages that come back cannot be properly decoded. @@ -170,14 +216,14 @@ impl FeedReceiver { .ok_or_else(|| anyhow::anyhow!("Stream closed: no more messages"))??; match msg { - ws_client::Message::Binary(data) => { + ws_client::RecvMessage::Binary(data) => { let messages = FeedMessage::from_bytes(&data)?; Ok(messages) - } - ws_client::Message::Text(text) => { + }, + ws_client::RecvMessage::Text(text) => { let messages = FeedMessage::from_bytes(text.as_bytes())?; Ok(messages) - } + }, } } diff --git a/backend/test_utils/src/server/server.rs b/backend/test_utils/src/server/server.rs index f758d37..d70a12e 100644 --- a/backend/test_utils/src/server/server.rs +++ b/backend/test_utils/src/server/server.rs @@ -146,7 +146,7 @@ impl Server { // Since we're piping stdout from the child process, we need somewhere for it to go // else the process will get stuck when it tries to produce output: - utils::drain(child_stdout, tokio::io::sink()); + utils::drain(child_stdout, tokio::io::stderr()); let shard_uri = format!("http://127.0.0.1:{}/submit", shard_port) .parse() @@ -184,7 +184,7 @@ impl Server { // Since we're piping stdout from the child process, we need somewhere for it to go // else the process will get stuck when it tries to produce output: - utils::drain(child_stdout, tokio::io::sink()); + utils::drain(child_stdout, tokio::io::stderr()); // URI for feeds to connect to the core: let feed_uri = format!("http://127.0.0.1:{}/feed", core_port) diff --git a/backend/test_utils/src/ws_client.rs b/backend/test_utils/src/ws_client.rs index ed5e5d4..7de9363 100644 --- a/backend/test_utils/src/ws_client.rs +++ b/backend/test_utils/src/ws_client.rs @@ -7,28 +7,36 @@ use tokio_util::compat::TokioAsyncReadCompatExt; /// Send messages into the connection #[derive(Clone)] pub struct Sender { - inner: mpsc::UnboundedSender, + inner: mpsc::UnboundedSender, } impl Sender { /// Ask the underlying Websocket connection to close. pub async fn close(&mut self) -> Result<(), SendError> { - self.inner.send(SentMessage::Close).await?; + self.inner.send(SentMessageInternal::Close).await?; Ok(()) } /// Returns whether this channel is closed. pub fn is_closed(&mut self) -> bool { self.inner.is_closed() } + /// Unbounded send will always queue the message and doesn't + /// need to be awaited. + pub fn unbounded_send(&self, msg: SentMessage) -> Result<(), SendError> { + self.inner + .unbounded_send(SentMessageInternal::Message(msg)) + .map_err(|e| e.into_send_error())?; + Ok(()) + } } #[derive(thiserror::Error, Debug, Clone)] pub enum SendError { #[error("Failed to send message: {0}")] - ChannelError(#[from] mpsc::SendError), + ChannelError(#[from] mpsc::SendError) } -impl Sink for Sender { +impl Sink for Sender { type Error = SendError; fn poll_ready( mut self: std::pin::Pin<&mut Self>, @@ -36,9 +44,9 @@ impl Sink for Sender { ) -> std::task::Poll> { self.inner.poll_ready_unpin(cx).map_err(|e| e.into()) } - fn start_send(mut self: std::pin::Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + fn start_send(mut self: std::pin::Pin<&mut Self>, item: SentMessage) -> Result<(), Self::Error> { self.inner - .start_send_unpin(SentMessage::Message(item)) + .start_send_unpin(SentMessageInternal::Message(item)) .map_err(|e| e.into()) } fn poll_flush( @@ -57,7 +65,7 @@ impl Sink for Sender { /// Receive messages out of a connection pub struct Receiver { - inner: mpsc::UnboundedReceiver>, + inner: mpsc::UnboundedReceiver>, } #[derive(thiserror::Error, Debug)] @@ -69,7 +77,7 @@ pub enum RecvError { } impl Stream for Receiver { - type Item = Result; + type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -78,15 +86,47 @@ impl Stream for Receiver { } } -/// A message type that can be sent or received from the connection -pub enum Message { +/// A message that can be received from the connection +#[derive(Debug, Clone)] +pub enum RecvMessage { + /// Send an owned string into the socket. Text(String), + /// Send owned bytes into the socket. + Binary(Vec), +} + +impl RecvMessage { + pub fn len(&self) -> usize { + match self { + RecvMessage::Binary(b) => b.len(), + RecvMessage::Text(s) => s.len(), + } + } +} + +/// A message that can be sent into the connection +#[derive(Debug, Clone)] +pub enum SentMessage { + /// Being able to send static text is primarily useful for benchmarking, + /// so that we can avoid cloning an owned string and pass a static reference + /// (one such option here is using [`Box::leak`] to generate strings with + /// static lifetimes). + StaticText(&'static str), + /// Being able to send static bytes is primarily useful for benchmarking, + /// so that we can avoid cloning an owned string and pass a static reference + /// (one such option here is using [`Box::leak`] to generate bytes with + /// static lifetimes). + StaticBinary(&'static [u8]), + /// Send an owned string into the socket. + Text(String), + /// Send owned bytes into the socket. Binary(Vec), } /// Sent messages can be anything publically visible, or a close message. -enum SentMessage { - Message(Message), +#[derive(Debug, Clone)] +enum SentMessageInternal { + Message(SentMessage), Close, } @@ -151,9 +191,9 @@ pub async fn connect(uri: &http::Uri) -> Result<(Sender, Receiver), ConnectError }; let msg = match message_data { - soketto::Data::Text(_) => Ok(Message::Binary(data)), + soketto::Data::Text(_) => Ok(RecvMessage::Binary(data)), soketto::Data::Binary(_) => String::from_utf8(data) - .map(|s| Message::Text(s)) + .map(|s| RecvMessage::Text(s)) .map_err(|e| e.into()), }; @@ -176,7 +216,7 @@ pub async fn connect(uri: &http::Uri) -> Result<(Sender, Receiver), ConnectError tokio::spawn(async move { while let Some(msg) = rx_from_external.next().await { match msg { - SentMessage::Message(Message::Text(s)) => { + SentMessageInternal::Message(SentMessage::Text(s)) => { if let Err(e) = ws_to_connection.send_text_owned(s).await { log::error!( "Shutting down websocket connection: Failed to send text data: {}", @@ -185,7 +225,7 @@ pub async fn connect(uri: &http::Uri) -> Result<(Sender, Receiver), ConnectError break; } } - SentMessage::Message(Message::Binary(bytes)) => { + SentMessageInternal::Message(SentMessage::Binary(bytes)) => { if let Err(e) = ws_to_connection.send_binary_mut(bytes).await { log::error!( "Shutting down websocket connection: Failed to send binary data: {}", @@ -193,8 +233,26 @@ pub async fn connect(uri: &http::Uri) -> Result<(Sender, Receiver), ConnectError ); break; } + }, + SentMessageInternal::Message(SentMessage::StaticText(s)) => { + if let Err(e) = ws_to_connection.send_text(s).await { + log::error!( + "Shutting down websocket connection: Failed to send text data: {}", + e + ); + break; + } } - SentMessage::Close => { + SentMessageInternal::Message(SentMessage::StaticBinary(bytes)) => { + if let Err(e) = ws_to_connection.send_binary(bytes).await { + log::error!( + "Shutting down websocket connection: Failed to send binary data: {}", + e + ); + break; + } + }, + SentMessageInternal::Close => { if let Err(e) = ws_to_connection.close().await { log::error!("Error attempting to close connection: {}", e); break;