diff --git a/polkadot/node/core/candidate-validation/src/lib.rs b/polkadot/node/core/candidate-validation/src/lib.rs index 2862d4c3f3..9d2852efcc 100644 --- a/polkadot/node/core/candidate-validation/src/lib.rs +++ b/polkadot/node/core/candidate-validation/src/lib.rs @@ -35,7 +35,7 @@ use polkadot_node_subsystem::{ CandidateValidationMessage, RuntimeApiMessage, RuntimeApiRequest, ValidationFailed, }, overseer, FromOverseer, OverseerSignal, SpawnedSubsystem, SubsystemContext, SubsystemError, - SubsystemResult, + SubsystemResult, SubsystemSender, }; use polkadot_node_subsystem_util::metrics::{self, prometheus}; use polkadot_parachain::primitives::{ValidationParams, ValidationResult as WasmValidationResult}; @@ -120,7 +120,7 @@ where Context: SubsystemContext, Context: overseer::SubsystemContext, { - let (mut validation_host, task) = polkadot_node_core_pvf::start( + let (validation_host, task) = polkadot_node_core_pvf::start( polkadot_node_core_pvf::Config::new(cache_path, program_path), pvf_metrics, ); @@ -137,24 +137,28 @@ where pov, response_sender, ) => { - let _timer = metrics.time_validate_from_chain_state(); + let bg = { + let mut sender = ctx.sender().clone(); + let metrics = metrics.clone(); + let validation_host = validation_host.clone(); - let res = spawn_validate_from_chain_state( - &mut ctx, - &mut validation_host, - descriptor, - pov, - &metrics, - ) - .await; + async move { + let _timer = metrics.time_validate_from_chain_state(); + let res = validate_from_chain_state( + &mut sender, + validation_host, + descriptor, + pov, + &metrics, + ) + .await; - match res { - Ok(x) => { - metrics.on_validation_event(&x); - let _ = response_sender.send(x); - }, - Err(e) => return Err(e), - } + metrics.on_validation_event(&res); + let _ = response_sender.send(res); + } + }; + + ctx.spawn("validate-from-chain-state", bg.boxed())?; }, CandidateValidationMessage::ValidateFromExhaustive( persisted_validation_data, @@ -163,50 +167,68 @@ where pov, response_sender, ) => { - let _timer = metrics.time_validate_from_exhaustive(); + let bg = { + let metrics = metrics.clone(); + let validation_host = validation_host.clone(); - let res = validate_candidate_exhaustive( - &mut validation_host, - persisted_validation_data, - validation_code, - descriptor, - pov, - &metrics, - ) - .await; + async move { + let _timer = metrics.time_validate_from_exhaustive(); + let res = validate_candidate_exhaustive( + validation_host, + persisted_validation_data, + validation_code, + descriptor, + pov, + &metrics, + ) + .await; - match res { - Ok(x) => { - metrics.on_validation_event(&x); + metrics.on_validation_event(&res); + let _ = response_sender.send(res); + } + }; - if let Err(_e) = response_sender.send(x) { - tracing::warn!( - target: LOG_TARGET, - "Requester of candidate validation dropped", - ) - } - }, - Err(e) => return Err(e), - } + ctx.spawn("validate-from-exhaustive", bg.boxed())?; }, }, } } } -async fn runtime_api_request( - ctx: &mut Context, +struct RuntimeRequestFailed; + +async fn runtime_api_request( + sender: &mut Sender, relay_parent: Hash, request: RuntimeApiRequest, receiver: oneshot::Receiver>, -) -> SubsystemResult> +) -> Result where - Context: SubsystemContext, - Context: overseer::SubsystemContext, + Sender: SubsystemSender, { - ctx.send_message(RuntimeApiMessage::Request(relay_parent, request)).await; + sender + .send_message(RuntimeApiMessage::Request(relay_parent, request).into()) + .await; - receiver.await.map_err(Into::into) + receiver + .await + .map_err(|_| { + tracing::debug!(target: LOG_TARGET, ?relay_parent, "Runtime API request dropped"); + + RuntimeRequestFailed + }) + .and_then(|res| { + res.map_err(|e| { + tracing::debug!( + target: LOG_TARGET, + ?relay_parent, + err = ?e, + "Runtime API request internal error" + ); + + RuntimeRequestFailed + }) + }) } #[derive(Debug)] @@ -216,61 +238,57 @@ enum AssumptionCheckOutcome { BadRequest, } -async fn check_assumption_validation_data( - ctx: &mut Context, +async fn check_assumption_validation_data( + sender: &mut Sender, descriptor: &CandidateDescriptor, assumption: OccupiedCoreAssumption, -) -> SubsystemResult +) -> AssumptionCheckOutcome where - Context: SubsystemContext, - Context: overseer::SubsystemContext, + Sender: SubsystemSender, { let validation_data = { let (tx, rx) = oneshot::channel(); let d = runtime_api_request( - ctx, + sender, descriptor.relay_parent, RuntimeApiRequest::PersistedValidationData(descriptor.para_id, assumption, tx), rx, ) - .await?; + .await; match d { - Ok(None) | Err(_) => return Ok(AssumptionCheckOutcome::BadRequest), + Ok(None) | Err(RuntimeRequestFailed) => return AssumptionCheckOutcome::BadRequest, Ok(Some(d)) => d, } }; let persisted_validation_data_hash = validation_data.hash(); - SubsystemResult::Ok( - if descriptor.persisted_validation_data_hash == persisted_validation_data_hash { - let (code_tx, code_rx) = oneshot::channel(); - let validation_code = runtime_api_request( - ctx, - descriptor.relay_parent, - RuntimeApiRequest::ValidationCode(descriptor.para_id, assumption, code_tx), - code_rx, - ) - .await?; + if descriptor.persisted_validation_data_hash == persisted_validation_data_hash { + let (code_tx, code_rx) = oneshot::channel(); + let validation_code = runtime_api_request( + sender, + descriptor.relay_parent, + RuntimeApiRequest::ValidationCode(descriptor.para_id, assumption, code_tx), + code_rx, + ) + .await; - match validation_code { - Ok(None) | Err(_) => AssumptionCheckOutcome::BadRequest, - Ok(Some(v)) => AssumptionCheckOutcome::Matches(validation_data, v), - } - } else { - AssumptionCheckOutcome::DoesNotMatch - }, - ) + match validation_code { + Ok(None) | Err(RuntimeRequestFailed) => AssumptionCheckOutcome::BadRequest, + Ok(Some(v)) => AssumptionCheckOutcome::Matches(validation_data, v), + } + } else { + AssumptionCheckOutcome::DoesNotMatch + } } -async fn find_assumed_validation_data( - ctx: &mut Context, +async fn find_assumed_validation_data( + sender: &mut Sender, descriptor: &CandidateDescriptor, -) -> SubsystemResult +) -> AssumptionCheckOutcome where - Context: SubsystemContext, - Context: overseer::SubsystemContext, + Sender: SubsystemSender, { // The candidate descriptor has a `persisted_validation_data_hash` which corresponds to // one of up to two possible values that we can derive from the state of the @@ -287,41 +305,40 @@ where // Consider running these checks in parallel to reduce validation latency. for assumption in ASSUMPTIONS { - let outcome = check_assumption_validation_data(ctx, descriptor, *assumption).await?; + let outcome = check_assumption_validation_data(sender, descriptor, *assumption).await; match outcome { - AssumptionCheckOutcome::Matches(_, _) => return Ok(outcome), - AssumptionCheckOutcome::BadRequest => return Ok(outcome), + AssumptionCheckOutcome::Matches(_, _) => return outcome, + AssumptionCheckOutcome::BadRequest => return outcome, AssumptionCheckOutcome::DoesNotMatch => continue, } } - Ok(AssumptionCheckOutcome::DoesNotMatch) + AssumptionCheckOutcome::DoesNotMatch } -async fn spawn_validate_from_chain_state( - ctx: &mut Context, - validation_host: &mut ValidationHost, +async fn validate_from_chain_state( + sender: &mut Sender, + validation_host: ValidationHost, descriptor: CandidateDescriptor, pov: Arc, metrics: &Metrics, -) -> SubsystemResult> +) -> Result where - Context: SubsystemContext, - Context: overseer::SubsystemContext, + Sender: SubsystemSender, { let (validation_data, validation_code) = - match find_assumed_validation_data(ctx, &descriptor).await? { + match find_assumed_validation_data(sender, &descriptor).await { AssumptionCheckOutcome::Matches(validation_data, validation_code) => (validation_data, validation_code), AssumptionCheckOutcome::DoesNotMatch => { // If neither the assumption of the occupied core having the para included or the assumption // of the occupied core timing out are valid, then the persisted_validation_data_hash in the descriptor // is not based on the relay parent and is thus invalid. - return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::BadParent))) + return Ok(ValidationResult::Invalid(InvalidCandidate::BadParent)) }, AssumptionCheckOutcome::BadRequest => - return Ok(Err(ValidationFailed("Assumption Check: Bad request".into()))), + return Err(ValidationFailed("Assumption Check: Bad request".into())), }; let validation_result = validate_candidate_exhaustive( @@ -334,20 +351,20 @@ where ) .await; - if let Ok(Ok(ValidationResult::Valid(ref outputs, _))) = validation_result { + if let Ok(ValidationResult::Valid(ref outputs, _)) = validation_result { let (tx, rx) = oneshot::channel(); match runtime_api_request( - ctx, + sender, descriptor.relay_parent, RuntimeApiRequest::CheckValidationOutputs(descriptor.para_id, outputs.clone(), tx), rx, ) - .await? + .await { Ok(true) => {}, - Ok(false) => return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::InvalidOutputs))), - Err(_) => - return Ok(Err(ValidationFailed("Check Validation Outputs: Bad request".into()))), + Ok(false) => return Ok(ValidationResult::Invalid(InvalidCandidate::InvalidOutputs)), + Err(RuntimeRequestFailed) => + return Err(ValidationFailed("Check Validation Outputs: Bad request".into())), } } @@ -361,7 +378,7 @@ async fn validate_candidate_exhaustive( descriptor: CandidateDescriptor, pov: Arc, metrics: &Metrics, -) -> SubsystemResult> { +) -> Result { let _timer = metrics.time_validate_candidate_exhaustive(); let validation_code_hash = validation_code.hash(); @@ -378,7 +395,7 @@ async fn validate_candidate_exhaustive( &*pov, &validation_code_hash, ) { - return Ok(Ok(ValidationResult::Invalid(e))) + return Ok(ValidationResult::Invalid(e)) } let raw_validation_code = match sp_maybe_compressed_blob::decompress( @@ -390,7 +407,7 @@ async fn validate_candidate_exhaustive( tracing::debug!(target: LOG_TARGET, err=?e, "Invalid validation code"); // If the validation code is invalid, the candidate certainly is. - return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::CodeDecompressionFailure))) + return Ok(ValidationResult::Invalid(InvalidCandidate::CodeDecompressionFailure)) }, }; @@ -401,7 +418,7 @@ async fn validate_candidate_exhaustive( tracing::debug!(target: LOG_TARGET, err=?e, "Invalid PoV code"); // If the PoV is invalid, the candidate certainly is. - return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::PoVDecompressionFailure))) + return Ok(ValidationResult::Invalid(InvalidCandidate::PoVDecompressionFailure)) }, }; @@ -424,7 +441,7 @@ async fn validate_candidate_exhaustive( ); } - let result = match result { + match result { Err(ValidationError::InternalError(e)) => Err(ValidationFailed(e)), Err(ValidationError::InvalidCandidate(WasmInvalidCandidate::HardTimeout)) => @@ -450,9 +467,7 @@ async fn validate_candidate_exhaustive( }; Ok(ValidationResult::Valid(outputs, persisted_validation_data)) }, - }; - - Ok(result) + } } #[async_trait] @@ -465,7 +480,7 @@ trait ValidationBackend { } #[async_trait] -impl ValidationBackend for &'_ mut ValidationHost { +impl ValidationBackend for ValidationHost { async fn validate_candidate( &mut self, raw_validation_code: Vec, diff --git a/polkadot/node/core/candidate-validation/src/tests.rs b/polkadot/node/core/candidate-validation/src/tests.rs index 15314c35ed..f067ead6cd 100644 --- a/polkadot/node/core/candidate-validation/src/tests.rs +++ b/polkadot/node/core/candidate-validation/src/tests.rs @@ -19,6 +19,7 @@ use assert_matches::assert_matches; use futures::executor; use polkadot_node_subsystem::messages::AllMessages; use polkadot_node_subsystem_test_helpers as test_helpers; +use polkadot_node_subsystem_util::reexports::SubsystemContext; use polkadot_primitives::v1::{HeadData, UpwardMessage}; use sp_core::testing::TaskExecutor; use sp_keyring::Sr25519Keyring; @@ -52,11 +53,15 @@ fn correctly_checks_included_assumption() { candidate.para_id = para_id; let pool = TaskExecutor::new(); - let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); + let (mut ctx, mut ctx_handle) = + test_helpers::make_subsystem_context::(pool.clone()); - let (check_fut, check_result) = - check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::Included) - .remote_handle(); + let (check_fut, check_result) = check_assumption_validation_data( + ctx.sender(), + &candidate, + OccupiedCoreAssumption::Included, + ) + .remote_handle(); let test_fut = async move { assert_matches!( @@ -89,7 +94,7 @@ fn correctly_checks_included_assumption() { } ); - assert_matches!(check_result.await.unwrap(), AssumptionCheckOutcome::Matches(o, v) => { + assert_matches!(check_result.await, AssumptionCheckOutcome::Matches(o, v) => { assert_eq!(o, validation_data); assert_eq!(v, validation_code); }); @@ -114,11 +119,15 @@ fn correctly_checks_timed_out_assumption() { candidate.para_id = para_id; let pool = TaskExecutor::new(); - let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); + let (mut ctx, mut ctx_handle) = + test_helpers::make_subsystem_context::(pool.clone()); - let (check_fut, check_result) = - check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::TimedOut) - .remote_handle(); + let (check_fut, check_result) = check_assumption_validation_data( + ctx.sender(), + &candidate, + OccupiedCoreAssumption::TimedOut, + ) + .remote_handle(); let test_fut = async move { assert_matches!( @@ -151,7 +160,7 @@ fn correctly_checks_timed_out_assumption() { } ); - assert_matches!(check_result.await.unwrap(), AssumptionCheckOutcome::Matches(o, v) => { + assert_matches!(check_result.await, AssumptionCheckOutcome::Matches(o, v) => { assert_eq!(o, validation_data); assert_eq!(v, validation_code); }); @@ -174,11 +183,15 @@ fn check_is_bad_request_if_no_validation_data() { candidate.para_id = para_id; let pool = TaskExecutor::new(); - let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); + let (mut ctx, mut ctx_handle) = + test_helpers::make_subsystem_context::(pool.clone()); - let (check_fut, check_result) = - check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::Included) - .remote_handle(); + let (check_fut, check_result) = check_assumption_validation_data( + ctx.sender(), + &candidate, + OccupiedCoreAssumption::Included, + ) + .remote_handle(); let test_fut = async move { assert_matches!( @@ -198,7 +211,7 @@ fn check_is_bad_request_if_no_validation_data() { } ); - assert_matches!(check_result.await.unwrap(), AssumptionCheckOutcome::BadRequest); + assert_matches!(check_result.await, AssumptionCheckOutcome::BadRequest); }; let test_fut = future::join(test_fut, check_fut); @@ -218,11 +231,15 @@ fn check_is_bad_request_if_no_validation_code() { candidate.para_id = para_id; let pool = TaskExecutor::new(); - let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); + let (mut ctx, mut ctx_handle) = + test_helpers::make_subsystem_context::(pool.clone()); - let (check_fut, check_result) = - check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::TimedOut) - .remote_handle(); + let (check_fut, check_result) = check_assumption_validation_data( + ctx.sender(), + &candidate, + OccupiedCoreAssumption::TimedOut, + ) + .remote_handle(); let test_fut = async move { assert_matches!( @@ -255,7 +272,7 @@ fn check_is_bad_request_if_no_validation_code() { } ); - assert_matches!(check_result.await.unwrap(), AssumptionCheckOutcome::BadRequest); + assert_matches!(check_result.await, AssumptionCheckOutcome::BadRequest); }; let test_fut = future::join(test_fut, check_fut); @@ -274,11 +291,15 @@ fn check_does_not_match() { candidate.para_id = para_id; let pool = TaskExecutor::new(); - let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); + let (mut ctx, mut ctx_handle) = + test_helpers::make_subsystem_context::(pool.clone()); - let (check_fut, check_result) = - check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::Included) - .remote_handle(); + let (check_fut, check_result) = check_assumption_validation_data( + ctx.sender(), + &candidate, + OccupiedCoreAssumption::Included, + ) + .remote_handle(); let test_fut = async move { assert_matches!( @@ -298,7 +319,7 @@ fn check_does_not_match() { } ); - assert_matches!(check_result.await.unwrap(), AssumptionCheckOutcome::DoesNotMatch); + assert_matches!(check_result.await, AssumptionCheckOutcome::DoesNotMatch); }; let test_fut = future::join(test_fut, check_fut); @@ -365,7 +386,6 @@ fn candidate_validation_ok_is_ok() { Arc::new(pov), &Default::default(), )) - .unwrap() .unwrap(); assert_matches!(v, ValidationResult::Valid(outputs, used_validation_data) => { @@ -408,7 +428,6 @@ fn candidate_validation_bad_return_is_invalid() { Arc::new(pov), &Default::default(), )) - .unwrap() .unwrap(); assert_matches!(v, ValidationResult::Invalid(InvalidCandidate::ExecutionError(_))); @@ -443,8 +462,7 @@ fn candidate_validation_timeout_is_internal_error() { descriptor, Arc::new(pov), &Default::default(), - )) - .unwrap(); + )); assert_matches!(v, Ok(ValidationResult::Invalid(InvalidCandidate::Timeout))); } @@ -479,7 +497,6 @@ fn candidate_validation_code_mismatch_is_invalid() { Arc::new(pov), &Default::default(), )) - .unwrap() .unwrap(); assert_matches!(v, ValidationResult::Invalid(InvalidCandidate::CodeHashMismatch)); @@ -518,8 +535,7 @@ fn compressed_code_works() { descriptor, Arc::new(pov), &Default::default(), - )) - .unwrap(); + )); assert_matches!(v, Ok(ValidationResult::Valid(_, _))); } @@ -558,8 +574,7 @@ fn code_decompression_failure_is_invalid() { descriptor, Arc::new(pov), &Default::default(), - )) - .unwrap(); + )); assert_matches!(v, Ok(ValidationResult::Invalid(InvalidCandidate::CodeDecompressionFailure))); } @@ -599,8 +614,7 @@ fn pov_decompression_failure_is_invalid() { descriptor, Arc::new(pov), &Default::default(), - )) - .unwrap(); + )); assert_matches!(v, Ok(ValidationResult::Invalid(InvalidCandidate::PoVDecompressionFailure))); }