Adapt shard to use hyper+soketto instead

This commit is contained in:
James Wilson
2021-07-23 17:07:29 +01:00
parent 480abd1e49
commit cf0e424803
8 changed files with 86 additions and 361 deletions
+1 -1
View File
@@ -49,7 +49,7 @@ pub enum FromWebsocket {
/// Tell the aggregator about a new node.
Add {
message_id: node_message::NodeMessageId,
ip: Option<std::net::IpAddr>,
ip: std::net::IpAddr,
node: common::node_types::NodeDetails,
genesis_hash: BlockHash,
},
+54 -59
View File
@@ -11,11 +11,10 @@ use aggregator::{Aggregator, FromWebsocket};
use common::node_message;
use futures::{channel::mpsc, SinkExt, StreamExt};
use http::Uri;
use real_ip::real_ip;
use simple_logger::SimpleLogger;
use structopt::StructOpt;
use warp::filters::ws;
use warp::Filter;
use hyper::{ Response, Method };
use common::http_utils;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const AUTHORS: &str = env!("CARGO_PKG_AUTHORS");
@@ -63,44 +62,51 @@ async fn main() {
/// Declare our routes and start the server.
async fn start_server(opts: Opts) -> anyhow::Result<()> {
let aggregator = Aggregator::spawn(opts.core_url).await?;
let server = http_utils::start_server(opts.socket, move |addr, req| {
let aggregator = aggregator.clone();
async move {
match (req.method(), req.uri().path().trim_end_matches('/')) {
// Check that the server is up and running:
(&Method::GET, "/health") => {
Ok(Response::new("OK".into()))
},
// Nodes send messages here:
(&Method::GET, "/submit") => {
let real_addr = real_ip::real_ip(addr, req.headers());
Ok(http_utils::upgrade_to_websocket(req, move |ws_send, ws_recv| async move {
let tx_to_aggregator = aggregator.subscribe_node();
let (mut tx_to_aggregator, mut ws_send)
= handle_node_websocket_connection(real_addr, ws_send, ws_recv, tx_to_aggregator).await;
log::info!("Closing /submit connection from {:?}", addr);
// Tell the aggregator that this connection has closed, so it can tidy up.
let _ = tx_to_aggregator.send(FromWebsocket::Disconnected).await;
let _ = ws_send.close().await;
}))
},
// 404 for anything else:
_ => {
Ok(Response::builder()
.status(404)
.body("Not found".into())
.unwrap())
}
}
}
});
// Handle requests to /health by returning OK.
let health_route = warp::path("health").map(|| "OK");
// Handle websocket requests to /submit.
let ws_route = warp::path("submit").and(warp::ws()).and(real_ip()).map(
move |ws: ws::Ws, addr: Option<IpAddr>| {
// Send messages from the websocket connection to this sink
// to have them pass to the aggregator.
let tx_to_aggregator = aggregator.subscribe_node();
log::info!("Opening /submit connection from {:?}", addr);
ws.on_upgrade(move |websocket| async move {
let (mut tx_to_aggregator, websocket) =
handle_node_websocket_connection(websocket, tx_to_aggregator, addr).await;
log::info!("Closing /submit connection from {:?}", addr);
// Tell the aggregator that this connection has closed, so it can tidy up.
let _ = tx_to_aggregator.send(FromWebsocket::Disconnected).await;
// Note: IF we want to close with a status code and reason, we need to construct
// a ws::Message using `ws::Message::close_with`, rather than using this method:
let _ = websocket.close().await;
})
},
);
// Merge the routes and start our server:
let routes = ws_route.or(health_route);
warp::serve(routes).run(opts.socket).await;
server.await?;
Ok(())
}
/// This takes care of handling messages from an established socket connection.
async fn handle_node_websocket_connection<S>(
mut websocket: ws::WebSocket,
real_addr: IpAddr,
ws_send: http_utils::WsSender,
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
addr: Option<IpAddr>,
) -> (S, ws::WebSocket)
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin,
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// This could be a oneshot channel, but it's useful to be able to clone
// messages, and we can't clone oneshot channel senders.
@@ -112,7 +118,7 @@ where
};
if let Err(e) = tx_to_aggregator.send(init_msg).await {
log::error!("Error sending message to aggregator: {}", e);
return (tx_to_aggregator, websocket);
return (tx_to_aggregator, ws_send);
}
// Now we've "initialized", wait for messages from the node. Messages will
@@ -120,38 +126,27 @@ where
// of messages with some message ID will be sent (a node could have more
// than one of these), or updates linked to a specific message_id.
loop {
let mut bytes = Vec::new();
tokio::select! {
// The close channel has fired, so end the loop:
// The close channel has fired, so end the loop. `ws_recv.receive_data` is
// *not* cancel safe, but since we're closing the connection we don't care.
_ = close_connection_rx.next() => {
log::info!("connection to {:?} being closed by aggregator", addr);
log::info!("connection to {:?} being closed by aggregator", real_addr);
break
},
// A message was received; handle it:
msg = websocket.next() => {
let msg = match msg {
Some(msg) => msg,
None => { log::warn!("Websocket connection from {:?} closed", addr); break }
};
// If we see any errors, log them and end our loop:
let msg = match msg {
Err(e) => { log::error!("Error in node websocket connection: {}", e); break },
Ok(msg) => msg,
};
// Close message? Break to close connection.
if msg.is_close() {
msg_info = ws_recv.receive_data(&mut bytes) => {
// Handle the socket closing, or errors receiving the message.
if let Err(soketto::connection::Error::Closed) = msg_info {
break;
}
if let Err(e) = msg_info {
log::error!("Shutting down websocket connection: Failed to receive data: {}", e);
break;
}
// If the message isn't something we want to handle, just ignore it.
// This includes system messages like "pings" and such, so don't log anything.
if !msg.is_binary() && !msg.is_text() {
continue;
}
// Deserialize from JSON, warning in debug mode if deserialization fails:
let bytes = msg.as_bytes();
let node_message: json_message::NodeMessage = match serde_json::from_slice(bytes) {
let node_message: json_message::NodeMessage = match serde_json::from_slice(&bytes) {
Ok(node_message) => node_message,
#[cfg(debug)]
Err(e) => {
@@ -177,7 +172,7 @@ where
if let node_message::Payload::SystemConnected(info) = payload {
let _ = tx_to_aggregator.send(FromWebsocket::Add {
message_id,
ip: addr,
ip: real_addr,
node: info.node,
genesis_hash: info.genesis_hash,
}).await;
@@ -193,5 +188,5 @@ where
}
// Return what we need to close the connection gracefully:
(tx_to_aggregator, websocket)
(tx_to_aggregator, ws_send)
}
+16 -19
View File
@@ -1,10 +1,7 @@
use std::net::{IpAddr, SocketAddr};
use warp::filters::addr;
use warp::filters::header;
use warp::Filter;
/**
A warp filter to extract the "real" IP address of the connection by looking at headers
Extract the "real" IP address of the connection by looking at headers
set by proxies (this is inspired by Actix Web's implementation of the feature).
First, check for the standardised "Forwarded" header. This looks something like:
@@ -21,28 +18,28 @@ appending one to the end. So, take the first of these if it exists.
If still no luck, look for the X-Real-IP header, which we expect to contain a single IP address.
If that _still_ doesn't work, fall back to the socket address of the connection.
Return `None` if all of this fails to yield an address.
*/
pub fn real_ip() -> impl warp::Filter<Extract = (Option<IpAddr>,), Error = warp::Rejection> + Clone
{
header::optional("forwarded")
.and(header::optional("x-forwarded-for"))
.and(header::optional("x-real-ip"))
.and(addr::remote())
.map(pick_best_ip_from_options)
pub fn real_ip(addr: SocketAddr, headers: &hyper::HeaderMap) -> IpAddr {
let forwarded = headers.get("forwarded").and_then(header_as_str);
let forwarded_for = headers.get("x-forwarded-for").and_then(header_as_str);
let real_ip = headers.get("x-real-ip").and_then(header_as_str);
pick_best_ip_from_options(forwarded, forwarded_for, real_ip, addr)
}
fn header_as_str(value: &hyper::header::HeaderValue) -> Option<&str> {
std::str::from_utf8(value.as_bytes()).ok()
}
fn pick_best_ip_from_options(
// Forwarded header value (if present)
forwarded: Option<String>,
forwarded: Option<&str>,
// X-Forwarded-For header value (if present)
forwarded_for: Option<String>,
forwarded_for: Option<&str>,
// X-Real-IP header value (if present)
real_ip: Option<String>,
real_ip: Option<&str>,
// socket address (if known)
addr: Option<SocketAddr>,
) -> Option<IpAddr> {
addr: SocketAddr,
) -> IpAddr {
let realip = forwarded
.as_ref()
.and_then(|val| get_first_addr_from_forwarded_header(val))
@@ -65,7 +62,7 @@ fn pick_best_ip_from_options(
.ok()
})
// Fall back to local IP address if the above fails
.or(addr.map(|a| a.ip()));
.unwrap_or(addr.ip());
realip
}