[FRAME] Make MQ pallet re-entrancy safe (#2356)

Closes https://github.com/paritytech/polkadot-sdk/issues/2319

Changes:
- Ensure that only `enqueue_message(s)` is callable from within the
message processor. This prevents messed up storage that can currently
happen when the pallet is called into recursively.
- Use `H256` instead of `[u8; 32]` for clearer API.

## Details

The re-entracy check is done with the `environmental` crate by adding a
`with_service_mutex(f)` function that runs the closure exclusively. This
works since the MQ pallet is not instantiable.

---------

Signed-off-by: Oliver Tale-Yazdi <oliver.tale-yazdi@parity.io>
Co-authored-by: Francisco Aguirre <franciscoaguirreperez@gmail.com>
This commit is contained in:
Oliver Tale-Yazdi
2023-12-07 17:48:13 +01:00
committed by GitHub
parent 95c3ee10bc
commit 7e7fe99047
11 changed files with 478 additions and 79 deletions
@@ -29,8 +29,8 @@
use crate::{
mock::{
build_and_execute, CountingMessageProcessor, IntoWeight, MockedWeightInfo,
NumMessagesProcessed, YieldingQueues,
build_and_execute, gen_seed, Callback, CountingMessageProcessor, IntoWeight,
MessagesProcessed, MockedWeightInfo, NumMessagesProcessed, YieldingQueues,
},
mock_helpers::MessageOrigin,
*,
@@ -120,13 +120,13 @@ impl Config for Test {
/// Processing all remaining 28639 messages
/// ```
#[test]
#[ignore] // Only run in the CI.
#[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_enqueue_and_service() {
let blocks = 20;
let max_queues = 10_000;
let max_messages_per_queue = 10_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(43);
let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
@@ -148,6 +148,74 @@ fn stress_test_enqueue_and_service() {
});
}
/// Very similar to `stress_test_enqueue_and_service`, but enqueues messages while processing them.
#[test]
#[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_recursive() {
let blocks = 20;
let mut rng = StdRng::seed_from_u64(gen_seed());
// We need to use thread-locals since the callback cannot capture anything.
parameter_types! {
pub static TotalEnqueued: u32 = 0;
pub static Enqueued: u32 = 0;
pub static Called: u32 = 0;
}
Called::take();
Enqueued::take();
TotalEnqueued::take();
Callback::set(Box::new(|_, _| {
let mut rng = StdRng::seed_from_u64(Enqueued::get() as u64);
let max_queues = 1_000;
let max_messages_per_queue = 1_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
// Instead of directly enqueueing, we enqueue inside a `service` call.
let enqueued = enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
TotalEnqueued::set(TotalEnqueued::get() + enqueued);
Enqueued::set(Enqueued::get() + enqueued);
Called::set(Called::get() + 1);
}));
build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
for b in 0..blocks {
log::info!("Block #{}", b);
MessageQueue::enqueue_message(
BoundedSlice::defensive_truncate_from(format!("callback={b}").as_bytes()),
b.into(),
);
msgs_remaining += Enqueued::take() + 1;
// Pick a fraction of all messages currently in queue and process them.
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); // This also advances the block.
msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed + 1);
MessageQueue::do_try_state().unwrap();
}
while Called::get() < blocks {
msgs_remaining += Enqueued::take();
// Pick a fraction of all messages currently in queue and process them.
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); // This also advances the block.
msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed);
MessageQueue::do_try_state().unwrap();
}
let msgs_remaining = TotalEnqueued::take();
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
assert_eq!(Called::get(), blocks);
post_conditions();
});
}
/// Simulates heavy usage of the suspension logic via `Yield`.
///
/// # Example output
@@ -164,14 +232,14 @@ fn stress_test_enqueue_and_service() {
/// Processing all remaining 430 messages
/// ```
#[test]
#[ignore] // Only run in the CI.
#[ignore] // Only run in the CI, otherwise its too slow.
fn stress_test_queue_suspension() {
let blocks = 20;
let max_queues = 10_000;
let max_messages_per_queue = 10_000;
let (max_suspend_per_block, max_resume_per_block) = (100, 50);
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(43);
let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| {
let mut suspended = BTreeSet::<u32>::new();
@@ -300,6 +368,7 @@ fn process_all_messages(expected: u32) {
assert_eq!(consumed, Weight::from_all(expected as u64));
assert_eq!(NumMessagesProcessed::take(), expected as usize);
MessagesProcessed::take();
}
/// Returns the weight consumed by `MessageQueue::on_initialize()`.
@@ -327,5 +396,6 @@ fn post_conditions() {
assert!(ServiceHead::<Test>::get().is_none());
// This still works fine.
assert_eq!(MessageQueue::service_queues(Weight::MAX), Weight::zero(), "Nothing left");
MessageQueue::do_try_state().unwrap();
next_block();
}