support https and wss core url in shard's ws client (#488)

* support https and wss core url in shard's ws client

* fmt

* small fix
This commit is contained in:
Renjie Cai
2022-08-05 00:01:32 +08:00
committed by GitHub
parent 24fd27e0bd
commit 65ad10b715
3 changed files with 128 additions and 4 deletions
+45 -4
View File
@@ -16,8 +16,12 @@
use super::on_close::OnClose;
use futures::{channel, StreamExt};
use soketto::handshake::{Client, ServerResponse};
use std::io;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, ServerName};
use tokio_rustls::{rustls, TlsConnector};
use tokio_util::compat::TokioAsyncReadCompatExt;
use super::{
@@ -25,12 +29,16 @@ use super::{
sender::{Sender, SentMessage},
};
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadWrite for T {}
/// The send side of a Soketto WebSocket connection
pub type RawSender = soketto::connection::Sender<tokio_util::compat::Compat<tokio::net::TcpStream>>;
pub type RawSender =
soketto::connection::Sender<tokio_util::compat::Compat<Box<dyn AsyncReadWrite>>>;
/// The receive side of a Soketto WebSocket connection
pub type RawReceiver =
soketto::connection::Receiver<tokio_util::compat::Compat<tokio::net::TcpStream>>;
soketto::connection::Receiver<tokio_util::compat::Compat<Box<dyn AsyncReadWrite>>>;
/// A websocket connection. From this, we can either expose the raw connection
/// or expose a cancel-safe interface to it.
@@ -228,11 +236,17 @@ pub enum ConnectError {
/// Establish a websocket connection that you can send and receive messages from.
pub async fn connect(uri: &http::Uri) -> Result<Connection, ConnectError> {
let host = uri.host().unwrap_or("127.0.0.1");
let port = uri.port_u16().unwrap_or(80);
let scheme = uri.scheme_str().unwrap_or("ws");
let mut port = 80;
if scheme == "https" || scheme == "wss" {
port = 443
}
let path = uri.path();
let port = uri.port_u16().unwrap_or(port);
let socket = TcpStream::connect((host, port)).await?;
socket.set_nodelay(true).expect("socket set_nodelay failed");
// wrap TCP stream with TLS if schema is https or wss
let socket = may_connect_tls(socket, host, scheme == "https" || scheme == "wss").await?;
// Establish a WS connection:
let mut client = Client::new(socket.compat(), host, &path);
@@ -251,3 +265,30 @@ pub async fn connect(uri: &http::Uri) -> Result<Connection, ConnectError> {
rx: ws_from_connection,
})
}
async fn may_connect_tls(
socket: TcpStream,
host: &str,
use_https: bool,
) -> io::Result<Box<dyn AsyncReadWrite>> {
if !use_https {
return Ok(Box::new(socket));
};
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let domain = ServerName::try_from(host)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dns name"))?;
let socket = connector.connect(domain, socket).await?;
Ok(Box::new(socket))
}