cargo fmt

This commit is contained in:
James Wilson
2021-07-27 12:38:05 +01:00
parent a44f39cfaf
commit 2ca5f6a1fb
23 changed files with 707 additions and 511 deletions
+76 -59
View File
@@ -1,29 +1,27 @@
use std::net::SocketAddr;
use hyper::{ Server, Request, Response, Body };
use std::future::Future;
use tokio_util::compat::{Compat,TokioAsyncReadCompatExt};
use futures::io::{BufReader, BufWriter};
use hyper::server::conn::AddrStream;
use hyper::{Body, Request, Response, Server};
use std::future::Future;
use std::net::SocketAddr;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
/// A convenience function to start up a Hyper server and handle requests.
pub async fn start_server<H, F>(addr: SocketAddr, handler: H) -> Result<(), anyhow::Error>
where
H: Clone + Send + Sync + 'static + FnMut(SocketAddr, Request<Body>) -> F,
F: Send + 'static + Future<Output = Result<Response<Body>, anyhow::Error>>
F: Send + 'static + Future<Output = Result<Response<Body>, anyhow::Error>>,
{
let service =
hyper::service::make_service_fn(move |addr: &AddrStream| {
let mut handler = handler.clone();
let addr = addr.remote_addr();
async move { Ok::<_, hyper::Error>(hyper::service::service_fn(move |r| handler(addr, r))) }
});
let server = Server::bind(&addr).serve(service);
let service = hyper::service::make_service_fn(move |addr: &AddrStream| {
let mut handler = handler.clone();
let addr = addr.remote_addr();
async move { Ok::<_, hyper::Error>(hyper::service::service_fn(move |r| handler(addr, r))) }
});
let server = Server::bind(&addr).serve(service);
log::info!("listening on http://{}", server.local_addr());
server.await?;
Ok(())
log::info!("listening on http://{}", server.local_addr());
server.await?;
Ok(())
}
type WsStream = BufReader<BufWriter<Compat<hyper::upgrade::Upgraded>>>;
@@ -31,27 +29,40 @@ pub type WsSender = soketto::connection::Sender<WsStream>;
pub type WsReceiver = soketto::connection::Receiver<WsStream>;
/// A convenience function to upgrade a Hyper request into a Soketto Websocket.
pub fn upgrade_to_websocket<H,F>(req: Request<Body>, on_upgrade: H) -> hyper::Response<Body>
pub fn upgrade_to_websocket<H, F>(req: Request<Body>, on_upgrade: H) -> hyper::Response<Body>
where
H: 'static + Send + FnOnce(WsSender, WsReceiver) -> F,
F: Send + Future<Output = ()>
F: Send + Future<Output = ()>,
{
if !is_upgrade_request(&req) {
return basic_response(400, "Expecting WebSocket upgrade headers");
return basic_response(400, "Expecting WebSocket upgrade headers");
}
let key = match req.headers().get("Sec-WebSocket-Key") {
let key = match req.headers().get("Sec-WebSocket-Key") {
Some(key) => key,
None => return basic_response(400, "Upgrade to websocket connection failed; Sec-WebSocket-Key header not provided")
None => {
return basic_response(
400,
"Upgrade to websocket connection failed; Sec-WebSocket-Key header not provided",
)
}
};
if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") {
return basic_response(400, "Sec-WebSocket-Version header should have a value of 13");
}
if req
.headers()
.get("Sec-WebSocket-Version")
.map(|v| v.as_bytes())
!= Some(b"13")
{
return basic_response(
400,
"Sec-WebSocket-Version header should have a value of 13",
);
}
// Just a little ceremony we need to go to to return the correct response key:
let mut accept_key_buf = [0; 32];
let accept_key = generate_websocket_accept_key(key.as_bytes(), &mut accept_key_buf);
// Just a little ceremony we need to go to to return the correct response key:
let mut accept_key_buf = [0; 32];
let accept_key = generate_websocket_accept_key(key.as_bytes(), &mut accept_key_buf);
// Tell the client that we accept the upgrade-to-WS request:
let response = Response::builder()
@@ -74,13 +85,11 @@ where
};
// Start a Soketto server with it:
let server = soketto::handshake::Server::new(
BufReader::new(BufWriter::new(stream.compat()))
);
let server =
soketto::handshake::Server::new(BufReader::new(BufWriter::new(stream.compat())));
// Get hold of a way to send and receive messages:
let (sender, receiver)
= server.into_builder().finish();
let (sender, receiver) = server.into_builder().finish();
// Pass these to our when-upgraded handler:
on_upgrade(sender, receiver).await;
@@ -100,40 +109,48 @@ fn basic_response(code: u16, msg: impl AsRef<str>) -> Response<Body> {
/// Defined in RFC 6455. this is how we convert the Sec-WebSocket-Key in a request into a
/// Sec-WebSocket-Accept that we return in the response.
fn generate_websocket_accept_key<'a>(key: &[u8], buf: &'a mut [u8; 32]) -> &'a [u8] {
// Defined in RFC 6455, we append this to the key to generate the response:
const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// Defined in RFC 6455, we append this to the key to generate the response:
const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
use sha1::{Digest, Sha1};
let mut digest = Sha1::new();
digest.update(key);
digest.update(KEY);
let d = digest.finalize();
use sha1::{Digest, Sha1};
let mut digest = Sha1::new();
digest.update(key);
digest.update(KEY);
let d = digest.finalize();
let n = base64::encode_config_slice(&d, base64::STANDARD, buf);
&buf[..n]
let n = base64::encode_config_slice(&d, base64::STANDARD, buf);
&buf[..n]
}
/// Check if a request is a websocket upgrade request.
fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
header_contains_value(request.headers(), hyper::header::CONNECTION, b"upgrade")
&& header_contains_value(request.headers(), hyper::header::UPGRADE, b"websocket")
header_contains_value(request.headers(), hyper::header::CONNECTION, b"upgrade")
&& header_contains_value(request.headers(), hyper::header::UPGRADE, b"websocket")
}
/// Check if there is a header of the given name containing the wanted value.
fn header_contains_value(headers: &hyper::HeaderMap, header: hyper::header::HeaderName, value: &[u8]) -> bool {
pub fn trim(x: &[u8]) -> &[u8] {
let from = match x.iter().position(|x| !x.is_ascii_whitespace()) {
Some(i) => i,
None => return &x[0..0],
};
let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap();
&x[from..=to]
}
fn header_contains_value(
headers: &hyper::HeaderMap,
header: hyper::header::HeaderName,
value: &[u8],
) -> bool {
pub fn trim(x: &[u8]) -> &[u8] {
let from = match x.iter().position(|x| !x.is_ascii_whitespace()) {
Some(i) => i,
None => return &x[0..0],
};
let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap();
&x[from..=to]
}
for header in headers.get_all(header) {
if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) {
return true;
}
}
false
}
for header in headers.get_all(header) {
if header
.as_bytes()
.split(|&c| c == b',')
.any(|x| trim(x).eq_ignore_ascii_case(value))
{
return true;
}
}
false
}
+2 -2
View File
@@ -1,11 +1,11 @@
pub mod http_utils;
pub mod id_type;
pub mod internal_messages;
pub mod node_message;
pub mod node_types;
pub mod ready_chunks_all;
pub mod time;
pub mod ws_client;
pub mod ready_chunks_all;
pub mod http_utils;
mod assign_id;
mod dense_map;
+3 -3
View File
@@ -9,12 +9,12 @@
//! Code is adapted from the futures implementation
//! (see [ready_chunks.rs](https://docs.rs/futures-util/0.3.15/src/futures_util/stream/stream/ready_chunks.rs.html)).
use futures::stream::Fuse;
use futures::StreamExt;
use core::mem;
use core::pin::Pin;
use futures::stream::Fuse;
use futures::stream::{FusedStream, Stream};
use futures::task::{Context, Poll};
use futures::StreamExt;
use pin_project_lite::pin_project;
pin_project! {
@@ -37,7 +37,7 @@ where
pub fn new(stream: St) -> Self {
Self {
stream: stream.fuse(),
items: Vec::new()
items: Vec::new(),
}
}
}
+8 -7
View File
@@ -5,21 +5,22 @@ use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncReadCompatExt;
use super::{
sender::{ Sender, SentMessage, SentMessageInternal },
receiver::{ Receiver, RecvMessage }
receiver::{Receiver, RecvMessage},
sender::{Sender, SentMessage, SentMessageInternal},
};
/// The send side of a Soketto WebSocket connection
pub type RawSender = soketto::connection::Sender<tokio_util::compat::Compat<tokio::net::TcpStream>>;
/// The receive side of a Soketto WebSocket connection
pub type RawReceiver = soketto::connection::Receiver<tokio_util::compat::Compat<tokio::net::TcpStream>>;
pub type RawReceiver =
soketto::connection::Receiver<tokio_util::compat::Compat<tokio::net::TcpStream>>;
/// A websocket connection. From this, we can either expose the raw connection
/// or expose a cancel-safe interface to it.
pub struct Connection {
tx: soketto::connection::Sender<tokio_util::compat::Compat<tokio::net::TcpStream>>,
rx: soketto::connection::Receiver<tokio_util::compat::Compat<tokio::net::TcpStream>>
rx: soketto::connection::Receiver<tokio_util::compat::Compat<tokio::net::TcpStream>>,
}
impl Connection {
@@ -109,7 +110,7 @@ impl Connection {
);
break;
}
},
}
SentMessageInternal::Message(SentMessage::StaticText(s)) => {
if let Err(e) = ws_to_connection.send_text(s).await {
log::error!(
@@ -127,7 +128,7 @@ impl Connection {
);
break;
}
},
}
SentMessageInternal::Close => {
if let Err(e) = ws_to_connection.close().await {
log::error!("Error attempting to close connection: {}", e);
@@ -185,6 +186,6 @@ pub async fn connect(uri: &http::Uri) -> Result<Connection, ConnectError> {
Ok(Connection {
tx: ws_to_connection,
rx: ws_from_connection
rx: ws_from_connection,
})
}
+5 -5
View File
@@ -1,10 +1,10 @@
/// Functionality to establish a connection
mod connect;
/// The channel based send interface
mod sender;
/// The channel based receive interface
mod receiver;
/// The channel based send interface
mod sender;
pub use connect::{ connect, ConnectError, Connection, RawSender, RawReceiver };
pub use sender::{ Sender, SentMessage, SendError };
pub use receiver::{ Receiver, RecvMessage, RecvError };
pub use connect::{connect, ConnectError, Connection, RawReceiver, RawSender};
pub use receiver::{Receiver, RecvError, RecvMessage};
pub use sender::{SendError, Sender, SentMessage};
+2 -2
View File
@@ -3,7 +3,7 @@ use futures::{Stream, StreamExt};
/// Receive messages out of a connection
pub struct Receiver {
pub (super) inner: mpsc::UnboundedReceiver<Result<RecvMessage, RecvError>>,
pub(super) inner: mpsc::UnboundedReceiver<Result<RecvMessage, RecvError>>,
}
#[derive(thiserror::Error, Debug)]
@@ -40,4 +40,4 @@ impl RecvMessage {
RecvMessage::Text(s) => s.len(),
}
}
}
}
+8 -5
View File
@@ -22,7 +22,7 @@ pub enum SentMessage {
/// Messages sent into the channel interface can be anything publically visible, or a close message.
#[derive(Debug, Clone)]
pub (super) enum SentMessageInternal {
pub(super) enum SentMessageInternal {
Message(SentMessage),
Close,
}
@@ -30,7 +30,7 @@ pub (super) enum SentMessageInternal {
/// Send messages into the connection
#[derive(Clone)]
pub struct Sender {
pub (super) inner: mpsc::UnboundedSender<SentMessageInternal>,
pub(super) inner: mpsc::UnboundedSender<SentMessageInternal>,
}
impl Sender {
@@ -56,7 +56,7 @@ impl Sender {
#[derive(thiserror::Error, Debug, Clone)]
pub enum SendError {
#[error("Failed to send message: {0}")]
ChannelError(#[from] mpsc::SendError)
ChannelError(#[from] mpsc::SendError),
}
impl Sink<SentMessage> for Sender {
@@ -67,7 +67,10 @@ impl Sink<SentMessage> for Sender {
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready_unpin(cx).map_err(|e| e.into())
}
fn start_send(mut self: std::pin::Pin<&mut Self>, item: SentMessage) -> Result<(), Self::Error> {
fn start_send(
mut self: std::pin::Pin<&mut Self>,
item: SentMessage,
) -> Result<(), Self::Error> {
self.inner
.start_send_unpin(SentMessageInternal::Message(item))
.map_err(|e| e.into())
@@ -84,4 +87,4 @@ impl Sink<SentMessage> for Sender {
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_close_unpin(cx).map_err(|e| e.into())
}
}
}