// Copyright 2017-2021 Parity Technologies (UK) Ltd.
// This file is part of Polkadot.
// Polkadot is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Polkadot is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Polkadot. If not, see .
//! A small set of wrapping types to cover most of our adversary test cases.
//!
//! This allows types with internal mutability to synchronize across
//! multiple subsystems and intercept or replace incoming and outgoing
//! messages on the overseer level.
use polkadot_node_subsystem::*;
pub use polkadot_node_subsystem::{messages::AllMessages, FromOverseer};
use std::future::Future;
use std::pin::Pin;
/// Filter incoming and outgoing messages.
pub trait MsgFilter: Send + Sync + Clone + 'static {
/// The message type the original subsystm handles incoming.
type Message: Send + 'static;
/// Filter messages that are to be received by
/// the subsystem.
fn filter_in(&self, msg: FromOverseer) -> Option> {
Some(msg)
}
/// Modify outgoing messages.
fn filter_out(&self, msg: AllMessages) -> Option {
Some(msg)
}
}
/// A sender with the outgoing messages filtered.
#[derive(Clone)]
pub struct FilteredSender {
inner: Sender,
message_filter: Fil,
}
#[async_trait::async_trait]
impl SubsystemSender for FilteredSender
where
Sender: SubsystemSender,
Fil: MsgFilter,
{
async fn send_message(&mut self, msg: AllMessages) {
if let Some(msg) = self.message_filter.filter_out(msg) {
self.inner.send_message(msg).await;
}
}
async fn send_messages(&mut self, msgs: T)
where
T: IntoIterator + Send,
T::IntoIter: Send,
{
for msg in msgs {
self.send_message(msg).await;
}
}
fn send_unbounded_message(&mut self, msg: AllMessages) {
if let Some(msg) = self.message_filter.filter_out(msg) {
self.inner.send_unbounded_message(msg);
}
}
}
/// A subsystem context, that filters the outgoing messages.
pub struct FilteredContext {
inner: Context,
message_filter: Fil,
sender: FilteredSender<::Sender, Fil>,
}
impl FilteredContext
where
Context: SubsystemContext,
Fil: MsgFilter::Message>,
{
pub fn new(mut inner: Context, message_filter: Fil) -> Self {
let sender = FilteredSender::<::Sender, Fil> {
inner: inner.sender().clone(),
message_filter: message_filter.clone(),
};
Self {
inner,
message_filter,
sender,
}
}
}
#[async_trait::async_trait]
impl SubsystemContext for FilteredContext
where
Context: SubsystemContext,
Fil: MsgFilter::Message>,
{
type Message = ::Message;
type Sender = FilteredSender<::Sender, Fil>;
async fn try_recv(&mut self) -> Result