[FRAME] Make MQ pallet re-entrancy safe (#2356)

Closes https://github.com/paritytech/polkadot-sdk/issues/2319

Changes:
- Ensure that only `enqueue_message(s)` is callable from within the
message processor. This prevents messed up storage that can currently
happen when the pallet is called into recursively.
- Use `H256` instead of `[u8; 32]` for clearer API.

## Details

The re-entracy check is done with the `environmental` crate by adding a
`with_service_mutex(f)` function that runs the closure exclusively. This
works since the MQ pallet is not instantiable.

---------

Signed-off-by: Oliver Tale-Yazdi <oliver.tale-yazdi@parity.io>
Co-authored-by: Francisco Aguirre <franciscoaguirreperez@gmail.com>
This commit is contained in:
Oliver Tale-Yazdi
2023-12-07 17:48:13 +01:00
committed by GitHub
parent 95c3ee10bc
commit 7e7fe99047
11 changed files with 478 additions and 79 deletions
Generated
+1
View File
@@ -10035,6 +10035,7 @@ dependencies = [
name = "pallet-message-queue" name = "pallet-message-queue"
version = "7.0.0-dev" version = "7.0.0-dev"
dependencies = [ dependencies = [
"environmental",
"frame-benchmarking", "frame-benchmarking",
"frame-support", "frame-support",
"frame-system", "frame-system",
+2 -9
View File
@@ -22,8 +22,7 @@ use XcmpMessageFormat::*;
use codec::Input; use codec::Input;
use cumulus_primitives_core::{ParaId, XcmpMessageHandler}; use cumulus_primitives_core::{ParaId, XcmpMessageHandler};
use frame_support::{ use frame_support::{
assert_err, assert_noop, assert_ok, assert_storage_noop, hypothetically, assert_err, assert_noop, assert_ok, assert_storage_noop, hypothetically, traits::Hooks,
traits::{Footprint, Hooks},
StorageNoopGuard, StorageNoopGuard,
}; };
use mock::{new_test_ext, ParachainSystem, RuntimeOrigin as Origin, Test, XcmpQueue}; use mock::{new_test_ext, ParachainSystem, RuntimeOrigin as Origin, Test, XcmpQueue};
@@ -100,7 +99,7 @@ fn xcm_enqueueing_multiple_times_works() {
} }
#[test] #[test]
#[cfg_attr(debug_assertions, should_panic = "Defensive failure")] #[cfg_attr(debug_assertions, should_panic = "Could not enqueue XCMP messages.")]
fn xcm_enqueueing_starts_dropping_on_overflow() { fn xcm_enqueueing_starts_dropping_on_overflow() {
new_test_ext().execute_with(|| { new_test_ext().execute_with(|| {
let xcm = VersionedXcm::<Test>::from(Xcm::<Test>(vec![ClearOrigin])); let xcm = VersionedXcm::<Test>::from(Xcm::<Test>(vec![ClearOrigin]));
@@ -112,12 +111,6 @@ fn xcm_enqueueing_starts_dropping_on_overflow() {
repeat((1000.into(), 1, data.as_slice())).take(limit * 2), repeat((1000.into(), 1, data.as_slice())).take(limit * 2),
Weight::MAX, Weight::MAX,
); );
assert_eq!(EnqueuedMessages::get().len(), limit);
// The drop threshold for pages is 48, the others numbers dont really matter:
assert_eq!(
<Test as Config>::XcmpQueue::footprint(1000.into()),
QueueFootprint { storage: Footprint { count: 256, size: 768 }, pages: 48 }
);
}) })
} }
+4 -4
View File
@@ -523,21 +523,21 @@ fn overweight_queue_works() {
assert_last_events( assert_last_events(
[ [
pallet_message_queue::Event::<Test>::Processed { pallet_message_queue::Event::<Test>::Processed {
id: hash_1, id: hash_1.into(),
origin: Ump(UmpQueueId::Para(para_a)), origin: Ump(UmpQueueId::Para(para_a)),
weight_used: Weight::from_parts(301, 301), weight_used: Weight::from_parts(301, 301),
success: true, success: true,
} }
.into(), .into(),
pallet_message_queue::Event::<Test>::OverweightEnqueued { pallet_message_queue::Event::<Test>::OverweightEnqueued {
id: hash_2, id: hash_2.into(),
origin: Ump(UmpQueueId::Para(para_a)), origin: Ump(UmpQueueId::Para(para_a)),
page_index: 0, page_index: 0,
message_index: 1, message_index: 1,
} }
.into(), .into(),
pallet_message_queue::Event::<Test>::OverweightEnqueued { pallet_message_queue::Event::<Test>::OverweightEnqueued {
id: hash_3, id: hash_3.into(),
origin: Ump(UmpQueueId::Para(para_a)), origin: Ump(UmpQueueId::Para(para_a)),
page_index: 0, page_index: 0,
message_index: 2, message_index: 2,
@@ -565,7 +565,7 @@ fn overweight_queue_works() {
)); ));
assert_last_event( assert_last_event(
pallet_message_queue::Event::<Test>::Processed { pallet_message_queue::Event::<Test>::Processed {
id: hash_3, id: hash_3.into(),
origin: Ump(UmpQueueId::Para(para_a)), origin: Ump(UmpQueueId::Para(para_a)),
weight_used: Weight::from_parts(501, 501), weight_used: Weight::from_parts(501, 501),
success: true, success: true,
+2
View File
@@ -13,6 +13,7 @@ codec = { package = "parity-scale-codec", version = "3.6.1", default-features =
scale-info = { version = "2.10.0", default-features = false, features = ["derive"] } scale-info = { version = "2.10.0", default-features = false, features = ["derive"] }
serde = { version = "1.0.193", optional = true, features = ["derive"] } serde = { version = "1.0.193", optional = true, features = ["derive"] }
log = { version = "0.4.17", default-features = false } log = { version = "0.4.17", default-features = false }
environmental = { version = "1.1.4", default-features = false }
sp-core = { path = "../../primitives/core", default-features = false } sp-core = { path = "../../primitives/core", default-features = false }
sp-io = { path = "../../primitives/io", default-features = false } sp-io = { path = "../../primitives/io", default-features = false }
@@ -34,6 +35,7 @@ rand_distr = "0.4.3"
default = ["std"] default = ["std"]
std = [ std = [
"codec/std", "codec/std",
"environmental/std",
"frame-benchmarking?/std", "frame-benchmarking?/std",
"frame-support/std", "frame-support/std",
"frame-system/std", "frame-system/std",
@@ -25,6 +25,7 @@ use super::{mock_helpers::*, Pallet as MessageQueue, *};
use frame_benchmarking::v2::*; use frame_benchmarking::v2::*;
use frame_support::traits::Get; use frame_support::traits::Get;
use frame_system::RawOrigin; use frame_system::RawOrigin;
use sp_io::hashing::blake2_256;
use sp_std::prelude::*; use sp_std::prelude::*;
#[benchmarks( #[benchmarks(
@@ -142,7 +143,7 @@ mod benchmarks {
// Check that it was processed. // Check that it was processed.
assert_last_event::<T>( assert_last_event::<T>(
Event::Processed { Event::Processed {
id: sp_io::hashing::blake2_256(&msg), id: blake2_256(&msg).into(),
origin: 0.into(), origin: 0.into(),
weight_used: 1.into_weight(), weight_used: 1.into_weight(),
success: true, success: true,
@@ -227,7 +228,7 @@ mod benchmarks {
assert_last_event::<T>( assert_last_event::<T>(
Event::Processed { Event::Processed {
id: sp_io::hashing::blake2_256(&((msgs - 1) as u32).encode()), id: blake2_256(&((msgs - 1) as u32).encode()).into(),
origin: 0.into(), origin: 0.into(),
weight_used: Weight::from_parts(1, 1), weight_used: Weight::from_parts(1, 1),
success: true, success: true,
@@ -264,7 +265,7 @@ mod benchmarks {
assert_last_event::<T>( assert_last_event::<T>(
Event::Processed { Event::Processed {
id: sp_io::hashing::blake2_256(&((msgs - 1) as u32).encode()), id: blake2_256(&((msgs - 1) as u32).encode()).into(),
origin: 0.into(), origin: 0.into(),
weight_used: Weight::from_parts(1, 1), weight_used: Weight::from_parts(1, 1),
success: true, success: true,
@@ -29,8 +29,8 @@
use crate::{ use crate::{
mock::{ mock::{
build_and_execute, CountingMessageProcessor, IntoWeight, MockedWeightInfo, build_and_execute, gen_seed, Callback, CountingMessageProcessor, IntoWeight,
NumMessagesProcessed, YieldingQueues, MessagesProcessed, MockedWeightInfo, NumMessagesProcessed, YieldingQueues,
}, },
mock_helpers::MessageOrigin, mock_helpers::MessageOrigin,
*, *,
@@ -120,13 +120,13 @@ impl Config for Test {
/// Processing all remaining 28639 messages /// Processing all remaining 28639 messages
/// ``` /// ```
#[test] #[test]
#[ignore] // Only run in the CI. #[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_enqueue_and_service() { fn stress_test_enqueue_and_service() {
let blocks = 20; let blocks = 20;
let max_queues = 10_000; let max_queues = 10_000;
let max_messages_per_queue = 10_000; let max_messages_per_queue = 10_000;
let max_msg_len = MaxMessageLenOf::<Test>::get(); let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(43); let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| { build_and_execute::<Test>(|| {
let mut msgs_remaining = 0; let mut msgs_remaining = 0;
@@ -148,6 +148,74 @@ fn stress_test_enqueue_and_service() {
}); });
} }
/// Very similar to `stress_test_enqueue_and_service`, but enqueues messages while processing them.
#[test]
#[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_recursive() {
let blocks = 20;
let mut rng = StdRng::seed_from_u64(gen_seed());
// We need to use thread-locals since the callback cannot capture anything.
parameter_types! {
pub static TotalEnqueued: u32 = 0;
pub static Enqueued: u32 = 0;
pub static Called: u32 = 0;
}
Called::take();
Enqueued::take();
TotalEnqueued::take();
Callback::set(Box::new(|_, _| {
let mut rng = StdRng::seed_from_u64(Enqueued::get() as u64);
let max_queues = 1_000;
let max_messages_per_queue = 1_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
// Instead of directly enqueueing, we enqueue inside a `service` call.
let enqueued = enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
TotalEnqueued::set(TotalEnqueued::get() + enqueued);
Enqueued::set(Enqueued::get() + enqueued);
Called::set(Called::get() + 1);
}));
build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
for b in 0..blocks {
log::info!("Block #{}", b);
MessageQueue::enqueue_message(
BoundedSlice::defensive_truncate_from(format!("callback={b}").as_bytes()),
b.into(),
);
msgs_remaining += Enqueued::take() + 1;
// Pick a fraction of all messages currently in queue and process them.
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); // This also advances the block.
msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed + 1);
MessageQueue::do_try_state().unwrap();
}
while Called::get() < blocks {
msgs_remaining += Enqueued::take();
// Pick a fraction of all messages currently in queue and process them.
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); // This also advances the block.
msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed);
MessageQueue::do_try_state().unwrap();
}
let msgs_remaining = TotalEnqueued::take();
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
assert_eq!(Called::get(), blocks);
post_conditions();
});
}
/// Simulates heavy usage of the suspension logic via `Yield`. /// Simulates heavy usage of the suspension logic via `Yield`.
/// ///
/// # Example output /// # Example output
@@ -164,14 +232,14 @@ fn stress_test_enqueue_and_service() {
/// Processing all remaining 430 messages /// Processing all remaining 430 messages
/// ``` /// ```
#[test] #[test]
#[ignore] // Only run in the CI. #[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_queue_suspension() { fn stress_test_queue_suspension() {
let blocks = 20; let blocks = 20;
let max_queues = 10_000; let max_queues = 10_000;
let max_messages_per_queue = 10_000; let max_messages_per_queue = 10_000;
let (max_suspend_per_block, max_resume_per_block) = (100, 50); let (max_suspend_per_block, max_resume_per_block) = (100, 50);
let max_msg_len = MaxMessageLenOf::<Test>::get(); let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(43); let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| { build_and_execute::<Test>(|| {
let mut suspended = BTreeSet::<u32>::new(); let mut suspended = BTreeSet::<u32>::new();
@@ -300,6 +368,7 @@ fn process_all_messages(expected: u32) {
assert_eq!(consumed, Weight::from_all(expected as u64)); assert_eq!(consumed, Weight::from_all(expected as u64));
assert_eq!(NumMessagesProcessed::take(), expected as usize); assert_eq!(NumMessagesProcessed::take(), expected as usize);
MessagesProcessed::take();
} }
/// Returns the weight consumed by `MessageQueue::on_initialize()`. /// Returns the weight consumed by `MessageQueue::on_initialize()`.
@@ -327,5 +396,6 @@ fn post_conditions() {
assert!(ServiceHead::<Test>::get().is_none()); assert!(ServiceHead::<Test>::get().is_none());
// This still works fine. // This still works fine.
assert_eq!(MessageQueue::service_queues(Weight::MAX), Weight::zero(), "Nothing left"); assert_eq!(MessageQueue::service_queues(Weight::MAX), Weight::zero(), "Nothing left");
MessageQueue::do_try_state().unwrap();
next_block(); next_block();
} }
+176 -49
View File
@@ -49,9 +49,21 @@
//! **Message Execution** //! **Message Execution**
//! //!
//! Executing a message is offloaded to the [`Config::MessageProcessor`] which contains the actual //! Executing a message is offloaded to the [`Config::MessageProcessor`] which contains the actual
//! logic of how to handle the message since they are blobs. A message can be temporarily or //! logic of how to handle the message since they are blobs. Storage changes are not rolled back on
//! permanently overweight. The pallet will perpetually try to execute a temporarily overweight //! error.
//! message. A permanently overweight message is skipped and must be executed manually. //!
//! A failed message can be temporarily or permanently overweight. The pallet will perpetually try
//! to execute a temporarily overweight message. A permanently overweight message is skipped and
//! must be executed manually.
//!
//! **Reentrancy**
//!
//! This pallet has two entry points for executing (possibly recursive) logic;
//! [`Pallet::service_queues`] and [`Pallet::execute_overweight`]. Both entry points are guarded by
//! the same mutex to error on reentrancy. The only functions that are explicitly **allowed** to be
//! called by a message processor are: [`Pallet::enqueue_message`] and
//! [`Pallet::enqueue_messages`]. All other functions are forbidden and error with
//! [`Error::RecursiveDisallowed`].
//! //!
//! **Pagination** //! **Pagination**
//! //!
@@ -146,6 +158,7 @@
//! which is the default state for a message after being enqueued. //! which is the default state for a message after being enqueued.
//! - `knitting`/`unknitting`: The means of adding or removing a `Queue` from the `ReadyRing`. //! - `knitting`/`unknitting`: The means of adding or removing a `Queue` from the `ReadyRing`.
//! - `MEL`: The Max Encoded Length of a type, see [`codec::MaxEncodedLen`]. //! - `MEL`: The Max Encoded Length of a type, see [`codec::MaxEncodedLen`].
//! - `Reentrance`: To enter an execution context again before it has completed.
//! //!
//! # Properties //! # Properties
//! //!
@@ -180,6 +193,7 @@
//! expensive. Currently this is archived by having one queue per para-chain/thread, which keeps the //! expensive. Currently this is archived by having one queue per para-chain/thread, which keeps the
//! number of queues within `O(n)` and should be "good enough". //! number of queues within `O(n)` and should be "good enough".
#![deny(missing_docs)]
#![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(not(feature = "std"), no_std)]
mod benchmarking; mod benchmarking;
@@ -194,8 +208,8 @@ use frame_support::{
defensive, defensive,
pallet_prelude::*, pallet_prelude::*,
traits::{ traits::{
DefensiveTruncateFrom, EnqueueMessage, ExecuteOverweightError, Footprint, ProcessMessage, Defensive, DefensiveTruncateFrom, EnqueueMessage, ExecuteOverweightError, Footprint,
ProcessMessageError, QueueFootprint, QueuePausedQuery, ServiceQueues, ProcessMessage, ProcessMessageError, QueueFootprint, QueuePausedQuery, ServiceQueues,
}, },
BoundedSlice, CloneNoBound, DefaultNoBound, BoundedSlice, CloneNoBound, DefaultNoBound,
}; };
@@ -203,6 +217,7 @@ use frame_system::pallet_prelude::*;
pub use pallet::*; pub use pallet::*;
use scale_info::TypeInfo; use scale_info::TypeInfo;
use sp_arithmetic::traits::{BaseArithmetic, Unsigned}; use sp_arithmetic::traits::{BaseArithmetic, Unsigned};
use sp_core::{defer, H256};
use sp_runtime::{ use sp_runtime::{
traits::{One, Zero}, traits::{One, Zero},
SaturatedConversion, Saturating, SaturatedConversion, Saturating,
@@ -460,6 +475,10 @@ pub mod pallet {
/// Processor for a message. /// Processor for a message.
/// ///
/// Storage changes are not rolled back on error.
///
/// # Benchmarking
///
/// Must be set to [`mock_helpers::NoopMessageProcessor`] for benchmarking. /// Must be set to [`mock_helpers::NoopMessageProcessor`] for benchmarking.
/// Other message processors that consumes exactly (1, 1) weight for any give message will /// Other message processors that consumes exactly (1, 1) weight for any give message will
/// work as well. Otherwise the benchmarking will also measure the weight of the message /// work as well. Otherwise the benchmarking will also measure the weight of the message
@@ -516,18 +535,51 @@ pub mod pallet {
#[pallet::generate_deposit(pub(super) fn deposit_event)] #[pallet::generate_deposit(pub(super) fn deposit_event)]
pub enum Event<T: Config> { pub enum Event<T: Config> {
/// Message discarded due to an error in the `MessageProcessor` (usually a format error). /// Message discarded due to an error in the `MessageProcessor` (usually a format error).
ProcessingFailed { id: [u8; 32], origin: MessageOriginOf<T>, error: ProcessMessageError }, ProcessingFailed {
/// The `blake2_256` hash of the message.
id: H256,
/// The queue of the message.
origin: MessageOriginOf<T>,
/// The error that occurred.
///
/// This error is pretty opaque. More fine-grained errors need to be emitted as events
/// by the `MessageProcessor`.
error: ProcessMessageError,
},
/// Message is processed. /// Message is processed.
Processed { id: [u8; 32], origin: MessageOriginOf<T>, weight_used: Weight, success: bool }, Processed {
/// The `blake2_256` hash of the message.
id: H256,
/// The queue of the message.
origin: MessageOriginOf<T>,
/// How much weight was used to process the message.
weight_used: Weight,
/// Whether the message was processed.
///
/// Note that this does not mean that the underlying `MessageProcessor` was internally
/// successful. It *solely* means that the MQ pallet will treat this as a success
/// condition and discard the message. Any internal error needs to be emitted as events
/// by the `MessageProcessor`.
success: bool,
},
/// Message placed in overweight queue. /// Message placed in overweight queue.
OverweightEnqueued { OverweightEnqueued {
/// The `blake2_256` hash of the message.
id: [u8; 32], id: [u8; 32],
/// The queue of the message.
origin: MessageOriginOf<T>, origin: MessageOriginOf<T>,
/// The page of the message.
page_index: PageIndex, page_index: PageIndex,
/// The index of the message within the page.
message_index: T::Size, message_index: T::Size,
}, },
/// This page was reaped. /// This page was reaped.
PageReaped { origin: MessageOriginOf<T>, index: PageIndex }, PageReaped {
/// The queue of the page.
origin: MessageOriginOf<T>,
/// The index of the page.
index: PageIndex,
},
} }
#[pallet::error] #[pallet::error]
@@ -554,6 +606,8 @@ pub mod pallet {
/// ///
/// This can change at any time and may resolve in the future by re-trying. /// This can change at any time and may resolve in the future by re-trying.
QueuePaused, QueuePaused,
/// Another call is in progress and needs to finish before this call can happen.
RecursiveDisallowed,
} }
/// The index of the first and last (non-empty) pages. /// The index of the first and last (non-empty) pages.
@@ -868,6 +922,21 @@ impl<T: Config> Pallet<T> {
page_index: PageIndex, page_index: PageIndex,
index: T::Size, index: T::Size,
weight_limit: Weight, weight_limit: Weight,
) -> Result<Weight, Error<T>> {
match with_service_mutex(|| {
Self::do_execute_overweight_inner(origin, page_index, index, weight_limit)
}) {
Err(()) => Err(Error::<T>::RecursiveDisallowed),
Ok(x) => x,
}
}
/// Same as `do_execute_overweight` but must be called while holding the `service_mutex`.
fn do_execute_overweight_inner(
origin: MessageOriginOf<T>,
page_index: PageIndex,
index: T::Size,
weight_limit: Weight,
) -> Result<Weight, Error<T>> { ) -> Result<Weight, Error<T>> {
let mut book_state = BookStateFor::<T>::get(&origin); let mut book_state = BookStateFor::<T>::get(&origin);
ensure!(!T::QueuePausedQuery::is_paused(&origin), Error::<T>::QueuePaused); ensure!(!T::QueuePausedQuery::is_paused(&origin), Error::<T>::QueuePaused);
@@ -924,6 +993,14 @@ impl<T: Config> Pallet<T> {
/// Remove a stale page or one which has no more messages remaining to be processed. /// Remove a stale page or one which has no more messages remaining to be processed.
fn do_reap_page(origin: &MessageOriginOf<T>, page_index: PageIndex) -> DispatchResult { fn do_reap_page(origin: &MessageOriginOf<T>, page_index: PageIndex) -> DispatchResult {
match with_service_mutex(|| Self::do_reap_page_inner(origin, page_index)) {
Err(()) => Err(Error::<T>::RecursiveDisallowed.into()),
Ok(x) => x,
}
}
/// Same as `do_reap_page` but must be called while holding the `service_mutex`.
fn do_reap_page_inner(origin: &MessageOriginOf<T>, page_index: PageIndex) -> DispatchResult {
let mut book_state = BookStateFor::<T>::get(origin); let mut book_state = BookStateFor::<T>::get(origin);
// definitely not reapable if the page's index is no less than the `begin`ning of ready // definitely not reapable if the page's index is no less than the `begin`ning of ready
// pages. // pages.
@@ -1112,6 +1189,7 @@ impl<T: Config> Pallet<T> {
weight: &mut WeightMeter, weight: &mut WeightMeter,
overweight_limit: Weight, overweight_limit: Weight,
) -> ItemExecutionStatus { ) -> ItemExecutionStatus {
use MessageExecutionStatus::*;
// This ugly pre-checking is needed for the invariant // This ugly pre-checking is needed for the invariant
// "we never bail if a page became complete". // "we never bail if a page became complete".
if page.is_complete() { if page.is_complete() {
@@ -1125,16 +1203,31 @@ impl<T: Config> Pallet<T> {
Some(m) => m, Some(m) => m,
None => return ItemExecutionStatus::NoItem, None => return ItemExecutionStatus::NoItem,
}[..]; }[..];
let payload_len = payload.len() as u64;
use MessageExecutionStatus::*; // Store these for the case that `process_message_payload` is recursive.
let is_processed = match Self::process_message_payload( Pages::<T>::insert(origin, page_index, &*page);
BookStateFor::<T>::insert(origin, &*book_state);
let res = Self::process_message_payload(
origin.clone(), origin.clone(),
page_index, page_index,
page.first_index, page.first_index,
payload, payload,
weight, weight,
overweight_limit, overweight_limit,
) { );
// And restore them afterwards to see the changes of a recursive call.
*book_state = BookStateFor::<T>::get(origin);
if let Some(new_page) = Pages::<T>::get(origin, page_index) {
*page = new_page;
} else {
defensive!("page must exist since we just inserted it and recursive calls are not allowed to remove anything");
return ItemExecutionStatus::NoItem
};
let is_processed = match res {
InsufficientWeight => return ItemExecutionStatus::Bailed, InsufficientWeight => return ItemExecutionStatus::Bailed,
Unprocessable { permanent: false } => return ItemExecutionStatus::NoProgress, Unprocessable { permanent: false } => return ItemExecutionStatus::NoProgress,
Processed | Unprocessable { permanent: true } => true, Processed | Unprocessable { permanent: true } => true,
@@ -1143,7 +1236,7 @@ impl<T: Config> Pallet<T> {
if is_processed { if is_processed {
book_state.message_count.saturating_dec(); book_state.message_count.saturating_dec();
book_state.size.saturating_reduce(payload.len() as u64); book_state.size.saturating_reduce(payload_len as u64);
} }
page.skip_first(is_processed); page.skip_first(is_processed);
ItemExecutionStatus::Executed(is_processed) ItemExecutionStatus::Executed(is_processed)
@@ -1168,7 +1261,7 @@ impl<T: Config> Pallet<T> {
/// * `remaining_size` > 0 /// * `remaining_size` > 0
/// * `first` <= `last` /// * `first` <= `last`
/// * Every page can be decoded into peek_* functions /// * Every page can be decoded into peek_* functions
#[cfg(any(test, feature = "try-runtime"))] #[cfg(any(test, feature = "try-runtime", feature = "std"))]
pub fn do_try_state() -> Result<(), sp_runtime::TryRuntimeError> { pub fn do_try_state() -> Result<(), sp_runtime::TryRuntimeError> {
// Checking memory corruption for BookStateFor // Checking memory corruption for BookStateFor
ensure!( ensure!(
@@ -1181,13 +1274,17 @@ impl<T: Config> Pallet<T> {
"Memory Corruption in Pages" "Memory Corruption in Pages"
); );
// No state to check // Basic checks for each book
if ServiceHead::<T>::get().is_none() { for book in BookStateFor::<T>::iter_values() {
return Ok(()) ensure!(book.end >= book.begin, "Invariant");
ensure!(book.end < 1 << 30, "Likely overflow or corruption");
ensure!(book.message_count < 1 << 30, "Likely overflow or corruption");
ensure!(book.size < 1 << 30, "Likely overflow or corruption");
ensure!(book.count < 1 << 30, "Likely overflow or corruption");
} }
//loop around this origin //loop around this origin
let starting_origin = ServiceHead::<T>::get().unwrap(); let Some(starting_origin) = ServiceHead::<T>::get() else { return Ok(()) };
while let Some(head) = Self::bump_service_head(&mut WeightMeter::new()) { while let Some(head) = Self::bump_service_head(&mut WeightMeter::new()) {
ensure!( ensure!(
@@ -1220,7 +1317,7 @@ impl<T: Config> Pallet<T> {
for page_index in head_book_state.begin..head_book_state.end { for page_index in head_book_state.begin..head_book_state.end {
let page = Pages::<T>::get(&head, page_index).unwrap(); let page = Pages::<T>::get(&head, page_index).unwrap();
let remaining_messages = page.remaining; let remaining_messages = page.remaining;
let mut counted_remaining_messages = 0; let mut counted_remaining_messages: u32 = 0;
ensure!( ensure!(
remaining_messages > 0.into(), remaining_messages > 0.into(),
"These must be some messages that have not been processed yet!" "These must be some messages that have not been processed yet!"
@@ -1237,7 +1334,7 @@ impl<T: Config> Pallet<T> {
} }
ensure!( ensure!(
remaining_messages == counted_remaining_messages.into(), remaining_messages.into() == counted_remaining_messages,
"Memory Corruption" "Memory Corruption"
); );
} }
@@ -1312,10 +1409,9 @@ impl<T: Config> Pallet<T> {
meter: &mut WeightMeter, meter: &mut WeightMeter,
overweight_limit: Weight, overweight_limit: Weight,
) -> MessageExecutionStatus { ) -> MessageExecutionStatus {
let hash = sp_io::hashing::blake2_256(message); let mut id = sp_io::hashing::blake2_256(message);
use ProcessMessageError::*; use ProcessMessageError::*;
let prev_consumed = meter.consumed(); let prev_consumed = meter.consumed();
let mut id = hash;
match T::MessageProcessor::process_message(message, origin.clone(), meter, &mut id) { match T::MessageProcessor::process_message(message, origin.clone(), meter, &mut id) {
Err(Overweight(w)) if w.any_gt(overweight_limit) => { Err(Overweight(w)) if w.any_gt(overweight_limit) => {
@@ -1339,19 +1435,44 @@ impl<T: Config> Pallet<T> {
}, },
Err(error @ BadFormat | error @ Corrupt | error @ Unsupported) => { Err(error @ BadFormat | error @ Corrupt | error @ Unsupported) => {
// Permanent error - drop // Permanent error - drop
Self::deposit_event(Event::<T>::ProcessingFailed { id, origin, error }); Self::deposit_event(Event::<T>::ProcessingFailed { id: id.into(), origin, error });
MessageExecutionStatus::Unprocessable { permanent: true } MessageExecutionStatus::Unprocessable { permanent: true }
}, },
Ok(success) => { Ok(success) => {
// Success // Success
let weight_used = meter.consumed().saturating_sub(prev_consumed); let weight_used = meter.consumed().saturating_sub(prev_consumed);
Self::deposit_event(Event::<T>::Processed { id, origin, weight_used, success }); Self::deposit_event(Event::<T>::Processed {
id: id.into(),
origin,
weight_used,
success,
});
MessageExecutionStatus::Processed MessageExecutionStatus::Processed
}, },
} }
} }
} }
/// Run a closure that errors on re-entrance. Meant to be used by anything that services queues.
pub(crate) fn with_service_mutex<F: FnOnce() -> R, R>(f: F) -> Result<R, ()> {
// Holds the singelton token instance.
environmental::environmental!(token: Option<()>);
token::using_once(&mut Some(()), || {
// The first `ok_or` should always be `Ok` since we are inside a `using_once`.
let hold = token::with(|t| t.take()).ok_or(()).defensive()?.ok_or(())?;
// Put the token back when we're done.
defer! {
token::with(|t| {
*t = Some(hold);
});
}
Ok(f())
})
}
/// Provides a [`sp_core::Get`] to access the `MEL` of a [`codec::MaxEncodedLen`] type. /// Provides a [`sp_core::Get`] to access the `MEL` of a [`codec::MaxEncodedLen`] type.
pub struct MaxEncodedLenOf<T>(sp_std::marker::PhantomData<T>); pub struct MaxEncodedLenOf<T>(sp_std::marker::PhantomData<T>);
impl<T: MaxEncodedLen> Get<u32> for MaxEncodedLenOf<T> { impl<T: MaxEncodedLen> Get<u32> for MaxEncodedLenOf<T> {
@@ -1407,35 +1528,40 @@ impl<T: Config> ServiceQueues for Pallet<T> {
Weight::zero() Weight::zero()
}); });
let mut next = match Self::bump_service_head(&mut weight) { match with_service_mutex(|| {
Some(h) => h, let mut next = match Self::bump_service_head(&mut weight) {
None => return weight.consumed(), Some(h) => h,
}; None => return weight.consumed(),
// The last queue that did not make any progress. };
// The loop aborts as soon as it arrives at this queue again without making any progress // The last queue that did not make any progress.
// on other queues in between. // The loop aborts as soon as it arrives at this queue again without making any progress
let mut last_no_progress = None; // on other queues in between.
let mut last_no_progress = None;
loop { loop {
let (progressed, n) = Self::service_queue(next.clone(), &mut weight, max_weight); let (progressed, n) = Self::service_queue(next.clone(), &mut weight, max_weight);
next = match n { next = match n {
Some(n) => Some(n) =>
if !progressed { if !progressed {
if last_no_progress == Some(n.clone()) { if last_no_progress == Some(n.clone()) {
break break
} }
if last_no_progress.is_none() { if last_no_progress.is_none() {
last_no_progress = Some(next.clone()) last_no_progress = Some(next.clone())
} }
n n
} else { } else {
last_no_progress = None; last_no_progress = None;
n n
}, },
None => break, None => break,
}
} }
weight.consumed()
}) {
Err(()) => weight.consumed(),
Ok(w) => w,
} }
weight.consumed()
} }
/// Execute a single overweight message. /// Execute a single overweight message.
@@ -1463,6 +1589,7 @@ impl<T: Config> ServiceQueues for Pallet<T> {
Error::<T>::QueuePaused => ExecuteOverweightError::QueuePaused, Error::<T>::QueuePaused => ExecuteOverweightError::QueuePaused,
Error::<T>::NoPage | Error::<T>::NoMessage | Error::<T>::Queued => Error::<T>::NoPage | Error::<T>::NoMessage | Error::<T>::Queued =>
ExecuteOverweightError::NotFound, ExecuteOverweightError::NotFound,
Error::<T>::RecursiveDisallowed => ExecuteOverweightError::RecursiveDisallowed,
_ => ExecuteOverweightError::Other, _ => ExecuteOverweightError::Other,
}, },
) )
+29 -1
View File
@@ -108,7 +108,10 @@ impl MockedWeightInfo {
impl crate::weights::WeightInfo for MockedWeightInfo { impl crate::weights::WeightInfo for MockedWeightInfo {
fn reap_page() -> Weight { fn reap_page() -> Weight {
WeightForCall::get().get("reap_page").copied().unwrap_or_default() WeightForCall::get()
.get("reap_page")
.copied()
.unwrap_or(DefaultWeightForCall::get())
} }
fn execute_overweight_page_updated() -> Weight { fn execute_overweight_page_updated() -> Weight {
WeightForCall::get() WeightForCall::get()
@@ -207,6 +210,10 @@ impl ProcessMessage for RecordingMessageProcessor {
let required = Weight::from_parts(weight, weight); let required = Weight::from_parts(weight, weight);
if meter.try_consume(required).is_ok() { if meter.try_consume(required).is_ok() {
if let Some(p) = message.strip_prefix(&b"callback="[..]) {
let s = String::from_utf8(p.to_vec()).expect("Need valid UTF8");
Callback::get()(&origin, s.parse().expect("Expected an u32"));
}
let mut m = MessagesProcessed::get(); let mut m = MessagesProcessed::get();
m.push((message.to_vec(), origin)); m.push((message.to_vec(), origin));
MessagesProcessed::set(m); MessagesProcessed::set(m);
@@ -217,6 +224,10 @@ impl ProcessMessage for RecordingMessageProcessor {
} }
} }
parameter_types! {
pub static Callback: Box<fn (&MessageOrigin, u32)> = Box::new(|_, _| {});
}
/// Processed a mocked message. Messages that end with `badformat`, `corrupt`, `unsupported` or /// Processed a mocked message. Messages that end with `badformat`, `corrupt`, `unsupported` or
/// `yield` will fail with an error respectively. /// `yield` will fail with an error respectively.
fn processing_message(msg: &[u8], origin: &MessageOrigin) -> Result<(), ProcessMessageError> { fn processing_message(msg: &[u8], origin: &MessageOrigin) -> Result<(), ProcessMessageError> {
@@ -264,6 +275,10 @@ impl ProcessMessage for CountingMessageProcessor {
let required = Weight::from_parts(1, 1); let required = Weight::from_parts(1, 1);
if meter.try_consume(required).is_ok() { if meter.try_consume(required).is_ok() {
if let Some(p) = message.strip_prefix(&b"callback="[..]) {
let s = String::from_utf8(p.to_vec()).expect("Need valid UTF8");
Callback::get()(&origin, s.parse().expect("Expected an u32"));
}
NumMessagesProcessed::set(NumMessagesProcessed::get() + 1); NumMessagesProcessed::set(NumMessagesProcessed::get() + 1);
Ok(true) Ok(true)
} else { } else {
@@ -372,3 +387,16 @@ pub fn num_overweight_enqueued_events() -> u32 {
pub fn fp(pages: u32, count: u64, size: u64) -> QueueFootprint { pub fn fp(pages: u32, count: u64, size: u64) -> QueueFootprint {
QueueFootprint { storage: Footprint { count, size }, pages } QueueFootprint { storage: Footprint { count, size }, pages }
} }
/// A random seed that can be overwritten with `MQ_SEED`.
pub fn gen_seed() -> u64 {
use rand::Rng;
let seed = if let Ok(seed) = std::env::var("MQ_SEED") {
seed.parse().expect("Need valid u64 as MQ_SEED env variable")
} else {
rand::thread_rng().gen::<u64>()
};
println!("Using seed: {}", seed);
seed
}
@@ -15,6 +15,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![allow(missing_docs)]
//! Std setup helpers for testing and benchmarking. //! Std setup helpers for testing and benchmarking.
//! //!
//! Cannot be put into mock.rs since benchmarks require no-std and mock.rs is std. //! Cannot be put into mock.rs since benchmarks require no-std and mock.rs is std.
@@ -88,10 +90,12 @@ pub fn page<T: Config>(msg: &[u8]) -> PageOf<T> {
PageOf::<T>::from_message::<T>(msg.try_into().unwrap()) PageOf::<T>::from_message::<T>(msg.try_into().unwrap())
} }
/// Create a book with a single message of one byte.
pub fn single_page_book<T: Config>() -> BookStateOf<T> { pub fn single_page_book<T: Config>() -> BookStateOf<T> {
BookState { begin: 0, end: 1, count: 1, message_count: 1, size: 1, ..Default::default() } BookState { begin: 0, end: 1, count: 1, message_count: 1, size: 1, ..Default::default() }
} }
/// Create an empty book.
pub fn empty_book<T: Config>() -> BookStateOf<T> { pub fn empty_book<T: Config>() -> BookStateOf<T> {
BookState { begin: 0, end: 1, count: 1, ..Default::default() } BookState { begin: 0, end: 1, count: 1, ..Default::default() }
} }
+178 -7
View File
@@ -181,7 +181,7 @@ fn service_queues_failing_messages_works() {
assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight()); assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight());
assert_last_event::<Test>( assert_last_event::<Test>(
Event::ProcessingFailed { Event::ProcessingFailed {
id: blake2_256(b"badformat"), id: blake2_256(b"badformat").into(),
origin: MessageOrigin::Here, origin: MessageOrigin::Here,
error: ProcessMessageError::BadFormat, error: ProcessMessageError::BadFormat,
} }
@@ -190,7 +190,7 @@ fn service_queues_failing_messages_works() {
assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight()); assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight());
assert_last_event::<Test>( assert_last_event::<Test>(
Event::ProcessingFailed { Event::ProcessingFailed {
id: blake2_256(b"corrupt"), id: blake2_256(b"corrupt").into(),
origin: MessageOrigin::Here, origin: MessageOrigin::Here,
error: ProcessMessageError::Corrupt, error: ProcessMessageError::Corrupt,
} }
@@ -199,7 +199,7 @@ fn service_queues_failing_messages_works() {
assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight()); assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight());
assert_last_event::<Test>( assert_last_event::<Test>(
Event::ProcessingFailed { Event::ProcessingFailed {
id: blake2_256(b"unsupported"), id: blake2_256(b"unsupported").into(),
origin: MessageOrigin::Here, origin: MessageOrigin::Here,
error: ProcessMessageError::Unsupported, error: ProcessMessageError::Unsupported,
} }
@@ -1264,7 +1264,7 @@ fn permanently_overweight_limit_is_valid_basic() {
RuntimeEvent::MessageQueue(Event::Processed { RuntimeEvent::MessageQueue(Event::Processed {
origin: Here, origin: Here,
weight_used: 200.into(), weight_used: 200.into(),
id: blake2_256(m.as_bytes()), id: blake2_256(m.as_bytes()).into(),
success: true, success: true,
}) })
); );
@@ -1321,7 +1321,7 @@ fn permanently_overweight_limit_is_valid_fuzzy() {
RuntimeEvent::MessageQueue(Event::Processed { RuntimeEvent::MessageQueue(Event::Processed {
origin: Here, origin: Here,
weight_used: 200.into(), weight_used: 200.into(),
id: blake2_256(m.as_bytes()), id: blake2_256(m.as_bytes()).into(),
success: true, success: true,
}) })
); );
@@ -1592,7 +1592,7 @@ fn execute_overweight_respects_suspension() {
assert_last_event::<Test>( assert_last_event::<Test>(
Event::Processed { Event::Processed {
id: blake2_256(b"weight=5"), id: blake2_256(b"weight=5").into(),
origin, origin,
weight_used: 5.into_weight(), weight_used: 5.into_weight(),
success: true, success: true,
@@ -1619,7 +1619,7 @@ fn service_queue_suspension_ready_ring_works() {
MessageQueue::service_queues(Weight::MAX); MessageQueue::service_queues(Weight::MAX);
assert_last_event::<Test>( assert_last_event::<Test>(
Event::Processed { Event::Processed {
id: blake2_256(b"weight=5"), id: blake2_256(b"weight=5").into(),
origin, origin,
weight_used: 5.into_weight(), weight_used: 5.into_weight(),
success: true, success: true,
@@ -1662,3 +1662,174 @@ fn integrity_test_checks_service_weight() {
} }
}); });
} }
/// Test for <https://github.com/paritytech/polkadot-sdk/issues/2319>.
#[test]
fn regression_issue_2319() {
build_and_execute::<Test>(|| {
Callback::set(Box::new(|_, _| {
MessageQueue::enqueue_message(mock_helpers::msg("anothermessage"), There);
}));
use MessageOrigin::*;
MessageQueue::enqueue_message(msg("callback=0"), Here);
// while servicing queue Here, "anothermessage" of origin There is enqueued in
// "firstmessage"'s process_message
assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight());
assert_eq!(MessagesProcessed::take(), vec![(b"callback=0".to_vec(), Here)]);
assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight());
// It used to fail here but got fixed.
assert_eq!(MessagesProcessed::take(), vec![(b"anothermessage".to_vec(), There)]);
});
}
/// Enqueueing a message from within `service_queues` works.
#[test]
fn recursive_enqueue_works() {
build_and_execute::<Test>(|| {
Callback::set(Box::new(|o, i| match i {
0 => {
MessageQueue::enqueue_message(msg(&format!("callback={}", 1)), *o);
},
1 => {
for _ in 0..100 {
MessageQueue::enqueue_message(msg(&format!("callback={}", 2)), *o);
}
for i in 0..100 {
MessageQueue::enqueue_message(msg(&format!("callback={}", 3)), i.into());
}
},
2 | 3 => {
MessageQueue::enqueue_message(msg(&format!("callback={}", 4)), *o);
},
4 => (),
_ => unreachable!(),
}));
MessageQueue::enqueue_message(msg("callback=0"), MessageOrigin::Here);
for _ in 0..402 {
assert_eq!(MessageQueue::service_queues(1.into_weight()), 1.into_weight());
}
assert_eq!(MessageQueue::service_queues(Weight::MAX), Weight::zero());
assert_eq!(MessagesProcessed::take().len(), 402);
});
}
/// Calling `service_queues` from within `service_queues` is forbidden.
#[test]
fn recursive_service_is_forbidden() {
use MessageOrigin::*;
build_and_execute::<Test>(|| {
Callback::set(Box::new(|_, _| {
MessageQueue::enqueue_message(msg("m1"), There);
// This call will fail since it is recursive. But it will not mess up the state.
assert_storage_noop!(MessageQueue::service_queues(10.into_weight()));
MessageQueue::enqueue_message(msg("m2"), There);
}));
for _ in 0..5 {
MessageQueue::enqueue_message(msg("callback=0"), Here);
MessageQueue::service_queues(3.into_weight());
// All three messages are correctly processed.
assert_eq!(
MessagesProcessed::take(),
vec![
(b"callback=0".to_vec(), Here),
(b"m1".to_vec(), There),
(b"m2".to_vec(), There)
]
);
}
});
}
/// Calling `service_queues` from within `service_queues` is forbidden.
#[test]
fn recursive_overweight_while_service_is_forbidden() {
use MessageOrigin::*;
build_and_execute::<Test>(|| {
Callback::set(Box::new(|_, _| {
// Check that the message was permanently overweight.
assert_last_event::<Test>(
Event::OverweightEnqueued {
id: blake2_256(b"weight=10"),
origin: There,
message_index: 0,
page_index: 0,
}
.into(),
);
// This call will fail since it is recursive. But it will not mess up the state.
assert_noop!(
<MessageQueue as ServiceQueues>::execute_overweight(
10.into_weight(),
(There, 0, 0)
),
ExecuteOverweightError::RecursiveDisallowed
);
}));
MessageQueue::enqueue_message(msg("weight=10"), There);
MessageQueue::enqueue_message(msg("callback=0"), Here);
// Mark it as permanently overweight.
MessageQueue::service_queues(5.into_weight());
assert_ok!(<MessageQueue as ServiceQueues>::execute_overweight(
10.into_weight(),
(There, 0, 0)
));
});
}
/// Calling `reap_page` from within `service_queues` is forbidden.
#[test]
fn recursive_reap_page_is_forbidden() {
use MessageOrigin::*;
build_and_execute::<Test>(|| {
Callback::set(Box::new(|_, _| {
// This call will fail since it is recursive. But it will not mess up the state.
assert_noop!(MessageQueue::do_reap_page(&Here, 0), Error::<Test>::RecursiveDisallowed);
}));
// Create 10 pages more than the stale limit.
let n = (MaxStale::get() + 10) as usize;
for _ in 0..n {
MessageQueue::enqueue_message(msg("weight=2"), Here);
}
// Mark all pages as stale since their message is permanently overweight.
MessageQueue::service_queues(1.into_weight());
assert_ok!(MessageQueue::do_reap_page(&Here, 0));
assert_last_event::<Test>(Event::PageReaped { origin: Here, index: 0 }.into());
});
}
#[test]
fn with_service_mutex_works() {
let mut called = 0;
with_service_mutex(|| called = 1).unwrap();
assert_eq!(called, 1);
// The outer one is fine but the inner one errors.
with_service_mutex(|| with_service_mutex(|| unreachable!()))
.unwrap()
.unwrap_err();
with_service_mutex(|| with_service_mutex(|| unreachable!()).unwrap_err()).unwrap();
with_service_mutex(|| {
with_service_mutex(|| unreachable!()).unwrap_err();
with_service_mutex(|| unreachable!()).unwrap_err();
called = 2;
})
.unwrap();
assert_eq!(called, 2);
// Still works.
with_service_mutex(|| called = 3).unwrap();
assert_eq!(called, 3);
}
@@ -82,6 +82,8 @@ pub enum ExecuteOverweightError {
QueuePaused, QueuePaused,
/// An unspecified error. /// An unspecified error.
Other, Other,
/// Another call is currently ongoing and prevents this call from executing.
RecursiveDisallowed,
} }
/// Can service queues and execute overweight messages. /// Can service queues and execute overweight messages.