diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 4673e61..da1c272 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -223,7 +223,9 @@ dependencies = [ "soketto", "thiserror", "tokio", + "tokio-rustls", "tokio-util", + "webpki-roots", ] [[package]] @@ -1339,6 +1341,21 @@ dependencies = [ "winreg", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin", + "untrusted", + "web-sys", + "winapi", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -1360,6 +1377,18 @@ dependencies = [ "semver", ] +[[package]] +name = "rustls" +version = "0.20.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +dependencies = [ + "log", + "ring", + "sct", + "webpki", +] + [[package]] name = "ryu" version = "1.0.5" @@ -1391,6 +1420,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.3.1" @@ -1555,6 +1594,12 @@ dependencies = [ "sha-1", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "spinning_top" version = "0.2.4" @@ -1828,6 +1873,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" +dependencies = [ + "rustls", + "tokio", + "webpki", +] + [[package]] name = "tokio-util" version = "0.6.7" @@ -1930,6 +1986,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "url" version = "2.2.2" @@ -2065,6 +2127,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "webpki-roots" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1c760f0d366a6c24a02ed7816e23e691f5d92291f94d15e836006fd11b04daf" +dependencies = [ + "webpki", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/backend/common/Cargo.toml b/backend/common/Cargo.toml index ccb03fa..8455675 100644 --- a/backend/common/Cargo.toml +++ b/backend/common/Cargo.toml @@ -29,6 +29,8 @@ thiserror = "1.0.24" tokio = { version = "1.8.2", features = ["full"] } tokio-util = { version = "0.6", features = ["compat"] } arrayvec = { version = "0.7.1", features = ["serde"] } +tokio-rustls = "0.23.4" +webpki-roots = "0.22.4" [dev-dependencies] bincode = "1.3.3" diff --git a/backend/common/src/ws_client/connect.rs b/backend/common/src/ws_client/connect.rs index 193b3ee..616469c 100644 --- a/backend/common/src/ws_client/connect.rs +++ b/backend/common/src/ws_client/connect.rs @@ -16,8 +16,12 @@ use super::on_close::OnClose; use futures::{channel, StreamExt}; use soketto::handshake::{Client, ServerResponse}; +use std::io; use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; +use tokio_rustls::rustls::{OwnedTrustAnchor, ServerName}; +use tokio_rustls::{rustls, TlsConnector}; use tokio_util::compat::TokioAsyncReadCompatExt; use super::{ @@ -25,12 +29,16 @@ use super::{ sender::{Sender, SentMessage}, }; +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send {} +impl AsyncReadWrite for T {} + /// The send side of a Soketto WebSocket connection -pub type RawSender = soketto::connection::Sender>; +pub type RawSender = + soketto::connection::Sender>>; /// The receive side of a Soketto WebSocket connection pub type RawReceiver = - soketto::connection::Receiver>; + soketto::connection::Receiver>>; /// A websocket connection. From this, we can either expose the raw connection /// or expose a cancel-safe interface to it. @@ -228,11 +236,17 @@ pub enum ConnectError { /// Establish a websocket connection that you can send and receive messages from. pub async fn connect(uri: &http::Uri) -> Result { let host = uri.host().unwrap_or("127.0.0.1"); - let port = uri.port_u16().unwrap_or(80); + let scheme = uri.scheme_str().unwrap_or("ws"); + let mut port = 80; + if scheme == "https" || scheme == "wss" { + port = 443 + } let path = uri.path(); - + let port = uri.port_u16().unwrap_or(port); let socket = TcpStream::connect((host, port)).await?; socket.set_nodelay(true).expect("socket set_nodelay failed"); + // wrap TCP stream with TLS if schema is https or wss + let socket = may_connect_tls(socket, host, scheme == "https" || scheme == "wss").await?; // Establish a WS connection: let mut client = Client::new(socket.compat(), host, &path); @@ -251,3 +265,30 @@ pub async fn connect(uri: &http::Uri) -> Result { rx: ws_from_connection, }) } + +async fn may_connect_tls( + socket: TcpStream, + host: &str, + use_https: bool, +) -> io::Result> { + if !use_https { + return Ok(Box::new(socket)); + }; + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(config)); + let domain = ServerName::try_from(host) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dns name"))?; + let socket = connector.connect(domain, socket).await?; + Ok(Box::new(socket)) +}