rpc server with HTTP/WS on the same socket (#12663)

* jsonrpsee v0.16

add backwards compatibility

run old http server on http only

* cargo fmt

* update jsonrpsee 0.16.1

* less verbose cors log

* fix nit in log: WS -> HTTP

* revert needless changes in Cargo.lock

* remove unused features in tower

* fix nits; add client-core feature

* jsonrpsee v0.16.2
This commit is contained in:
Niklas Adolfsson
2022-12-12 11:32:55 +01:00
committed by GitHub
parent 84a383f035
commit 84303ca75d
33 changed files with 364 additions and 326 deletions
+83 -59
View File
@@ -21,17 +21,21 @@
#![warn(missing_docs)]
use jsonrpsee::{
http_server::{AccessControlBuilder, HttpServerBuilder, HttpServerHandle},
ws_server::{WsServerBuilder, WsServerHandle},
server::{
middleware::proxy_get_request::ProxyGetRequestLayer, AllowHosts, ServerBuilder,
ServerHandle,
},
RpcModule,
};
use std::{error::Error as StdError, net::SocketAddr};
pub use crate::middleware::{RpcMetrics, RpcMiddleware};
pub use crate::middleware::RpcMetrics;
use http::header::HeaderValue;
pub use jsonrpsee::core::{
id_providers::{RandomIntegerIdProvider, RandomStringIdProvider},
traits::IdProvider,
};
use tower_http::cors::{AllowOrigin, CorsLayer};
const MEGABYTE: usize = 1024 * 1024;
@@ -46,12 +50,11 @@ const WS_MAX_SUBS_PER_CONN: usize = 1024;
pub mod middleware;
/// Type alias for http server
pub type HttpServer = HttpServerHandle;
/// Type alias for ws server
pub type WsServer = WsServerHandle;
/// Type alias JSON-RPC server
pub type Server = ServerHandle;
/// WebSocket specific settings on the server.
/// Server config.
#[derive(Debug, Clone)]
pub struct WsConfig {
/// Maximum connections.
pub max_connections: Option<usize>,
@@ -67,8 +70,8 @@ impl WsConfig {
// Deconstructs the config to get the finalized inner values.
//
// `Payload size` or `max subs per connection` bigger than u32::MAX will be truncated.
fn deconstruct(self) -> (u32, u32, u64, u32) {
let max_conns = self.max_connections.unwrap_or(WS_MAX_CONNECTIONS) as u64;
fn deconstruct(self) -> (u32, u32, u32, u32) {
let max_conns = self.max_connections.unwrap_or(WS_MAX_CONNECTIONS) as u32;
let max_payload_in_mb = payload_size_or_default(self.max_payload_in_mb) as u32;
let max_payload_out_mb = payload_size_or_default(self.max_payload_out_mb) as u32;
let max_subs_per_conn = self.max_subs_per_conn.unwrap_or(WS_MAX_SUBS_PER_CONN) as u32;
@@ -86,31 +89,27 @@ pub async fn start_http<M: Send + Sync + 'static>(
metrics: Option<RpcMetrics>,
rpc_api: RpcModule<M>,
rt: tokio::runtime::Handle,
) -> Result<HttpServerHandle, Box<dyn StdError + Send + Sync>> {
let max_payload_in = payload_size_or_default(max_payload_in_mb);
let max_payload_out = payload_size_or_default(max_payload_out_mb);
) -> Result<ServerHandle, Box<dyn StdError + Send + Sync>> {
let max_payload_in = payload_size_or_default(max_payload_in_mb) as u32;
let max_payload_out = payload_size_or_default(max_payload_out_mb) as u32;
let host_filter = hosts_filter(cors.is_some(), &addrs);
let mut acl = AccessControlBuilder::new();
let middleware = tower::ServiceBuilder::new()
// Proxy `GET /health` requests to internal `system_health` method.
.layer(ProxyGetRequestLayer::new("/health", "system_health")?)
.layer(try_into_cors(cors)?);
if let Some(cors) = cors {
// Whitelist listening address.
// NOTE: set_allowed_hosts will whitelist both ports but only one will used.
acl = acl.set_allowed_hosts(format_allowed_hosts(&addrs[..]))?;
acl = acl.set_allowed_origins(cors)?;
};
let builder = HttpServerBuilder::new()
.max_request_body_size(max_payload_in as u32)
.max_response_body_size(max_payload_out as u32)
.set_access_control(acl.build())
.health_api("/health", "system_health")?
.custom_tokio_runtime(rt);
let builder = ServerBuilder::new()
.max_request_body_size(max_payload_in)
.max_response_body_size(max_payload_out)
.set_host_filtering(host_filter)
.set_middleware(middleware)
.custom_tokio_runtime(rt)
.http_only();
let rpc_api = build_rpc_api(rpc_api);
let (handle, addr) = if let Some(metrics) = metrics {
let middleware = RpcMiddleware::new(metrics, "http".into());
let builder = builder.set_middleware(middleware);
let server = builder.build(&addrs[..]).await?;
let server = builder.set_logger(metrics).build(&addrs[..]).await?;
let addr = server.local_addr();
(server.start(rpc_api)?, addr)
} else {
@@ -120,16 +119,16 @@ pub async fn start_http<M: Send + Sync + 'static>(
};
log::info!(
"Running JSON-RPC HTTP server: addr={}, allowed origins={:?}",
"Running JSON-RPC HTTP server: addr={}, allowed origins={}",
addr.map_or_else(|_| "unknown".to_string(), |a| a.to_string()),
cors
format_cors(cors)
);
Ok(handle)
}
/// Start WS server listening on given address.
pub async fn start_ws<M: Send + Sync + 'static>(
/// Start a JSON-RPC server listening on given address that supports both HTTP and WS.
pub async fn start<M: Send + Sync + 'static>(
addrs: [SocketAddr; 2],
cors: Option<&Vec<String>>,
ws_config: WsConfig,
@@ -137,27 +136,26 @@ pub async fn start_ws<M: Send + Sync + 'static>(
rpc_api: RpcModule<M>,
rt: tokio::runtime::Handle,
id_provider: Option<Box<dyn IdProvider>>,
) -> Result<WsServerHandle, Box<dyn StdError + Send + Sync>> {
) -> Result<ServerHandle, Box<dyn StdError + Send + Sync>> {
let (max_payload_in, max_payload_out, max_connections, max_subs_per_conn) =
ws_config.deconstruct();
let mut acl = AccessControlBuilder::new();
let host_filter = hosts_filter(cors.is_some(), &addrs);
if let Some(cors) = cors {
// Whitelist listening address.
// NOTE: set_allowed_hosts will whitelist both ports but only one will used.
acl = acl.set_allowed_hosts(format_allowed_hosts(&addrs[..]))?;
acl = acl.set_allowed_origins(cors)?;
};
let middleware = tower::ServiceBuilder::new()
// Proxy `GET /health` requests to internal `system_health` method.
.layer(ProxyGetRequestLayer::new("/health", "system_health")?)
.layer(try_into_cors(cors)?);
let mut builder = WsServerBuilder::new()
let mut builder = ServerBuilder::new()
.max_request_body_size(max_payload_in)
.max_response_body_size(max_payload_out)
.max_connections(max_connections)
.max_subscriptions_per_connection(max_subs_per_conn)
.ping_interval(std::time::Duration::from_secs(30))
.custom_tokio_runtime(rt)
.set_access_control(acl.build());
.set_host_filtering(host_filter)
.set_middleware(middleware)
.custom_tokio_runtime(rt);
if let Some(provider) = id_provider {
builder = builder.set_id_provider(provider);
@@ -167,9 +165,7 @@ pub async fn start_ws<M: Send + Sync + 'static>(
let rpc_api = build_rpc_api(rpc_api);
let (handle, addr) = if let Some(metrics) = metrics {
let middleware = RpcMiddleware::new(metrics, "ws".into());
let builder = builder.set_middleware(middleware);
let server = builder.build(&addrs[..]).await?;
let server = builder.set_logger(metrics).build(&addrs[..]).await?;
let addr = server.local_addr();
(server.start(rpc_api)?, addr)
} else {
@@ -179,23 +175,14 @@ pub async fn start_ws<M: Send + Sync + 'static>(
};
log::info!(
"Running JSON-RPC WS server: addr={}, allowed origins={:?}",
"Running JSON-RPC WS server: addr={}, allowed origins={}",
addr.map_or_else(|_| "unknown".to_string(), |a| a.to_string()),
cors
format_cors(cors)
);
Ok(handle)
}
fn format_allowed_hosts(addrs: &[SocketAddr]) -> Vec<String> {
let mut hosts = Vec::with_capacity(addrs.len() * 2);
for addr in addrs {
hosts.push(format!("localhost:{}", addr.port()));
hosts.push(format!("127.0.0.1:{}", addr.port()));
}
hosts
}
fn build_rpc_api<M: Send + Sync + 'static>(mut rpc_api: RpcModule<M>) -> RpcModule<M> {
let mut available_methods = rpc_api.method_names().collect::<Vec<_>>();
available_methods.sort();
@@ -214,3 +201,40 @@ fn build_rpc_api<M: Send + Sync + 'static>(mut rpc_api: RpcModule<M>) -> RpcModu
fn payload_size_or_default(size_mb: Option<usize>) -> usize {
size_mb.map_or(RPC_MAX_PAYLOAD_DEFAULT, |mb| mb.saturating_mul(MEGABYTE))
}
fn hosts_filter(enabled: bool, addrs: &[SocketAddr]) -> AllowHosts {
if enabled {
// NOTE The listening addresses are whitelisted by default.
let mut hosts = Vec::with_capacity(addrs.len() * 2);
for addr in addrs {
hosts.push(format!("localhost:{}", addr.port()).into());
hosts.push(format!("127.0.0.1:{}", addr.port()).into());
}
AllowHosts::Only(hosts)
} else {
AllowHosts::Any
}
}
fn try_into_cors(
maybe_cors: Option<&Vec<String>>,
) -> Result<CorsLayer, Box<dyn StdError + Send + Sync>> {
if let Some(cors) = maybe_cors {
let mut list = Vec::new();
for origin in cors {
list.push(HeaderValue::from_str(origin)?);
}
Ok(CorsLayer::new().allow_origin(AllowOrigin::list(list)))
} else {
// allow all cors
Ok(CorsLayer::permissive())
}
}
fn format_cors(maybe_cors: Option<&Vec<String>>) -> String {
if let Some(cors) = maybe_cors {
format!("{:?}", cors)
} else {
format!("{:?}", ["*"])
}
}
+49 -85
View File
@@ -16,9 +16,9 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//! RPC middlware to collect prometheus metrics on RPC calls.
//! RPC middleware to collect prometheus metrics on RPC calls.
use jsonrpsee::core::middleware::{Headers, HttpMiddleware, MethodKind, Params, WsMiddleware};
use jsonrpsee::server::logger::{HttpRequest, Logger, MethodKind, Params, TransportProtocol};
use prometheus_endpoint::{
register, Counter, CounterVec, HistogramOpts, HistogramVec, Opts, PrometheusError, Registry,
U64,
@@ -54,9 +54,9 @@ pub struct RpcMetrics {
calls_started: CounterVec<U64>,
/// Number of calls completed.
calls_finished: CounterVec<U64>,
/// Number of Websocket sessions opened (Websocket only).
/// Number of Websocket sessions opened.
ws_sessions_opened: Option<Counter<U64>>,
/// Number of Websocket sessions closed (Websocket only).
/// Number of Websocket sessions closed.
ws_sessions_closed: Option<Counter<U64>>,
}
@@ -139,62 +139,61 @@ impl RpcMetrics {
}
}
#[derive(Clone)]
/// Middleware for RPC calls
pub struct RpcMiddleware {
metrics: RpcMetrics,
transport_label: &'static str,
}
impl Logger for RpcMetrics {
type Instant = std::time::Instant;
impl RpcMiddleware {
/// Create a new [`RpcMiddleware`] with the provided [`RpcMetrics`].
pub fn new(metrics: RpcMetrics, transport_label: &'static str) -> Self {
Self { metrics, transport_label }
fn on_connect(
&self,
_remote_addr: SocketAddr,
_request: &HttpRequest,
transport: TransportProtocol,
) {
if let TransportProtocol::WebSocket = transport {
self.ws_sessions_opened.as_ref().map(|counter| counter.inc());
}
}
/// Called when a new JSON-RPC request comes to the server.
fn on_request(&self) -> std::time::Instant {
fn on_request(&self, transport: TransportProtocol) -> Self::Instant {
let transport_label = transport_label_str(transport);
let now = std::time::Instant::now();
self.metrics.requests_started.with_label_values(&[self.transport_label]).inc();
self.requests_started.with_label_values(&[transport_label]).inc();
now
}
/// Called on each JSON-RPC method call, batch requests will trigger `on_call` multiple times.
fn on_call(&self, name: &str, params: Params, kind: MethodKind) {
fn on_call(&self, name: &str, params: Params, kind: MethodKind, transport: TransportProtocol) {
let transport_label = transport_label_str(transport);
log::trace!(
target: "rpc_metrics",
"[{}] on_call name={} params={:?} kind={}",
self.transport_label,
transport_label,
name,
params,
kind,
);
self.metrics
.calls_started
.with_label_values(&[self.transport_label, name])
.inc();
self.calls_started.with_label_values(&[transport_label, name]).inc();
}
/// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple
/// times.
fn on_result(&self, name: &str, success: bool, started_at: std::time::Instant) {
fn on_result(
&self,
name: &str,
success: bool,
started_at: Self::Instant,
transport: TransportProtocol,
) {
let transport_label = transport_label_str(transport);
let micros = started_at.elapsed().as_micros();
log::debug!(
target: "rpc_metrics",
"[{}] {} call took {} μs",
self.transport_label,
transport_label,
name,
micros,
);
self.metrics
.calls_time
.with_label_values(&[self.transport_label, name])
.observe(micros as _);
self.calls_time.with_label_values(&[transport_label, name]).observe(micros as _);
self.metrics
.calls_finished
self.calls_finished
.with_label_values(&[
self.transport_label,
transport_label,
name,
// the label "is_error", so `success` should be regarded as false
// and vice-versa to be registrered correctly.
@@ -203,58 +202,23 @@ impl RpcMiddleware {
.inc();
}
/// Called once the JSON-RPC request is finished and response is sent to the output buffer.
fn on_response(&self, result: &str, started_at: std::time::Instant) {
log::trace!(target: "rpc_metrics", "[{}] on_response started_at={:?}", self.transport_label, started_at);
log::trace!(target: "rpc_metrics::extra", "[{}] result={:?}", self.transport_label, result);
self.metrics.requests_finished.with_label_values(&[self.transport_label]).inc();
fn on_response(&self, result: &str, started_at: Self::Instant, transport: TransportProtocol) {
let transport_label = transport_label_str(transport);
log::trace!(target: "rpc_metrics", "[{}] on_response started_at={:?}", transport_label, started_at);
log::trace!(target: "rpc_metrics::extra", "[{}] result={:?}", transport_label, result);
self.requests_finished.with_label_values(&[transport_label]).inc();
}
fn on_disconnect(&self, _remote_addr: SocketAddr, transport: TransportProtocol) {
if let TransportProtocol::WebSocket = transport {
self.ws_sessions_closed.as_ref().map(|counter| counter.inc());
}
}
}
impl WsMiddleware for RpcMiddleware {
type Instant = std::time::Instant;
fn on_connect(&self, _remote_addr: SocketAddr, _headers: &Headers) {
self.metrics.ws_sessions_opened.as_ref().map(|counter| counter.inc());
}
fn on_request(&self) -> Self::Instant {
self.on_request()
}
fn on_call(&self, name: &str, params: Params, kind: MethodKind) {
self.on_call(name, params, kind)
}
fn on_result(&self, name: &str, success: bool, started_at: Self::Instant) {
self.on_result(name, success, started_at)
}
fn on_response(&self, _result: &str, started_at: Self::Instant) {
self.on_response(_result, started_at)
}
fn on_disconnect(&self, _remote_addr: SocketAddr) {
self.metrics.ws_sessions_closed.as_ref().map(|counter| counter.inc());
}
}
impl HttpMiddleware for RpcMiddleware {
type Instant = std::time::Instant;
fn on_request(&self, _remote_addr: SocketAddr, _headers: &Headers) -> Self::Instant {
self.on_request()
}
fn on_call(&self, name: &str, params: Params, kind: MethodKind) {
self.on_call(name, params, kind)
}
fn on_result(&self, name: &str, success: bool, started_at: Self::Instant) {
self.on_result(name, success, started_at)
}
fn on_response(&self, _result: &str, started_at: Self::Instant) {
self.on_response(_result, started_at)
fn transport_label_str(t: TransportProtocol) -> &'static str {
match t {
TransportProtocol::Http => "http",
TransportProtocol::WebSocket => "ws",
}
}