diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 7c52f7c..5167e55 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -1259,29 +1259,6 @@ dependencies = [ "opaque-debug", ] -[[package]] -name = "shard" -version = "0.1.0" -dependencies = [ - "anyhow", - "bincode", - "common", - "futures", - "hex", - "http", - "log", - "primitive-types", - "serde", - "serde_json", - "simple_logger", - "soketto", - "structopt", - "thiserror", - "tokio", - "tokio-util", - "warp", -] - [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -1395,7 +1372,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] -name = "telemetry" +name = "telemetry_core" version = "0.1.0" dependencies = [ "anyhow", @@ -1417,6 +1394,30 @@ dependencies = [ "smallvec", "soketto", "structopt", + "test_utils", + "thiserror", + "tokio", + "tokio-util", + "warp", +] + +[[package]] +name = "telemetry_shard" +version = "0.1.0" +dependencies = [ + "anyhow", + "bincode", + "common", + "futures", + "hex", + "http", + "log", + "primitive-types", + "serde", + "serde_json", + "simple_logger", + "soketto", + "structopt", "thiserror", "tokio", "tokio-util", @@ -1437,6 +1438,22 @@ dependencies = [ "winapi", ] +[[package]] +name = "test_utils" +version = "0.1.0" +dependencies = [ + "anyhow", + "common", + "futures", + "http", + "log", + "serde_json", + "soketto", + "thiserror", + "tokio", + "tokio-util", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -1493,9 +1510,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c79ba603c337335df6ba6dd6afc38c38a7d5e1b0c871678439ea973cd62a118e" +checksum = "570c2eb13b3ab38208130eccd41be92520388791207fde783bda7c1e8ace28d4" dependencies = [ "autocfg", "bytes", diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 167ab3d..7061f61 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -1,8 +1,9 @@ [workspace] members = [ "common", - "telemetry", - "shard" + "telemetry_core", + "telemetry_shard", + "test_utils" ] [profile.dev] diff --git a/backend/README.md b/backend/README.md index d82c6eb..df35545 100644 --- a/backend/README.md +++ b/backend/README.md @@ -2,8 +2,8 @@ This folder contains the rust crates and documentation specific to the telemetry backend. A description of the folders: -- [telemetry](./telemetry): The Telemetry Core. This aggregates data received from shards and allows UI feeds to connect and receive this information. -- [shard](./shard): A Shard. It's expected that multiple of these will run. Nodes will connect to Shard instances and send JSON telemetry to them, and Shard instances will each connect to the Telemetry Core and relay on relevant data to it. +- [telemetry_core](./telemetry_core): The Telemetry Core. This aggregates data received from shards and allows UI feeds to connect and receive this information. +- [telemetry_shard](./telemetry_shard): A Shard. It's expected that multiple of these will run. Nodes will connect to Shard instances and send JSON telemetry to them, and Shard instances will each connect to the Telemetry Core and relay on relevant data to it. - [common](./common): common code shared between the telemetry shard and core - [docs](./docs): Material supporting the documentation lives here diff --git a/backend/common/src/dense_map.rs b/backend/common/src/dense_map.rs index edc267f..a631cff 100644 --- a/backend/common/src/dense_map.rs +++ b/backend/common/src/dense_map.rs @@ -88,6 +88,13 @@ where .filter_map(|(id, item)| Some((id.into(), item.as_mut()?))) } + pub fn into_iter(self) -> impl Iterator { + self.items + .into_iter() + .enumerate() + .filter_map(|(id, item)| Some((id.into(), item?))) + } + pub fn len(&self) -> usize { self.items.len() - self.retired.len() } diff --git a/backend/common/src/node_types.rs b/backend/common/src/node_types.rs index d4eecfd..4cf3c49 100644 --- a/backend/common/src/node_types.rs +++ b/backend/common/src/node_types.rs @@ -10,6 +10,7 @@ pub type BlockNumber = u64; pub type Timestamp = u64; pub use primitive_types::H256 as BlockHash; +/// #[derive(Serialize, Deserialize, Debug, Clone)] pub struct NodeDetails { pub chain: Box, @@ -21,18 +22,58 @@ pub struct NodeDetails { pub startup_time: Option>, } -#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] + +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub struct NodeStats { pub peers: u64, pub txcount: u64, } +impl Serialize for NodeStats { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut tup = serializer.serialize_tuple(2)?; + tup.serialize_element(&self.peers)?; + tup.serialize_element(&self.txcount)?; + tup.end() + } +} + +impl <'de> Deserialize<'de> for NodeStats { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let (peers, txcount) = <(u64, u64)>::deserialize(deserializer)?; + Ok(NodeStats { peers, txcount }) + } +} + + +/// #[derive(Default)] pub struct NodeIO { pub used_state_cache_size: MeanList, } -#[derive(Deserialize, Serialize, Debug, Clone, Copy)] +impl Serialize for NodeIO { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut tup = serializer.serialize_tuple(1)?; + // This is "one-way": we can't deserialize again from this to a MeanList: + tup.serialize_element(self.used_state_cache_size.slice())?; + tup.end() + } +} + + +/// +#[derive(Deserialize, Serialize, Debug, Clone, Copy, PartialEq)] pub struct Block { pub hash: BlockHash, pub height: BlockNumber, @@ -47,7 +88,67 @@ impl Block { } } -#[derive(Debug, Clone, Copy)] + +/// +#[derive(Default)] +pub struct NodeHardware { + /// Upload uses means + pub upload: MeanList, + /// Download uses means + pub download: MeanList, + /// Stampchange uses means + pub chart_stamps: MeanList, +} + +impl Serialize for NodeHardware { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut tup = serializer.serialize_tuple(3)?; + // These are "one-way": we can't deserialize again from them to MeanLists: + tup.serialize_element(self.upload.slice())?; + tup.serialize_element(self.download.slice())?; + tup.serialize_element(self.chart_stamps.slice())?; + tup.end() + } +} + + +/// +#[derive(Debug, Clone, PartialEq)] +pub struct NodeLocation { + pub latitude: f32, + pub longitude: f32, + pub city: Box, +} + +impl Serialize for NodeLocation { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut tup = serializer.serialize_tuple(3)?; + tup.serialize_element(&self.latitude)?; + tup.serialize_element(&self.longitude)?; + tup.serialize_element(&&*self.city)?; + tup.end() + } +} + +impl <'de> Deserialize<'de> for NodeLocation { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let (latitude, longitude, city) = <(f32, f32, Box)>::deserialize(deserializer)?; + Ok(NodeLocation { latitude, longitude, city }) + } +} + + +/// +#[derive(Debug, Clone, Copy, PartialEq)] pub struct BlockDetails { pub block: Block, pub block_time: u64, @@ -66,46 +167,6 @@ impl Default for BlockDetails { } } -#[derive(Default)] -pub struct NodeHardware { - /// Upload uses means - pub upload: MeanList, - /// Download uses means - pub download: MeanList, - /// Stampchange uses means - pub chart_stamps: MeanList, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct NodeLocation { - pub latitude: f32, - pub longitude: f32, - pub city: Box, -} - -impl Serialize for NodeStats { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut tup = serializer.serialize_tuple(2)?; - tup.serialize_element(&self.peers)?; - tup.serialize_element(&self.txcount)?; - tup.end() - } -} - -impl Serialize for NodeIO { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut tup = serializer.serialize_tuple(1)?; - tup.serialize_element(self.used_state_cache_size.slice())?; - tup.end() - } -} - impl Serialize for BlockDetails { fn serialize(&self, serializer: S) -> Result where @@ -121,28 +182,18 @@ impl Serialize for BlockDetails { } } -impl Serialize for NodeLocation { - fn serialize(&self, serializer: S) -> Result +impl <'de> Deserialize<'de> for BlockDetails { + fn deserialize(deserializer: D) -> Result where - S: Serializer, + D: serde::Deserializer<'de>, { - let mut tup = serializer.serialize_tuple(3)?; - tup.serialize_element(&self.latitude)?; - tup.serialize_element(&self.longitude)?; - tup.serialize_element(&&*self.city)?; - tup.end() + let tup = <(u64, BlockHash, u64, u64, Option)>::deserialize(deserializer)?; + Ok(BlockDetails { + block: Block { height: tup.0, hash: tup.1 }, + block_time: tup.2, + block_timestamp: tup.3, + propagation_time: tup.4 + }) } } -impl Serialize for NodeHardware { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut tup = serializer.serialize_tuple(3)?; - tup.serialize_element(self.upload.slice())?; - tup.serialize_element(self.download.slice())?; - tup.serialize_element(self.chart_stamps.slice())?; - tup.end() - } -} diff --git a/backend/telemetry/Cargo.toml b/backend/telemetry_core/Cargo.toml similarity index 90% rename from backend/telemetry/Cargo.toml rename to backend/telemetry_core/Cargo.toml index 72d6aea..c78a741 100644 --- a/backend/telemetry/Cargo.toml +++ b/backend/telemetry_core/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "telemetry" +name = "telemetry_core" version = "0.1.0" authors = ["Parity Technologies Ltd. "] edition = "2018" @@ -29,3 +29,6 @@ thiserror = "1.0.25" tokio = { version = "1.7.0", features = ["full"] } tokio-util = { version = "0.6", features = ["compat"] } warp = "0.3.1" + +[dev-dependencies] +test_utils = { path = "../test_utils" } \ No newline at end of file diff --git a/backend/telemetry/src/aggregator/aggregator.rs b/backend/telemetry_core/src/aggregator/aggregator.rs similarity index 100% rename from backend/telemetry/src/aggregator/aggregator.rs rename to backend/telemetry_core/src/aggregator/aggregator.rs diff --git a/backend/telemetry/src/aggregator/inner_loop.rs b/backend/telemetry_core/src/aggregator/inner_loop.rs similarity index 100% rename from backend/telemetry/src/aggregator/inner_loop.rs rename to backend/telemetry_core/src/aggregator/inner_loop.rs diff --git a/backend/telemetry/src/aggregator/mod.rs b/backend/telemetry_core/src/aggregator/mod.rs similarity index 100% rename from backend/telemetry/src/aggregator/mod.rs rename to backend/telemetry_core/src/aggregator/mod.rs diff --git a/backend/telemetry/src/feed_message.rs b/backend/telemetry_core/src/feed_message.rs similarity index 99% rename from backend/telemetry/src/feed_message.rs rename to backend/telemetry_core/src/feed_message.rs index d6986ca..5512768 100644 --- a/backend/telemetry/src/feed_message.rs +++ b/backend/telemetry_core/src/feed_message.rs @@ -231,4 +231,4 @@ impl FeedMessageWrite for AddedNode<'_> { &node.startup_time(), )); } -} +} \ No newline at end of file diff --git a/backend/telemetry/src/find_location.rs b/backend/telemetry_core/src/find_location.rs similarity index 100% rename from backend/telemetry/src/find_location.rs rename to backend/telemetry_core/src/find_location.rs diff --git a/backend/telemetry/src/main.rs b/backend/telemetry_core/src/main.rs similarity index 100% rename from backend/telemetry/src/main.rs rename to backend/telemetry_core/src/main.rs diff --git a/backend/telemetry/src/state/chain.rs b/backend/telemetry_core/src/state/chain.rs similarity index 100% rename from backend/telemetry/src/state/chain.rs rename to backend/telemetry_core/src/state/chain.rs diff --git a/backend/telemetry/src/state/mod.rs b/backend/telemetry_core/src/state/mod.rs similarity index 100% rename from backend/telemetry/src/state/mod.rs rename to backend/telemetry_core/src/state/mod.rs diff --git a/backend/telemetry/src/state/node.rs b/backend/telemetry_core/src/state/node.rs similarity index 100% rename from backend/telemetry/src/state/node.rs rename to backend/telemetry_core/src/state/node.rs diff --git a/backend/telemetry/src/state/state.rs b/backend/telemetry_core/src/state/state.rs similarity index 100% rename from backend/telemetry/src/state/state.rs rename to backend/telemetry_core/src/state/state.rs diff --git a/backend/telemetry_core/tests/basic_tests.rs b/backend/telemetry_core/tests/basic_tests.rs new file mode 100644 index 0000000..b094be5 --- /dev/null +++ b/backend/telemetry_core/tests/basic_tests.rs @@ -0,0 +1,27 @@ +use test_utils::{feed_message_de::FeedMessage, server::Server}; +// use serde_json::json; + +#[tokio::test] +async fn can_ping_feed() { + + let server = Server::start_default() + .await + .expect("server could start"); + + // Connect to the feed: + let (mut feed_tx, mut feed_rx) = server.get_core().connect().await.unwrap(); + + // Expect a version response of 31: + let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); + assert_eq!(feed_messages, vec![FeedMessage::Version(31)], "expecting version"); + + // Ping it: + feed_tx.send_command("ping", "hello!").await.unwrap(); + + // Expect a pong response: + let feed_messages = feed_rx.recv_feed_messages().await.unwrap(); + assert_eq!(feed_messages, vec![FeedMessage::Pong { msg: "hello!".to_owned() }], "expecting pong"); + + // Tidy up: + server.shutdown().await; +} \ No newline at end of file diff --git a/backend/shard/Cargo.toml b/backend/telemetry_shard/Cargo.toml similarity index 96% rename from backend/shard/Cargo.toml rename to backend/telemetry_shard/Cargo.toml index e3a64db..ffc28f7 100644 --- a/backend/shard/Cargo.toml +++ b/backend/telemetry_shard/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "shard" +name = "telemetry_shard" version = "0.1.0" authors = ["Parity Technologies Ltd. "] edition = "2018" diff --git a/backend/shard/src/aggregator.rs b/backend/telemetry_shard/src/aggregator.rs similarity index 100% rename from backend/shard/src/aggregator.rs rename to backend/telemetry_shard/src/aggregator.rs diff --git a/backend/shard/src/connection.rs b/backend/telemetry_shard/src/connection.rs similarity index 100% rename from backend/shard/src/connection.rs rename to backend/telemetry_shard/src/connection.rs diff --git a/backend/shard/src/json_message/hash.rs b/backend/telemetry_shard/src/json_message/hash.rs similarity index 100% rename from backend/shard/src/json_message/hash.rs rename to backend/telemetry_shard/src/json_message/hash.rs diff --git a/backend/shard/src/json_message/mod.rs b/backend/telemetry_shard/src/json_message/mod.rs similarity index 100% rename from backend/shard/src/json_message/mod.rs rename to backend/telemetry_shard/src/json_message/mod.rs diff --git a/backend/shard/src/json_message/node_message.rs b/backend/telemetry_shard/src/json_message/node_message.rs similarity index 100% rename from backend/shard/src/json_message/node_message.rs rename to backend/telemetry_shard/src/json_message/node_message.rs diff --git a/backend/shard/src/main.rs b/backend/telemetry_shard/src/main.rs similarity index 100% rename from backend/shard/src/main.rs rename to backend/telemetry_shard/src/main.rs diff --git a/backend/shard/src/real_ip.rs b/backend/telemetry_shard/src/real_ip.rs similarity index 100% rename from backend/shard/src/real_ip.rs rename to backend/telemetry_shard/src/real_ip.rs diff --git a/backend/test_utils/Cargo.toml b/backend/test_utils/Cargo.toml new file mode 100644 index 0000000..02de8d0 --- /dev/null +++ b/backend/test_utils/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "test_utils" +version = "0.1.0" +authors = ["James Wilson "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.41" +futures = "0.3.15" +http = "0.2.4" +log = "0.4.14" +serde_json = "1.0.64" +soketto = "0.6.0" +thiserror = "1.0.25" +tokio = { version = "1.7.1", features = ["full"] } +tokio-util = "0.6.7" +common = { path = "../common" } diff --git a/backend/test_utils/src/feed_message_de.rs b/backend/test_utils/src/feed_message_de.rs new file mode 100644 index 0000000..28f78ac --- /dev/null +++ b/backend/test_utils/src/feed_message_de.rs @@ -0,0 +1,277 @@ +use common::node_types::{BlockDetails, BlockHash, BlockNumber, NodeLocation, NodeStats, Timestamp}; +use serde_json::value::RawValue; + +#[derive(Debug, PartialEq)] +pub enum FeedMessage { + Version(usize), + BestBlock { + block_number: BlockNumber, + timestamp: Timestamp, + avg_block_time: Option, + }, + BestFinalized { + block_number: BlockNumber, + block_hash: BlockHash, + }, + AddedNode { + node_id: usize, + node: NodeDetails, + stats: NodeStats, + // io: NodeIO, // can't losslessly deserialize + // hardware: NodeHardware, // can't losslessly deserialize + block_details: BlockDetails, + location: NodeLocation, + startup_time: Option, + }, + RemovedNode { + node_id: usize, + }, + LocatedNode { + node_id: usize, + lat: f32, + long: f32, + city: String, + }, + ImportedBlock { + node_id: usize, + block_details: BlockDetails, + }, + FinalizedBlock { + node_id: usize, + block_number: BlockNumber, + block_hash: BlockHash, + }, + NodeStatsUpdate { + node_id: usize, + stats: NodeStats, + }, + Hardware { + node_id: usize, + // hardware: NodeHardware, // Can't losslessly deserialize + }, + TimeSync { + time: Timestamp, + }, + AddedChain { + name: String, + node_count: usize, + }, + RemovedChain { + name: String, + }, + SubscribedTo { + name: String, + }, + UnsubscribedFrom { + name: String, + }, + Pong { + msg: String, + }, + AfgFinalized { + address: String, + block_number: BlockNumber, + block_hash: BlockHash, + }, + AfgReceivedPrevote { + address: String, + block_number: BlockNumber, + block_hash: BlockHash, + voter: Option + }, + AfgReceivedPrecommit { + address: String, + block_number: BlockNumber, + block_hash: BlockHash, + voter: Option, + }, + AfgAuthoritySet { // Not used currently; not sure what "address" params are: + a1: String, + a2: String, + a3: String, + block_number: BlockNumber, + block_hash: BlockHash, + }, + StaleNode { + node_id: usize, + }, + NodeIOUpdate { + node_id: usize, + // details: NodeIO, // can't losslessly deserialize + }, + /// A "special" case when we don't know how to decode an action: + UnknownValue { + action: u8, + value: String + } +} + +#[derive(Debug, PartialEq)] +pub struct NodeDetails { + pub name: String, + pub implementation: String, + pub version: String, + pub validator: Option, + pub network_id: Option, +} + +impl FeedMessage { + /// Decode a slice of bytes into a vector of feed messages + pub fn from_bytes(bytes: &[u8]) -> Result, anyhow::Error> { + let v: Vec<&RawValue> = serde_json::from_slice(bytes)?; + + let mut feed_messages = vec![]; + for raw_keyval in v.windows(2) { + let raw_key = raw_keyval[0]; + let raw_val = raw_keyval[1]; + feed_messages.push(FeedMessage::decode(raw_key, raw_val)?); + } + + Ok(feed_messages) + } + + // Deserialize the feed message to a value based on the "action" key + fn decode(raw_key: &RawValue, raw_val: &RawValue) -> Result { + let action: u8 = serde_json::from_str(raw_key.get())?; + let feed_message = match action { + // Version: + 0 => { + let version = serde_json::from_str(raw_val.get())?; + FeedMessage::Version(version) + }, + // BestBlock + 1 => { + let (block_number, timestamp, avg_block_time) = serde_json::from_str(raw_val.get())?; + FeedMessage::BestBlock { block_number, timestamp, avg_block_time } + }, + // BestFinalized + 2 => { + let (block_number, block_hash) = serde_json::from_str(raw_val.get())?; + FeedMessage::BestFinalized { block_number, block_hash } + } + // AddNode + 3 => { + let ( + node_id, + ( name, implementation, version, validator, network_id ), + stats, + io, + hardware, + block_details, + location, + startup_time, + ) = serde_json::from_str(raw_val.get())?; + + // Give these two types but don't use the results: + let (_,_): (&RawValue, &RawValue) = (io, hardware); + + FeedMessage::AddedNode { + node_id, + node: NodeDetails { name, implementation, version, validator, network_id }, + stats, + block_details, + location, + startup_time, + } + }, + // RemoveNode + 4 => { + let node_id = serde_json::from_str(raw_val.get())?; + FeedMessage::RemovedNode { node_id } + }, + // LocatedNode + 5 => { + let (node_id, lat, long, city) = serde_json::from_str(raw_val.get())?; + FeedMessage::LocatedNode { node_id, lat, long, city } + }, + // ImportedBlock + 6 => { + let (node_id, block_details) = serde_json::from_str(raw_val.get())?; + FeedMessage::ImportedBlock { node_id, block_details } + }, + // FinalizedBlock + 7 => { + let (node_id, block_number, block_hash) = serde_json::from_str(raw_val.get())?; + FeedMessage::FinalizedBlock { node_id, block_number, block_hash } + }, + // NodeStatsUpdate + 8 => { + let (node_id, stats) = serde_json::from_str(raw_val.get())?; + FeedMessage::NodeStatsUpdate { node_id, stats } + }, + // Hardware + 9 => { + let (node_id, _hardware): (_, &RawValue) = serde_json::from_str(raw_val.get())?; + FeedMessage::Hardware { node_id } + }, + // TimeSync + 10 => { + let time = serde_json::from_str(raw_val.get())?; + FeedMessage::TimeSync { time } + }, + // AddedChain + 11 => { + let (name, node_count) = serde_json::from_str(raw_val.get())?; + FeedMessage::AddedChain { name, node_count } + }, + // RemovedChain + 12 => { + let name = serde_json::from_str(raw_val.get())?; + FeedMessage::RemovedChain { name } + }, + // SubscribedTo + 13 => { + let name = serde_json::from_str(raw_val.get())?; + FeedMessage::SubscribedTo { name } + }, + // UnsubscribedFrom + 14 => { + let name = serde_json::from_str(raw_val.get())?; + FeedMessage::UnsubscribedFrom { name } + }, + // Pong + 15 => { + let msg = serde_json::from_str(raw_val.get())?; + FeedMessage::Pong { msg } + }, + // AfgFinalized + 16 => { + let (address, block_number, block_hash) = serde_json::from_str(raw_val.get())?; + FeedMessage::AfgFinalized { address, block_number, block_hash } + }, + // AfgReceivedPrevote + 17 => { + let (address, block_number, block_hash, voter) = serde_json::from_str(raw_val.get())?; + FeedMessage::AfgReceivedPrevote { address, block_number, block_hash, voter } + }, + // AfgReceivedPrecommit + 18 => { + let (address, block_number, block_hash, voter) = serde_json::from_str(raw_val.get())?; + FeedMessage::AfgReceivedPrecommit { address, block_number, block_hash, voter } + }, + // AfgAuthoritySet + 19 => { + let (a1, a2, a3, block_number, block_hash) = serde_json::from_str(raw_val.get())?; + FeedMessage::AfgAuthoritySet { a1, a2, a3, block_number, block_hash } + }, + // StaleNode + 20 => { + let node_id = serde_json::from_str(raw_val.get())?; + FeedMessage::StaleNode { node_id } + }, + // NodeIOUpdate + 21 => { + // ignore NodeIO for now: + let (node_id, _node_io): (_, &RawValue) = serde_json::from_str(raw_val.get())?; + FeedMessage::NodeIOUpdate { node_id } + }, + // A catchall for messages we don't know/care about yet: + _ => { + let value = raw_val.to_string(); + FeedMessage::UnknownValue { action, value } + }, + }; + + Ok(feed_message) + } +} \ No newline at end of file diff --git a/backend/test_utils/src/lib.rs b/backend/test_utils/src/lib.rs new file mode 100644 index 0000000..fc44e9b --- /dev/null +++ b/backend/test_utils/src/lib.rs @@ -0,0 +1,10 @@ +/// Create/connect to a server consisting of shards and a core process that we can interact with. +pub mod server; + +/// Test support for deserializing feed messages from the feed processes. This basically +/// is the slightly-lossy inverse of the custom serialization we do to feed messages. +pub mod feed_message_de; + +/// A wrapper around soketto to simplify the process of establishing connections +/// and sending messages. Provides cancel-safe message channels. +pub mod ws_client; diff --git a/backend/test_utils/src/server/channels.rs b/backend/test_utils/src/server/channels.rs new file mode 100644 index 0000000..11815b4 --- /dev/null +++ b/backend/test_utils/src/server/channels.rs @@ -0,0 +1,117 @@ +use crate::ws_client; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use crate::feed_message_de::FeedMessage; + +/// Wrap a `ws_client::Sender` with convenient utility methods for shard connections +pub struct ShardSender(ws_client::Sender); + +impl From for ShardSender { + fn from(c: ws_client::Sender) -> Self { ShardSender(c) } +} + +impl Sink for ShardSender { + type Error = ws_client::SendError; + fn poll_ready(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_ready_unpin(cx) + } + fn start_send(mut self: std::pin::Pin<&mut Self>, item: ws_client::Message) -> Result<(), Self::Error> { + self.0.start_send_unpin(item) + } + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_flush_unpin(cx) + } + fn poll_close(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_close_unpin(cx) + } +} + +impl ShardSender { + pub async fn send_json_binary(&mut self, json: serde_json::Value) -> Result<(), ws_client::SendError> { + let bytes = serde_json::to_vec(&json).expect("valid bytes"); + self.send(ws_client::Message::Binary(bytes)).await + } + pub async fn send_json_text(&mut self, json: serde_json::Value) -> Result<(), ws_client::SendError> { + let s = serde_json::to_string(&json).expect("valid string"); + self.send(ws_client::Message::Text(s)).await + } +} + +/// Wrap a `ws_client::Receiver` with convenient utility methods for shard connections +pub struct ShardReceiver(ws_client::Receiver); + +impl From for ShardReceiver { + fn from(c: ws_client::Receiver) -> Self { ShardReceiver(c) } +} + +impl Stream for ShardReceiver { + type Item = Result; + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_next_unpin(cx) + } +} + + +/// Wrap a `ws_client::Sender` with convenient utility methods for feed connections +pub struct FeedSender(ws_client::Sender); + +impl From for FeedSender { + fn from(c: ws_client::Sender) -> Self { FeedSender(c) } +} + +impl Sink for FeedSender { + type Error = ws_client::SendError; + fn poll_ready(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_ready_unpin(cx) + } + fn start_send(mut self: std::pin::Pin<&mut Self>, item: ws_client::Message) -> Result<(), Self::Error> { + self.0.start_send_unpin(item) + } + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_flush_unpin(cx) + } + fn poll_close(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_close_unpin(cx) + } +} + +impl FeedSender { + pub async fn send_command>(&mut self, command: S, param: S) -> Result<(), ws_client::SendError> { + self.send(ws_client::Message::Text(format!("{}:{}", command.as_ref(), param.as_ref()))).await + } +} + +/// Wrap a `ws_client::Receiver` with convenient utility methods for feed connections +pub struct FeedReceiver(ws_client::Receiver); + +impl From for FeedReceiver { + fn from(c: ws_client::Receiver) -> Self { FeedReceiver(c) } +} + +impl Stream for FeedReceiver { + type Item = Result; + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.0.poll_next_unpin(cx).map_err(|e| e.into()) + } +} + +impl FeedReceiver { + /// Wait for the next set of feed messages to arrive. Returns an error if the connection + /// is closed, or the messages that come back cannot be properly decoded. + pub async fn recv_feed_messages(&mut self) -> Result, anyhow::Error> { + let msg = self.0 + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Stream closed: no more messages"))??; + + match msg { + ws_client::Message::Binary(data) => { + let messages = FeedMessage::from_bytes(&data)?; + Ok(messages) + }, + ws_client::Message::Text(text) => { + let messages = FeedMessage::from_bytes(text.as_bytes())?; + Ok(messages) + } + } + } +} \ No newline at end of file diff --git a/backend/test_utils/src/server/default_commands.rs b/backend/test_utils/src/server/default_commands.rs new file mode 100644 index 0000000..da5d928 --- /dev/null +++ b/backend/test_utils/src/server/default_commands.rs @@ -0,0 +1,30 @@ +use super::Command; +use std::path::PathBuf; + +pub fn default_telemetry_shard_command() -> Result { + default_telemetry_command("telemetry_shard") +} + +pub fn default_telemetry_core_command() -> Result { + default_telemetry_command("telemetry_core") +} + +fn default_telemetry_command(bin: &'static str) -> Result { + let mut workspace_dir = try_find_workspace_dir()?; + workspace_dir.push("Cargo.toml"); + Ok(Command::new("cargo") + .arg("run") + .arg("--bin") + .arg(bin) + .arg("--manifest-path") + .arg(workspace_dir) + .arg("--")) +} + +/// A _very_ naive way to find the workspace ("backend") directory +/// from the current path (which is assumed to be inside it). +fn try_find_workspace_dir() -> Result { + let mut dir = std::env::current_dir()?; + while !dir.ends_with("backend") && dir.pop() {} + Ok(dir) +} \ No newline at end of file diff --git a/backend/test_utils/src/server/mod.rs b/backend/test_utils/src/server/mod.rs new file mode 100644 index 0000000..6d92098 --- /dev/null +++ b/backend/test_utils/src/server/mod.rs @@ -0,0 +1,6 @@ +mod utils; +mod server; +mod default_commands; + +pub mod channels; +pub use server::*; \ No newline at end of file diff --git a/backend/test_utils/src/server/server.rs b/backend/test_utils/src/server/server.rs new file mode 100644 index 0000000..e49579a --- /dev/null +++ b/backend/test_utils/src/server/server.rs @@ -0,0 +1,324 @@ +use std::ffi::OsString; +use std::marker::PhantomData; +use crate::ws_client; +use tokio::process::{ self, Command as TokioCommand }; +use super::{ channels, utils }; +use common::{ id_type, DenseMap }; + +id_type! { + /// The ID of a running process. Cannot be constructed externally. + pub struct ProcessId(usize); +} + +pub struct StartOpts { + /// Optional command to run to start a shard (instead of `telemetry_shard`). + /// The `--listen` and `--log` arguments will be appended within and shouldn't be provided. + pub shard_command: Option, + /// Optional command to run to start a telemetry core process (instead of `telemetry_core`). + /// The `--listen` and `--log` arguments will be appended within and shouldn't be provided. + pub core_command: Option +} + +impl Default for StartOpts { + fn default() -> Self { + StartOpts { + shard_command: None, + core_command: None + } + } +} + +pub struct ConnectToExistingOpts { + /// Details for connections to `telemetry_shard` /submit endpoints + pub shard_uris: Vec, + /// Details for connections to `telemetry_core` /feed endpoints + pub feed_uri: http::Uri, +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Can't establsih connection: {0}")] + ConnectionError(#[from] ws_client::ConnectError), + #[error("Can't establsih connection: {0}")] + JoinError(#[from] tokio::task::JoinError), + #[error("Can't establsih connection: {0}")] + IoError(#[from] std::io::Error), + #[error("Could not obtain port for process: {0}")] + ErrorObtainingPort(anyhow::Error), + #[error("Whoops; attempt to kill a process we didn't start (and so have no handle to)")] + CannotKillNoHandle, + #[error("Whoops; attempt to add a shard to a server we didn't start (and so have no handle to)")] + CannotAddShardNoHandle, +} + +/// This provides back connections (or groups of connections) that are +/// hooked up to the running processes and ready to send/receive messages. +pub struct Server { + /// URI to connect a shard to core: + core_shard_submit_uri: Option, + /// Command to run to start a new shard: + shard_command: Option, + /// Shard processes that we can connect to + shards: DenseMap, + /// Core process that we can connect to + core: CoreProcess, +} + +impl Server { + pub fn get_core(&self) -> &CoreProcess { + &self.core + } + + pub fn get_shard(&self, id: ProcessId) -> Option<&ShardProcess> { + self.shards.get(id) + } + + pub fn iter_shards(&self) -> impl Iterator { + self.shards.iter().map(|(_,v)| v) + } + + pub async fn kill_shard(&mut self, id: ProcessId) -> bool { + let shard = match self.shards.remove(id) { + Some(shard) => shard, + None => return false + }; + + // With this, killing will complete even if the promise returned is cancelled + // (it should regardless, but just to play it safe..) + let _ = tokio::spawn(async move { + let _ = shard.kill().await; + }).await; + + true + } + + /// Kill everything and tidy up + pub async fn shutdown(self) { + // Spawn so we don't need to await cleanup if we don't care. + // Run all kill futs simultaneously. + let handle = tokio::spawn(async move { + let shard_kill_futs = self.shards + .into_iter() + .map(|(_,s)| s.kill()); + + let _ = tokio::join!( + futures::future::join_all(shard_kill_futs), + self.core.kill() + ); + }); + + // You can wait for cleanup but aren't obliged to: + let _ = handle.await; + } + + /// Connect a new shard and return a process that you can interact with: + pub async fn add_shard(&mut self) -> Result { + let core_uri = match &self.core_shard_submit_uri { + Some(uri) => uri, + None => return Err(Error::CannotAddShardNoHandle) + }; + + let mut shard_cmd: TokioCommand = match &self.shard_command { + Some(cmd) => cmd.clone(), + None => super::default_commands::default_telemetry_shard_command()? + }.into(); + + shard_cmd + .arg("--listen") + .arg("127.0.0.1:0") // 0 to have a port picked by the kernel + .arg("--log") + .arg("info") + .arg("--core") + .arg(core_uri.to_string()) + .kill_on_drop(true) + .stdout(std::process::Stdio::piped()) + .stdin(std::process::Stdio::piped()); + + let mut shard_process = shard_cmd.spawn()?; + let mut child_stdout = shard_process.stdout.take().expect("shard stdout"); + let shard_port = utils::get_port(&mut child_stdout) + .await + .map_err(|e| Error::ErrorObtainingPort(e))?; + + // Since we're piping stdout from the child process, we need somewhere for it to go + // else the process will get stuck when it tries to produce output: + utils::drain(child_stdout, tokio::io::stdout()); + + let shard_uri = format!("http://127.0.0.1:{}/submit", shard_port) + .parse() + .expect("valid submit URI"); + + let pid = self.shards.add_with(|id| Process { + id, + handle: Some(shard_process), + uri: shard_uri, + _channel_type: PhantomData + }); + + Ok(pid) + } + + /// Start a telemetry_core process with default opts. From here, we can add/remove shards as needed. + pub async fn start_default() -> Result { + Server::start(StartOpts::default()).await + } + + /// Start a telemetry_core process. From here, we can add/remove shards as needed. + pub async fn start(opts: StartOpts) -> Result { + + let mut core_cmd: TokioCommand = match opts.core_command { + Some(cmd) => cmd, + None => super::default_commands::default_telemetry_core_command()? + }.into(); + + let mut child = core_cmd + .arg("--listen") + .arg("127.0.0.1:0") // 0 to have a port picked by the kernel + .arg("--log") + .arg("info") + .kill_on_drop(true) + .stdout(std::process::Stdio::piped()) + .stdin(std::process::Stdio::piped()) + .spawn()?; + + // Find out the port that this is running on + let mut child_stdout = child.stdout.take().expect("core stdout"); + let core_port = utils::get_port(&mut child_stdout) + .await + .map_err(|e| Error::ErrorObtainingPort(e))?; + + // Since we're piping stdout from the child process, we need somewhere for it to go + // else the process will get stuck when it tries to produce output: + utils::drain(child_stdout, tokio::io::stdout()); + + // URI for feeds to connect to the core: + let feed_uri = format!("http://127.0.0.1:{}/feed", core_port) + .parse() + .expect("valid feed URI"); + + Ok(Server { + shard_command: opts.shard_command, + core_shard_submit_uri: Some(format!("http://127.0.0.1:{}/shard_submit", core_port) + .parse() + .expect("valid shard_submit URI")), + shards: DenseMap::new(), + core: Process { + id: ProcessId(0), + handle: Some(child), + uri: feed_uri, + _channel_type: PhantomData, + } + }) + } + + /// Establshes the requested connections to existing processes. + pub fn connect_to_existing(opts: ConnectToExistingOpts) -> Server { + let mut shards = DenseMap::new(); + for shard_uri in opts.shard_uris { + shards.add_with(|id| Process { + id, + uri: shard_uri, + handle: None, + _channel_type: PhantomData, + }); + } + + Server { + shard_command: None, + // We can't add shards if starting in this mode: + core_shard_submit_uri: None, + shards, + core: Process { + id: ProcessId(0), + uri: opts.feed_uri, + handle: None, + _channel_type: PhantomData, + } + } + } +} + + +/// This represents a running process that we can connect to, which +/// may be either a `telemetry_shard` or `telemetry_core`. +pub struct Process { + id: ProcessId, + /// If we started the processes ourselves, we'll have a handle to + /// them which we can use to kill them. Else, we may not. + handle: Option, + /// The URI that we can use to connect to the process socket. + uri: http::Uri, + /// The kind of the process (lets us add methods specific to shard/core). + _channel_type: PhantomData +} + +/// A shard process with shard-specific methods. +pub type ShardProcess = Process<(channels::ShardSender, channels::ShardReceiver)>; + +/// A core process with core-specific methods. +pub type CoreProcess = Process<(channels::FeedSender, channels::FeedReceiver)>; + +impl Process { + /// Get the ID of this process + pub fn id(&self) -> ProcessId { + self.id + } + + /// Kill the process and wait for this to complete + /// Not public: Klling done via Server. + async fn kill(self) -> Result<(), Error> { + match self.handle { + Some(mut handle) => Ok(handle.kill().await?), + None => Err(Error::CannotKillNoHandle) + } + } +} + +impl , Recv: From> Process<(Send, Recv)> { + /// Establish a connection to the process + pub async fn connect(&self) -> Result<(Send, Recv), Error> { + ws_client::connect(&self.uri) + .await + .map(|(s,r)| (s.into(), r.into())) + .map_err(|e| e.into()) + } + + /// Establish multiple connections to the process + pub async fn connect_multiple(&self, num_connections: usize) -> Result, Error> { + utils::connect_multiple_to_uri(&self.uri, num_connections) + .await + .map(|v| v.into_iter().map(|(s,r)| (s.into(), r.into())).collect()) + .map_err(|e| e.into()) + } +} + +/// This defines a command to run. This exists because [`tokio::process::Command`] +/// cannot be cloned, but we need to be able to clone our command to spawn multiple +/// processes with it. +#[derive(Clone, Debug)] +pub struct Command { + command: OsString, + args: Vec +} + +impl Command { + pub fn new>(command: S) -> Command { + Command { + command: command.into(), + args: Vec::new() + } + } + + pub fn arg>(mut self, arg: S) -> Command { + self.args.push(arg.into()); + self + } +} + +impl Into for Command { + fn into(self) -> TokioCommand { + let mut cmd = TokioCommand::new(self.command); + cmd.args(self.args); + cmd + } +} \ No newline at end of file diff --git a/backend/test_utils/src/server/utils.rs b/backend/test_utils/src/server/utils.rs new file mode 100644 index 0000000..8dde9d8 --- /dev/null +++ b/backend/test_utils/src/server/utils.rs @@ -0,0 +1,65 @@ +use crate::ws_client; +use tokio::io::BufReader; +use tokio::io::{ AsyncRead, AsyncWrite, AsyncBufReadExt }; +use tokio::time::Duration; +use anyhow::{ anyhow, Context }; + +/// Reads from the stdout of the shard/core process to extract the port that was assigned to it, +/// with the side benefit that we'll wait for it to start listening before returning. We do this +/// because we want to allow the kernel to assign ports and so don't specify a port as an arg. +pub async fn get_port(reader: R) -> Result { + let reader = BufReader::new(reader); + let mut reader_lines = reader.lines(); + + loop { + let line = tokio::time::timeout( + // This has to accomodate pauses during compilation if the cmd is "cargo run --": + Duration::from_secs(30), + reader_lines.next_line() + ).await; + + let line = match line { + // timeout expired; couldn't get port: + Err(e) => return Err(anyhow!("Timeout expired waiting to discover port: {}", e)), + // Something went wrong reading line; bail: + Ok(Err(e)) => return Err(anyhow!("Could not read line from stdout: {}", e)), + // No more output; process ended? bail: + Ok(Ok(None)) => return Err(anyhow!("No more output from stdout; has the process ended?")), + // All OK, and a line is given back; phew! + Ok(Ok(Some(line))) => line + }; + + let (_, port_str) = match line.rsplit_once("listening on http://127.0.0.1:") { + Some(m) => m, + None => continue + }; + + return port_str + .trim() + .parse() + .with_context(|| format!("Could not parse output to port: {}", port_str)); + } +} + +/// Establish multiple connections to a URI and return them all. +pub async fn connect_multiple_to_uri(uri: &http::Uri, num_connections: usize) -> Result, ws_client::ConnectError> { + let connect_futs = (0..num_connections) + .map(|_| ws_client::connect(uri)); + let sockets: Result,_> = futures::future::join_all(connect_futs) + .await + .into_iter() + .collect(); + sockets +} + +/// Drain output from a reader to stdout. After acquiring port details from spawned processes, +/// they expect their stdout to be continue to be consumed, and so we do this here. +pub fn drain(mut reader: R, mut writer: W) +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static +{ + tokio::spawn(async move { + let _ = tokio::io::copy(&mut reader, &mut writer).await; + }); +} \ No newline at end of file diff --git a/backend/test_utils/src/ws_client.rs b/backend/test_utils/src/ws_client.rs new file mode 100644 index 0000000..ea40bf2 --- /dev/null +++ b/backend/test_utils/src/ws_client.rs @@ -0,0 +1,193 @@ +use futures::channel::{ mpsc }; +use soketto::handshake::{Client, ServerResponse}; +use tokio_util::compat::{ TokioAsyncReadCompatExt }; +use tokio::net::TcpStream; +use futures::{Sink, SinkExt, Stream, StreamExt}; + +/// Send messages into the connection +#[derive(Clone)] +pub struct Sender { + inner: mpsc::UnboundedSender +} + +impl Sender { + pub async fn close(&mut self) -> Result<(),SendError> { + self.inner.send(SentMessage::Close).await?; + Ok(()) + } +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum SendError { + #[error("Failed to send message: {0}")] + ChannelError(#[from] mpsc::SendError) +} + +impl Sink for Sender { + type Error = SendError; + fn poll_ready(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.inner.poll_ready_unpin(cx).map_err(|e| e.into()) + } + fn start_send(mut self: std::pin::Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + self.inner.start_send_unpin(SentMessage::Message(item)).map_err(|e| e.into()) + } + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.inner.poll_flush_unpin(cx).map_err(|e| e.into()) + } + fn poll_close(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.inner.poll_close_unpin(cx).map_err(|e| e.into()) + } +} + +/// Receive messages out of a connection +pub struct Receiver { + inner: mpsc::UnboundedReceiver> +} + +#[derive(thiserror::Error, Debug)] +pub enum RecvError { + #[error("Text message contains invalid UTF8: {0}")] + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("Stream finished")] + StreamFinished +} + +impl Stream for Receiver { + type Item = Result; + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.inner.poll_next_unpin(cx).map_err(|e| e.into()) + } +} + +/// A message type that can be sent or received from the connection +pub enum Message { + Text(String), + Binary(Vec) +} + +/// Sent messages can be anything publically visible, or a close message. +enum SentMessage { + Message(Message), + Close +} + +#[derive(thiserror::Error, Debug)] +pub enum ConnectError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Handshake error: {0}")] + Handshake(#[from] soketto::handshake::Error), + #[error("Redirect not supported (status code: {status_code})")] + ConnectionFailedRedirect { status_code: u16 }, + #[error("Connection rejected (status code: {status_code})")] + ConnectionFailedRejected { status_code: u16 }, +} + +/// Establish a websocket connection that you can send and receive messages from. +/// A thin wrapper around Soketto that provides cancel-safe send/receive handles. +/// +/// This must be called within the context of a tokio runtime. +pub async fn connect(uri: &http::Uri) -> Result<(Sender, Receiver), ConnectError> { + let host = uri.host().unwrap_or("127.0.0.1"); + let port = uri.port_u16().unwrap_or(80); + let path = uri.path(); + + let socket = TcpStream::connect((host, port)).await?; + socket.set_nodelay(true).unwrap(); + + // Establish a WS connection: + let mut client = Client::new(socket.compat(), host, &path); + let (mut ws_to_connection, mut ws_from_connection) = match client.handshake().await? { + ServerResponse::Accepted { .. } => { + client.into_builder().finish() + }, + ServerResponse::Redirect { status_code, .. } => { + return Err(ConnectError::ConnectionFailedRedirect { status_code }) + }, + ServerResponse::Rejected { status_code } => { + return Err(ConnectError::ConnectionFailedRejected { status_code }) + } + }; + + // Soketto sending/receiving isn't cancel safe, so we wrap the message stuff into spawned + // tasks and use channels (which are cancel safe) to send/recv messages atomically.. + + // Receive messages from the socket and post them out: + let (mut tx_to_external, rx_from_ws) = mpsc::unbounded(); + tokio::spawn(async move { + let mut data = Vec::with_capacity(128); + loop { + // Clear the buffer and wait for the next message to arrive: + data.clear(); + + let message_data = match ws_from_connection.receive_data(&mut data).await { + Err(e) => { + // Couldn't receive data may mean all senders are gone, so log + // the error and shut this down: + log::error!("Shutting down websocket connection: Failed to receive data: {}", e); + break; + }, + Ok(data) => { + data + } + }; + + let msg = match message_data { + soketto::Data::Text(_) => { + Ok(Message::Binary(data)) + }, + soketto::Data::Binary(_) => { + String::from_utf8(data) + .map(|s| Message::Text(s)) + .map_err(|e| e.into()) + }, + }; + + data = Vec::with_capacity(128); + + if let Err(e) = tx_to_external.send(msg).await { + // Failure to send likely means that the recv has been dropped, + // so let's drop this loop too. + log::error!("Shutting down websocket connection: Failed to send data out: {}", e); + break; + } + } + }); + + // Receive messages externally to send to the socket. + let (tx_to_ws, mut rx_from_external) = mpsc::unbounded(); + tokio::spawn(async move { + while let Some(msg) = rx_from_external.next().await { + match msg { + SentMessage::Message(Message::Text(s)) => { + if let Err(e) = ws_to_connection.send_text_owned(s).await { + log::error!("Shutting down websocket connection: Failed to send text data: {}", e); + break; + } + }, + SentMessage::Message(Message::Binary(bytes)) => { + if let Err(e) = ws_to_connection.send_binary_mut(bytes).await { + log::error!("Shutting down websocket connection: Failed to send binary data: {}", e); + break; + } + }, + SentMessage::Close => { + if let Err(e) = ws_to_connection.close().await { + log::error!("Error attempting to close connection: {}", e); + break; + } + } + } + + if let Err(e) = ws_to_connection.flush().await { + log::error!("Shutting down websocket connection: Failed to flush data: {}", e); + break; + } + } + }); + + Ok(( + Sender { inner: tx_to_ws }, + Receiver { inner: rx_from_ws } + )) +}