availability-recovery: move cpu burners in blocking tasks (#7417)

* Move expensive computations to blocking thread

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* fix test

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* add internal error and fix dependent subystems

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* fmt

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* fix test fix

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* minor refactor and TODOs

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* Impl Feedback for Review

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* review feedback

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* More docs

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

* add some example timings in comments

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>

---------

Signed-off-by: Andrei Sandu <andrei-mihail@parity.io>
This commit is contained in:
Andrei Sandu
2023-07-04 12:50:49 +03:00
committed by GitHub
parent e53d15aa20
commit a0814490d2
5 changed files with 288 additions and 41 deletions
@@ -2519,6 +2519,17 @@ async fn launch_approval<Context>(
// do nothing. we'll just be a no-show and that'll cause others to rise up.
metrics_guard.take().on_approval_unavailable();
},
&RecoveryError::ChannelClosed => {
gum::warn!(
target: LOG_TARGET,
?para_id,
?candidate_hash,
"Channel closed while recovering data for candidate {:?}",
(candidate_hash, candidate.descriptor.para_id),
);
// do nothing. we'll just be a no-show and that'll cause others to rise up.
metrics_guard.take().on_approval_unavailable();
},
&RecoveryError::Invalid => {
gum::warn!(
target: LOG_TARGET,
@@ -319,7 +319,7 @@ async fn participate(
send_result(&mut result_sender, req, ParticipationOutcome::Invalid).await;
return
},
Ok(Err(RecoveryError::Unavailable)) => {
Ok(Err(RecoveryError::Unavailable)) | Ok(Err(RecoveryError::ChannelClosed)) => {
send_result(&mut result_sender, req, ParticipationOutcome::Unavailable).await;
return
},
@@ -20,24 +20,28 @@
use std::{
collections::{HashMap, VecDeque},
iter::Iterator,
num::NonZeroUsize,
pin::Pin,
time::Duration,
};
use futures::{
channel::oneshot,
future::{FutureExt, RemoteHandle},
channel::oneshot::{self, channel},
future::{Future, FutureExt, RemoteHandle},
pin_mut,
prelude::*,
stream::FuturesUnordered,
sink::SinkExt,
stream::{FuturesUnordered, StreamExt},
task::{Context, Poll},
};
use lru::LruCache;
use rand::seq::SliceRandom;
use fatality::Nested;
use polkadot_erasure_coding::{branch_hash, branches, obtain_chunks_v1, recovery_threshold};
use polkadot_erasure_coding::{
branch_hash, branches, obtain_chunks_v1, recovery_threshold, Error as ErasureEncodingError,
};
#[cfg(not(test))]
use polkadot_node_network_protocol::request_response::CHUNK_REQUEST_TIMEOUT;
use polkadot_node_network_protocol::{
@@ -150,6 +154,8 @@ struct RequestFromBackers {
// a random shuffling of the validators from the backing group which indicates the order
// in which we connect to them and request the chunk.
shuffled_backers: Vec<ValidatorIndex>,
// channel to the erasure task handler.
erasure_task_tx: futures::channel::mpsc::Sender<ErasureTask>,
}
struct RequestChunksFromValidators {
@@ -162,9 +168,12 @@ struct RequestChunksFromValidators {
/// a random shuffling of the validators which indicates the order in which we connect to the validators and
/// request the chunk from them.
shuffling: VecDeque<ValidatorIndex>,
/// Chunks received so far.
received_chunks: HashMap<ValidatorIndex, ErasureChunk>,
/// Pending chunk requests with soft timeout.
requesting_chunks: FuturesUndead<Result<Option<ErasureChunk>, (ValidatorIndex, RequestError)>>,
// channel to the erasure task handler.
erasure_task_tx: futures::channel::mpsc::Sender<ErasureTask>,
}
struct RecoveryParams {
@@ -198,6 +207,18 @@ enum Source {
RequestChunks(RequestChunksFromValidators),
}
/// Expensive erasure coding computations that we want to run on a blocking thread.
enum ErasureTask {
/// Reconstructs `AvailableData` from chunks given `n_validators`.
Reconstruct(
usize,
HashMap<ValidatorIndex, ErasureChunk>,
oneshot::Sender<Result<AvailableData, ErasureEncodingError>>,
),
/// Re-encode `AvailableData` into erasure chunks in order to verify the provided root hash of the Merkle tree.
Reencode(usize, Hash, AvailableData, oneshot::Sender<Option<AvailableData>>),
}
/// A stateful reconstruction of availability data in reference to
/// a candidate hash.
struct RecoveryTask<Sender> {
@@ -208,13 +229,19 @@ struct RecoveryTask<Sender> {
/// The source to obtain the availability data from.
source: Source,
// channel to the erasure task handler.
erasure_task_tx: futures::channel::mpsc::Sender<ErasureTask>,
}
impl RequestFromBackers {
fn new(mut backers: Vec<ValidatorIndex>) -> Self {
fn new(
mut backers: Vec<ValidatorIndex>,
erasure_task_tx: futures::channel::mpsc::Sender<ErasureTask>,
) -> Self {
backers.shuffle(&mut rand::thread_rng());
RequestFromBackers { shuffled_backers: backers }
RequestFromBackers { shuffled_backers: backers, erasure_task_tx }
}
// Run this phase to completion.
@@ -251,12 +278,21 @@ impl RequestFromBackers {
match response.await {
Ok(req_res::v1::AvailableDataFetchingResponse::AvailableData(data)) => {
if reconstructed_data_matches_root(
params.validators.len(),
&params.erasure_root,
&data,
&params.metrics,
) {
let (reencode_tx, reencode_rx) = channel();
self.erasure_task_tx
.send(ErasureTask::Reencode(
params.validators.len(),
params.erasure_root,
data,
reencode_tx,
))
.await
.map_err(|_| RecoveryError::ChannelClosed)?;
let reencode_response =
reencode_rx.await.map_err(|_| RecoveryError::ChannelClosed)?;
if let Some(data) = reencode_response {
gum::trace!(
target: LOG_TARGET,
candidate_hash = ?params.candidate_hash,
@@ -289,7 +325,10 @@ impl RequestFromBackers {
}
impl RequestChunksFromValidators {
fn new(n_validators: u32) -> Self {
fn new(
n_validators: u32,
erasure_task_tx: futures::channel::mpsc::Sender<ErasureTask>,
) -> Self {
let mut shuffling: Vec<_> = (0..n_validators).map(ValidatorIndex).collect();
shuffling.shuffle(&mut rand::thread_rng());
@@ -299,20 +338,29 @@ impl RequestChunksFromValidators {
shuffling: shuffling.into(),
received_chunks: HashMap::new(),
requesting_chunks: FuturesUndead::new(),
erasure_task_tx,
}
}
fn is_unavailable(&self, params: &RecoveryParams) -> bool {
is_unavailable(
self.received_chunks.len(),
self.chunk_count(),
self.requesting_chunks.total_len(),
self.shuffling.len(),
params.threshold,
)
}
fn chunk_count(&self) -> usize {
self.received_chunks.len()
}
fn insert_chunk(&mut self, validator_index: ValidatorIndex, chunk: ErasureChunk) {
self.received_chunks.insert(validator_index, chunk);
}
fn can_conclude(&self, params: &RecoveryParams) -> bool {
self.received_chunks.len() >= params.threshold || self.is_unavailable(params)
self.chunk_count() >= params.threshold || self.is_unavailable(params)
}
/// Desired number of parallel requests.
@@ -329,7 +377,7 @@ impl RequestChunksFromValidators {
// 4. We request more chunks to make up for it ...
let max_requests_boundary = std::cmp::min(N_PARALLEL, threshold);
// How many chunks are still needed?
let remaining_chunks = threshold.saturating_sub(self.received_chunks.len());
let remaining_chunks = threshold.saturating_sub(self.chunk_count());
// What is the current error rate, so we can make up for it?
let inv_error_rate =
self.total_received_responses.checked_div(self.error_count).unwrap_or(0);
@@ -430,7 +478,7 @@ impl RequestChunksFromValidators {
validator_index = ?chunk.index,
"Received valid chunk",
);
self.received_chunks.insert(chunk.index, chunk);
self.insert_chunk(chunk.index, chunk);
} else {
metrics.on_chunk_request_invalid();
self.error_count += 1;
@@ -488,7 +536,7 @@ impl RequestChunksFromValidators {
gum::debug!(
target: LOG_TARGET,
candidate_hash = ?params.candidate_hash,
received_chunks_count = ?self.received_chunks.len(),
received_chunks_count = ?self.chunk_count(),
requested_chunks_count = ?self.requesting_chunks.len(),
threshold = ?params.threshold,
"Can conclude availability for a candidate",
@@ -530,7 +578,7 @@ impl RequestChunksFromValidators {
validator_index = ?chunk.index,
"Found valid chunk on disk"
);
self.received_chunks.insert(chunk.index, chunk);
self.insert_chunk(chunk.index, chunk);
} else {
gum::error!(
target: LOG_TARGET,
@@ -557,7 +605,7 @@ impl RequestChunksFromValidators {
target: LOG_TARGET,
candidate_hash = ?params.candidate_hash,
erasure_root = ?params.erasure_root,
received = %self.received_chunks.len(),
received = %self.chunk_count(),
requesting = %self.requesting_chunks.len(),
total_requesting = %self.requesting_chunks.total_len(),
n_validators = %params.validators.len(),
@@ -575,20 +623,41 @@ impl RequestChunksFromValidators {
// If received_chunks has more than threshold entries, attempt to recover the data.
// If that fails, or a re-encoding of it doesn't match the expected erasure root,
// return Err(RecoveryError::Invalid)
if self.received_chunks.len() >= params.threshold {
if self.chunk_count() >= params.threshold {
let recovery_duration = metrics.time_erasure_recovery();
return match polkadot_erasure_coding::reconstruct_v1(
params.validators.len(),
self.received_chunks.values().map(|c| (&c.chunk[..], c.index.0 as usize)),
) {
// Send request to reconstruct available data from chunks.
let (avilable_data_tx, available_data_rx) = channel();
self.erasure_task_tx
.send(ErasureTask::Reconstruct(
params.validators.len(),
std::mem::take(&mut self.received_chunks),
avilable_data_tx,
))
.await
.map_err(|_| RecoveryError::ChannelClosed)?;
let available_data_response =
available_data_rx.await.map_err(|_| RecoveryError::ChannelClosed)?;
return match available_data_response {
Ok(data) => {
if reconstructed_data_matches_root(
params.validators.len(),
&params.erasure_root,
&data,
&metrics,
) {
// Send request to re-encode the chunks and check merkle root.
let (reencode_tx, reencode_rx) = channel();
self.erasure_task_tx
.send(ErasureTask::Reencode(
params.validators.len(),
params.erasure_root,
data,
reencode_tx,
))
.await
.map_err(|_| RecoveryError::ChannelClosed)?;
let reencode_response =
reencode_rx.await.map_err(|_| RecoveryError::ChannelClosed)?;
if let Some(data) = reencode_response {
gum::trace!(
target: LOG_TARGET,
candidate_hash = ?params.candidate_hash,
@@ -746,9 +815,12 @@ where
match from_backers.run(&self.params, &mut self.sender).await {
Ok(data) => break Ok(data),
Err(RecoveryError::Invalid) => break Err(RecoveryError::Invalid),
Err(RecoveryError::ChannelClosed) =>
break Err(RecoveryError::ChannelClosed),
Err(RecoveryError::Unavailable) =>
self.source = Source::RequestChunks(RequestChunksFromValidators::new(
self.params.validators.len() as _,
self.erasure_task_tx.clone(),
)),
}
},
@@ -838,6 +910,7 @@ impl TryFrom<Result<AvailableData, RecoveryError>> for CachedRecovery {
// We don't want to cache unavailable state, as that state might change, so if
// requested again we want to try again!
Err(RecoveryError::Unavailable) => Err(()),
Err(RecoveryError::ChannelClosed) => Err(()),
}
}
}
@@ -904,9 +977,9 @@ async fn launch_recovery_task<Context>(
response_sender: oneshot::Sender<Result<AvailableData, RecoveryError>>,
metrics: &Metrics,
recovery_strategy: &RecoveryStrategy,
erasure_task_tx: futures::channel::mpsc::Sender<ErasureTask>,
) -> error::Result<()> {
let candidate_hash = receipt.hash();
let params = RecoveryParams {
validator_authority_keys: session_info.discovery_keys.clone(),
validators: session_info.validators.clone(),
@@ -943,12 +1016,21 @@ async fn launch_recovery_task<Context>(
let phase = backing_group
.and_then(|g| session_info.validator_groups.get(g))
.map(|group| Source::RequestFromBackers(RequestFromBackers::new(group.clone())))
.map(|group| {
Source::RequestFromBackers(RequestFromBackers::new(
group.clone(),
erasure_task_tx.clone(),
))
})
.unwrap_or_else(|| {
Source::RequestChunks(RequestChunksFromValidators::new(params.validators.len() as _))
Source::RequestChunks(RequestChunksFromValidators::new(
params.validators.len() as _,
erasure_task_tx.clone(),
))
});
let recovery_task = RecoveryTask { sender: ctx.sender().clone(), params, source: phase };
let recovery_task =
RecoveryTask { sender: ctx.sender().clone(), params, source: phase, erasure_task_tx };
let (remote, remote_handle) = recovery_task.run().remote_handle();
@@ -980,6 +1062,7 @@ async fn handle_recover<Context>(
response_sender: oneshot::Sender<Result<AvailableData, RecoveryError>>,
metrics: &Metrics,
recovery_strategy: &RecoveryStrategy,
erasure_task_tx: futures::channel::mpsc::Sender<ErasureTask>,
) -> error::Result<()> {
let candidate_hash = receipt.hash();
@@ -1024,6 +1107,7 @@ async fn handle_recover<Context>(
response_sender,
metrics,
recovery_strategy,
erasure_task_tx,
)
.await,
None => {
@@ -1061,6 +1145,7 @@ async fn query_chunk_size<Context>(
rx.await.map_err(error::Error::CanceledQueryFullData)
}
#[overseer::contextbounds(AvailabilityRecovery, prefix = self::overseer)]
impl AvailabilityRecoverySubsystem {
/// Create a new instance of `AvailabilityRecoverySubsystem` which never requests the
@@ -1106,10 +1191,65 @@ impl AvailabilityRecoverySubsystem {
let mut state = State::default();
let Self { recovery_strategy, mut req_receiver, metrics } = self;
let (erasure_task_tx, erasure_task_rx) = futures::channel::mpsc::channel(16);
let mut erasure_task_rx = erasure_task_rx.fuse();
// `ThreadPoolBuilder` spawns the tasks using `spawn_blocking`. For each worker there will be a `mpsc` channel created.
// Each of these workers take the `Receiver` and poll it in an infinite loop.
// All of the sender ends of the channel are sent as a vec which we then use to create a `Cycle` iterator.
// We use this iterator to assign work in a round-robin fashion to the workers in the pool.
//
// How work is dispatched to the pool from the recovery tasks:
// - Once a recovery task finishes retrieving the availability data, it needs to reconstruct from chunks and/or
// re-encode the data which are heavy CPU computations.
// To do so it sends an `ErasureTask` to the main loop via the `erasure_task` channel, and waits for the results
// over a `oneshot` channel.
// - In the subsystem main loop we poll the `erasure_task_rx` receiver.
// - We forward the received `ErasureTask` to the `next()` sender yielded by the `Cycle` iterator.
// - Some worker thread handles it and sends the response over the `oneshot` channel.
// Create a thread pool with 2 workers.
let mut to_pool = ThreadPoolBuilder::build(
// Pool is guaranteed to have at least 1 worker thread.
NonZeroUsize::new(2).expect("There are 2 threads; qed"),
metrics.clone(),
&mut ctx,
)
.into_iter()
.cycle();
loop {
let recv_req = req_receiver.recv(|| vec![COST_INVALID_REQUEST]).fuse();
pin_mut!(recv_req);
futures::select! {
erasure_task = erasure_task_rx.next() => {
match erasure_task {
Some(task) => {
let send_result = to_pool
.next()
.expect("Pool size is `NonZeroUsize`; qed")
.send(task)
.await
.map_err(|_| RecoveryError::ChannelClosed);
if let Err(err) = send_result {
gum::warn!(
target: LOG_TARGET,
?err,
"Failed to send erasure coding task",
);
}
},
None => {
gum::debug!(
target: LOG_TARGET,
"Erasure task channel closed",
);
return Err(SubsystemError::with_origin("availability-recovery", RecoveryError::ChannelClosed))
}
}
}
v = ctx.recv().fuse() => {
match v? {
FromOrchestra::Signal(signal) => if handle_signal(
@@ -1135,6 +1275,7 @@ impl AvailabilityRecoverySubsystem {
response_sender,
&metrics,
&recovery_strategy,
erasure_task_tx.clone(),
).await {
gum::warn!(
target: LOG_TARGET,
@@ -1194,3 +1335,92 @@ impl AvailabilityRecoverySubsystem {
}
}
}
// A simple thread pool implementation using `spawn_blocking` threads.
struct ThreadPoolBuilder;
const MAX_THREADS: NonZeroUsize = match NonZeroUsize::new(4) {
Some(max_threads) => max_threads,
None => panic!("MAX_THREADS must be non-zero"),
};
impl ThreadPoolBuilder {
// Creates a pool of `size` workers, where 1 <= `size` <= `MAX_THREADS`.
//
// Each worker is created by `spawn_blocking` and takes the receiver side of a channel
// while all of the senders are returned to the caller. Each worker runs `erasure_task_thread` that
// polls the `Receiver` for an `ErasureTask` which is expected to be CPU intensive. The larger
// the input (more or larger chunks/availability data), the more CPU cycles will be spent.
//
// For example, for 32KB PoVs, we'd expect re-encode to eat as much as 90ms and 500ms for 2.5MiB.
//
// After executing such a task, the worker sends the response via a provided `oneshot` sender.
//
// The caller is responsible for routing work to the workers.
#[overseer::contextbounds(AvailabilityRecovery, prefix = self::overseer)]
pub fn build<Context>(
size: NonZeroUsize,
metrics: Metrics,
ctx: &mut Context,
) -> Vec<futures::channel::mpsc::Sender<ErasureTask>> {
// At least 1 task, at most `MAX_THREADS.
let size = std::cmp::min(size, MAX_THREADS);
let mut senders = Vec::new();
for index in 0..size.into() {
let (tx, rx) = futures::channel::mpsc::channel(8);
senders.push(tx);
if let Err(e) = ctx
.spawn_blocking("erasure-task", Box::pin(erasure_task_thread(metrics.clone(), rx)))
{
gum::warn!(
target: LOG_TARGET,
err = ?e,
index,
"Failed to spawn a erasure task",
);
}
}
senders
}
}
// Handles CPU intensive operation on a dedicated blocking thread.
async fn erasure_task_thread(
metrics: Metrics,
mut ingress: futures::channel::mpsc::Receiver<ErasureTask>,
) {
loop {
match ingress.next().await {
Some(ErasureTask::Reconstruct(n_validators, chunks, sender)) => {
let _ = sender.send(polkadot_erasure_coding::reconstruct_v1(
n_validators,
chunks.values().map(|c| (&c.chunk[..], c.index.0 as usize)),
));
},
Some(ErasureTask::Reencode(n_validators, root, available_data, sender)) => {
let metrics = metrics.clone();
let maybe_data = if reconstructed_data_matches_root(
n_validators,
&root,
&available_data,
&metrics,
) {
Some(available_data)
} else {
None
};
let _ = sender.send(maybe_data);
},
None => {
gum::debug!(
target: LOG_TARGET,
"Erasure task channel closed. Node shutting down ?",
);
},
}
}
}
@@ -1584,7 +1584,9 @@ fn invalid_local_chunk_is_ignored() {
fn parallel_request_calculation_works_as_expected() {
let num_validators = 100;
let threshold = recovery_threshold(num_validators).unwrap();
let mut phase = RequestChunksFromValidators::new(100);
let (erasure_task_tx, _erasure_task_rx) = futures::channel::mpsc::channel(16);
let mut phase = RequestChunksFromValidators::new(100, erasure_task_tx);
assert_eq!(phase.get_desired_request_count(threshold), threshold);
phase.error_count = 1;
phase.total_received_responses = 1;
@@ -1593,20 +1595,20 @@ fn parallel_request_calculation_works_as_expected() {
let dummy_chunk =
ErasureChunk { chunk: Vec::new(), index: ValidatorIndex(0), proof: Proof::dummy_proof() };
phase.received_chunks.insert(ValidatorIndex(0), dummy_chunk.clone());
phase.insert_chunk(ValidatorIndex(0), dummy_chunk.clone());
phase.total_received_responses = 2;
// With given error rate - still saturating:
assert_eq!(phase.get_desired_request_count(threshold), threshold);
for i in 1..9 {
phase.received_chunks.insert(ValidatorIndex(i), dummy_chunk.clone());
phase.insert_chunk(ValidatorIndex(i), dummy_chunk.clone());
}
phase.total_received_responses += 8;
// error rate: 1/10
// remaining chunks needed: threshold (34) - 9
// expected: 24 * (1+ 1/10) = (next greater integer) = 27
assert_eq!(phase.get_desired_request_count(threshold), 27);
phase.received_chunks.insert(ValidatorIndex(9), dummy_chunk.clone());
phase.insert_chunk(ValidatorIndex(9), dummy_chunk.clone());
phase.error_count = 0;
// With error count zero - we should fetch exactly as needed:
assert_eq!(phase.get_desired_request_count(threshold), threshold - phase.received_chunks.len());
assert_eq!(phase.get_desired_request_count(threshold), threshold - phase.chunk_count());
}
@@ -75,6 +75,9 @@ pub enum RecoveryError {
/// A requested chunk is unavailable.
Unavailable,
/// Erasure task channel closed, usually means node is shutting down.
ChannelClosed,
}
impl std::fmt::Display for RecoveryError {
@@ -82,6 +85,7 @@ impl std::fmt::Display for RecoveryError {
let msg = match self {
RecoveryError::Invalid => "Invalid",
RecoveryError::Unavailable => "Unavailable",
RecoveryError::ChannelClosed => "ChannelClosed",
};
write!(f, "{}", msg)