// Copyright (C) Parity Technologies (UK) Ltd.
// This file is part of Polkadot.
// Polkadot is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Polkadot is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Polkadot. If not, see .
//! A small set of wrapping types to cover most of our adversary test cases.
//!
//! This allows types with internal mutability to synchronize across
//! multiple subsystems and intercept or replace incoming and outgoing
//! messages on the overseer level.
use polkadot_node_subsystem::*;
pub use polkadot_node_subsystem::{messages, messages::*, overseer, FromOrchestra};
use std::{future::Future, pin::Pin};
/// Filter incoming and outgoing messages.
pub trait MessageInterceptor: Send + Sync + Clone + 'static
where
Sender: overseer::SubsystemSender<::OutgoingMessages>
+ Clone
+ 'static,
{
/// The message type the original subsystem handles incoming.
type Message: overseer::AssociateOutgoing + Send + 'static;
/// Filter messages that are to be received by
/// the subsystem.
///
/// For non-trivial cases, the `sender` can be used to send
/// multiple messages after doing some additional processing.
fn intercept_incoming(
&self,
_sender: &mut Sender,
msg: FromOrchestra,
) -> Option> {
Some(msg)
}
/// Specifies if we need to replace some outgoing message with another (potentially empty)
/// message
fn need_intercept_outgoing(
&self,
_msg: &::OutgoingMessages,
) -> bool {
false
}
/// Send modified message instead of the original one
fn intercept_outgoing(
&self,
_msg: &::OutgoingMessages,
) -> Option<::OutgoingMessages> {
None
}
}
/// A sender with the outgoing messages filtered.
#[derive(Clone)]
pub struct InterceptedSender {
inner: Sender,
message_filter: Fil,
}
#[async_trait::async_trait]
impl overseer::SubsystemSender for InterceptedSender
where
OutgoingMessage: overseer::AssociateOutgoing + Send + 'static + TryFrom,
Sender: overseer::SubsystemSender
+ overseer::SubsystemSender<
<
>::Message as overseer::AssociateOutgoing
>::OutgoingMessages
>,
Fil: MessageInterceptor,
>::Message: overseer::AssociateOutgoing,
<
>::Message as overseer::AssociateOutgoing
>::OutgoingMessages:
From + Send + Sync,
>::Error: std::fmt::Debug,
{
async fn send_message(&mut self, msg: OutgoingMessage) {
let msg = <
<>::Message as overseer::AssociateOutgoing
>::OutgoingMessages as From>::from(msg);
if self.message_filter.need_intercept_outgoing(&msg) {
if let Some(msg) = self.message_filter.intercept_outgoing(&msg) {
self.inner.send_message(msg).await;
}
}
else {
self.inner.send_message(msg).await;
}
}
fn try_send_message(&mut self, msg: OutgoingMessage) -> Result<(), TrySendError> {
let msg = <
<>::Message as overseer::AssociateOutgoing
>::OutgoingMessages as From>::from(msg);
if self.message_filter.need_intercept_outgoing(&msg) {
if let Some(real_msg) = self.message_filter.intercept_outgoing(&msg) {
let orig_msg : OutgoingMessage = msg.into().try_into().expect("must be able to recover the original message");
self.inner.try_send_message(real_msg).map_err(|e| {
match e {
TrySendError::Full(_) => TrySendError::Full(orig_msg),
TrySendError::Closed(_) => TrySendError::Closed(orig_msg),
}
})
}
else {
// No message to send after intercepting
Ok(())
}
}
else {
let orig_msg : OutgoingMessage = msg.into().try_into().expect("must be able to recover the original message");
self.inner.try_send_message(orig_msg)
}
}
async fn send_messages(&mut self, msgs: T)
where
T: IntoIterator + Send,
T::IntoIter: Send,
{
for msg in msgs {
self.send_message(msg).await;
}
}
fn send_unbounded_message(&mut self, msg: OutgoingMessage) {
let msg = <
<>::Message as overseer::AssociateOutgoing
>::OutgoingMessages as From>::from(msg);
if self.message_filter.need_intercept_outgoing(&msg) {
if let Some(msg) = self.message_filter.intercept_outgoing(&msg) {
self.inner.send_unbounded_message(msg);
}
}
else {
self.inner.send_unbounded_message(msg);
}
}
}
/// A subsystem context, that filters the outgoing messages.
pub struct InterceptedContext
where
Context: overseer::SubsystemContext,
Fil: MessageInterceptor<::Sender>,
::Sender:
overseer::SubsystemSender<
<
<
Fil as MessageInterceptor<::Sender>
>::Message as overseer::AssociateOutgoing
>::OutgoingMessages,
>,
{
inner: Context,
message_filter: Fil,
sender: InterceptedSender<::Sender, Fil>,
}
impl InterceptedContext
where
Context: overseer::SubsystemContext,
Fil: MessageInterceptor<
::Sender,
Message = ::Message,
>,
::Message: overseer::AssociateOutgoing,
::Sender: overseer::SubsystemSender<
<::Message as overseer::AssociateOutgoing>::OutgoingMessages
>
{
pub fn new(mut inner: Context, message_filter: Fil) -> Self {
let sender = InterceptedSender::<::Sender, Fil> {
inner: inner.sender().clone(),
message_filter: message_filter.clone(),
};
Self { inner, message_filter, sender }
}
}
#[async_trait::async_trait]
impl overseer::SubsystemContext for InterceptedContext
where
Context: overseer::SubsystemContext,
::Message:
overseer::AssociateOutgoing,
::Sender:
overseer::SubsystemSender<
<::Message as overseer::AssociateOutgoing>::OutgoingMessages
>,
InterceptedSender<::Sender, Fil>:
overseer::SubsystemSender<
<::Message as overseer::AssociateOutgoing>::OutgoingMessages
>,
Fil: MessageInterceptor<
::Sender,
Message = ::Message,
>,
{
type Message = ::Message;
type Sender = InterceptedSender<::Sender, Fil>;
type Error = SubsystemError;
type OutgoingMessages = <::Message as overseer::AssociateOutgoing>::OutgoingMessages;
type Signal = OverseerSignal;
async fn try_recv(&mut self) -> Result