Make the code even more concurrent

This commit is contained in:
Omar Abdulla
2025-08-18 18:39:03 +03:00
parent d45b8da8e3
commit ad3d580df9
6 changed files with 130 additions and 107 deletions
Generated
+1
View File
@@ -4545,6 +4545,7 @@ dependencies = [
"alloy-primitives", "alloy-primitives",
"alloy-sol-types", "alloy-sol-types",
"anyhow", "anyhow",
"futures",
"regex", "regex",
"revive-common", "revive-common",
"revive-dt-common", "revive-dt-common",
+44 -41
View File
@@ -19,6 +19,7 @@ use alloy::{
rpc::types::{TransactionRequest, trace::geth::DiffMode}, rpc::types::{TransactionRequest, trace::geth::DiffMode},
}; };
use anyhow::Context; use anyhow::Context;
use futures::TryStreamExt;
use indexmap::IndexMap; use indexmap::IndexMap;
use revive_dt_format::traits::{ResolutionContext, ResolverApi}; use revive_dt_format::traits::{ResolutionContext, ResolverApi};
use semver::Version; use semver::Version;
@@ -31,6 +32,7 @@ use revive_dt_format::input::{
use revive_dt_format::metadata::{ContractIdent, ContractInstance, ContractPathAndIdent}; use revive_dt_format::metadata::{ContractIdent, ContractInstance, ContractPathAndIdent};
use revive_dt_format::{input::Step, metadata::Metadata}; use revive_dt_format::{input::Step, metadata::Metadata};
use revive_dt_node_interaction::EthereumNode; use revive_dt_node_interaction::EthereumNode;
use tokio::try_join;
use tracing::{Instrument, info, info_span, instrument}; use tracing::{Instrument, info, info_span, instrument};
use crate::Platform; use crate::Platform;
@@ -113,9 +115,11 @@ where
.handle_input_call_frame_tracing(&execution_receipt, node) .handle_input_call_frame_tracing(&execution_receipt, node)
.await?; .await?;
self.handle_input_variable_assignment(input, &tracing_result)?; self.handle_input_variable_assignment(input, &tracing_result)?;
self.handle_input_expectations(input, &execution_receipt, node, &tracing_result) let (_, (geth_trace, diff_mode)) = try_join!(
.await?; self.handle_input_expectations(input, &execution_receipt, node, &tracing_result),
self.handle_input_diff(execution_receipt, node).await self.handle_input_diff(&execution_receipt, node)
)?;
Ok((execution_receipt, geth_trace, diff_mode))
} }
#[instrument(level = "info", name = "Handling Balance Assertion", skip_all)] #[instrument(level = "info", name = "Handling Balance Assertion", skip_all)]
@@ -287,7 +291,7 @@ where
#[instrument(level = "info", skip_all)] #[instrument(level = "info", skip_all)]
async fn handle_input_expectations( async fn handle_input_expectations(
&mut self, &self,
input: &Input, input: &Input,
execution_receipt: &TransactionReceipt, execution_receipt: &TransactionReceipt,
resolver: &impl ResolverApi, resolver: &impl ResolverApi,
@@ -321,25 +325,25 @@ where
} }
} }
for expectation in expectations.iter() { futures::stream::iter(expectations.into_iter().map(Ok))
self.handle_input_expectation_item( .try_for_each_concurrent(None, |expectation| async move {
execution_receipt, self.handle_input_expectation_item(
resolver, execution_receipt,
expectation, resolver,
tracing_result, expectation,
) tracing_result,
.await?; )
} .await
})
Ok(()) .await
} }
#[instrument(level = "info", skip_all)] #[instrument(level = "info", skip_all)]
async fn handle_input_expectation_item( async fn handle_input_expectation_item(
&mut self, &self,
execution_receipt: &TransactionReceipt, execution_receipt: &TransactionReceipt,
resolver: &impl ResolverApi, resolver: &impl ResolverApi,
expectation: &ExpectedOutput, expectation: ExpectedOutput,
tracing_result: &CallFrame, tracing_result: &CallFrame,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
if let Some(ref version_requirement) = expectation.compiler_version { if let Some(ref version_requirement) = expectation.compiler_version {
@@ -479,10 +483,10 @@ where
#[instrument(level = "info", skip_all)] #[instrument(level = "info", skip_all)]
async fn handle_input_diff( async fn handle_input_diff(
&mut self, &self,
execution_receipt: TransactionReceipt, execution_receipt: &TransactionReceipt,
node: &T::Blockchain, node: &T::Blockchain,
) -> anyhow::Result<(TransactionReceipt, GethTrace, DiffMode)> { ) -> anyhow::Result<(GethTrace, DiffMode)> {
let trace_options = GethDebugTracingOptions::prestate_tracer(PreStateConfig { let trace_options = GethDebugTracingOptions::prestate_tracer(PreStateConfig {
diff_mode: Some(true), diff_mode: Some(true),
disable_code: None, disable_code: None,
@@ -490,11 +494,11 @@ where
}); });
let trace = node let trace = node
.trace_transaction(&execution_receipt, trace_options) .trace_transaction(execution_receipt, trace_options)
.await?; .await?;
let diff = node.state_diff(&execution_receipt).await?; let diff = node.state_diff(execution_receipt).await?;
Ok((execution_receipt, trace, diff)) Ok((trace, diff))
} }
#[instrument(level = "info", skip_all)] #[instrument(level = "info", skip_all)]
@@ -772,24 +776,23 @@ where
.enumerate() .enumerate()
.map(|(idx, v)| (StepIdx::new(idx), v)) .map(|(idx, v)| (StepIdx::new(idx), v))
{ {
let leader_step_output = self let (leader_step_output, follower_step_output) = try_join!(
.leader_state self.leader_state
.handle_step(self.metadata, &step, self.leader_node) .handle_step(self.metadata, &step, self.leader_node)
.instrument(info_span!( .instrument(info_span!(
"Handling Step", "Handling Step",
%step_idx, %step_idx,
target = "Leader", target = "Leader",
)) )),
.await?; self.follower_state
let follower_step_output = self .handle_step(self.metadata, &step, self.follower_node)
.follower_state .instrument(info_span!(
.handle_step(self.metadata, &step, self.follower_node) "Handling Step",
.instrument(info_span!( %step_idx,
"Handling Step", target = "Follower",
%step_idx, ))
target = "Follower", )?;
))
.await?;
match (leader_step_output, follower_step_output) { match (leader_step_output, follower_step_output) {
(StepOutput::FunctionCall(..), StepOutput::FunctionCall(..)) => { (StepOutput::FunctionCall(..), StepOutput::FunctionCall(..)) => {
// TODO: We need to actually work out how/if we will compare the diff between // TODO: We need to actually work out how/if we will compare the diff between
+45 -35
View File
@@ -19,13 +19,13 @@ use futures::{Stream, StreamExt};
use indexmap::IndexMap; use indexmap::IndexMap;
use revive_dt_node_interaction::EthereumNode; use revive_dt_node_interaction::EthereumNode;
use temp_dir::TempDir; use temp_dir::TempDir;
use tokio::sync::mpsc; use tokio::{sync::mpsc, try_join};
use tracing::{debug, info, info_span, instrument}; use tracing::{debug, info, info_span, instrument};
use tracing_appender::non_blocking::WorkerGuard; use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{EnvFilter, FmtSubscriber}; use tracing_subscriber::{EnvFilter, FmtSubscriber};
use revive_dt_common::types::Mode; use revive_dt_common::types::Mode;
use revive_dt_compiler::SolidityCompiler; use revive_dt_compiler::{CompilerOutput, SolidityCompiler};
use revive_dt_config::*; use revive_dt_config::*;
use revive_dt_core::{ use revive_dt_core::{
Geth, Kitchensink, Platform, Geth, Kitchensink, Platform,
@@ -500,16 +500,23 @@ where
L::Blockchain: revive_dt_node::Node + Send + Sync + 'static, L::Blockchain: revive_dt_node::Node + Send + Sync + 'static,
F::Blockchain: revive_dt_node::Node + Send + Sync + 'static, F::Blockchain: revive_dt_node::Node + Send + Sync + 'static,
{ {
let leader_pre_link_contracts = cached_compiler let (
.compile_contracts::<L>(metadata, metadata_file_path, &mode, config, None) (
.await? CompilerOutput {
.0 contracts: leader_pre_link_contracts,
.contracts; },
let follower_pre_link_contracts = cached_compiler _,
.compile_contracts::<F>(metadata, metadata_file_path, &mode, config, None) ),
.await? (
.0 CompilerOutput {
.contracts; contracts: follower_pre_link_contracts,
},
_,
),
) = try_join!(
cached_compiler.compile_contracts::<L>(metadata, metadata_file_path, &mode, config, None),
cached_compiler.compile_contracts::<F>(metadata, metadata_file_path, &mode, config, None)
)?;
let mut leader_deployed_libraries = None::<HashMap<_, _>>; let mut leader_deployed_libraries = None::<HashMap<_, _>>;
let mut follower_deployed_libraries = None::<HashMap<_, _>>; let mut follower_deployed_libraries = None::<HashMap<_, _>>;
@@ -573,18 +580,10 @@ where
follower_code, follower_code,
); );
let leader_receipt = match leader_node.execute_transaction(leader_tx).await { let (leader_receipt, follower_receipt) = try_join!(
Ok(receipt) => receipt, leader_node.execute_transaction(leader_tx),
Err(error) => { follower_node.execute_transaction(follower_tx)
return Err(error); )?;
}
};
let follower_receipt = match follower_node.execute_transaction(follower_tx).await {
Ok(receipt) => receipt,
Err(error) => {
return Err(error);
}
};
debug!( debug!(
?library_instance, ?library_instance,
@@ -622,33 +621,44 @@ where
); );
} }
let (leader_post_link_contracts, leader_compiler_version) = cached_compiler let (
.compile_contracts::<L>( (
CompilerOutput {
contracts: leader_post_link_contracts,
},
leader_compiler_version,
),
(
CompilerOutput {
contracts: follower_post_link_contracts,
},
follower_compiler_version,
),
) = try_join!(
cached_compiler.compile_contracts::<L>(
metadata, metadata,
metadata_file_path, metadata_file_path,
&mode, &mode,
config, config,
leader_deployed_libraries.as_ref(), leader_deployed_libraries.as_ref()
) ),
.await?; cached_compiler.compile_contracts::<F>(
let (follower_post_link_contracts, follower_compiler_version) = cached_compiler
.compile_contracts::<F>(
metadata, metadata,
metadata_file_path, metadata_file_path,
&mode, &mode,
config, config,
follower_deployed_libraries.as_ref(), follower_deployed_libraries.as_ref()
) )
.await?; )?;
let leader_state = CaseState::<L>::new( let leader_state = CaseState::<L>::new(
leader_compiler_version, leader_compiler_version,
leader_post_link_contracts.contracts, leader_post_link_contracts,
leader_deployed_libraries.unwrap_or_default(), leader_deployed_libraries.unwrap_or_default(),
); );
let follower_state = CaseState::<F>::new( let follower_state = CaseState::<F>::new(
follower_compiler_version, follower_compiler_version,
follower_post_link_contracts.contracts, follower_post_link_contracts,
follower_deployed_libraries.unwrap_or_default(), follower_deployed_libraries.unwrap_or_default(),
); );
+1
View File
@@ -17,6 +17,7 @@ alloy = { workspace = true }
alloy-primitives = { workspace = true } alloy-primitives = { workspace = true }
alloy-sol-types = { workspace = true } alloy-sol-types = { workspace = true }
anyhow = { workspace = true } anyhow = { workspace = true }
futures = { workspace = true }
regex = { workspace = true } regex = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
semver = { workspace = true } semver = { workspace = true }
+33 -28
View File
@@ -9,6 +9,7 @@ use alloy::{
}; };
use alloy_primitives::{FixedBytes, utils::parse_units}; use alloy_primitives::{FixedBytes, utils::parse_units};
use anyhow::Context; use anyhow::Context;
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, stream};
use semver::VersionReq; use semver::VersionReq;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -427,14 +428,18 @@ impl Calldata {
buffer.extend_from_slice(bytes); buffer.extend_from_slice(bytes);
} }
Calldata::Compound(items) => { Calldata::Compound(items) => {
for (arg_idx, arg) in items.iter().enumerate() { let resolved = stream::iter(items.iter().enumerate())
buffer.extend( .map(|(arg_idx, arg)| async move {
arg.resolve(resolver, context) arg.resolve(resolver, context)
.instrument(info_span!("Resolving argument", %arg, arg_idx)) .instrument(info_span!("Resolving argument", %arg, arg_idx))
.await? .map_ok(|value| value.to_be_bytes::<32>())
.to_be_bytes::<32>(), .await
); })
} .buffered(0xFF)
.try_collect::<Vec<_>>()
.await?;
buffer.extend(resolved.into_iter().flatten());
} }
}; };
Ok(()) Ok(())
@@ -457,30 +462,30 @@ impl Calldata {
match self { match self {
Calldata::Single(calldata) => Ok(calldata == other), Calldata::Single(calldata) => Ok(calldata == other),
Calldata::Compound(items) => { Calldata::Compound(items) => {
// Chunking the "other" calldata into 32 byte chunks since each stream::iter(items.iter().zip(other.chunks(32)))
// one of the items in the compound calldata represents 32 bytes .map(|(this, other)| async move {
for (this, other) in items.iter().zip(other.chunks(32)) { // The matterlabs format supports wildcards and therefore we
// The matterlabs format supports wildcards and therefore we // also need to support them.
// also need to support them. if this.as_ref() == "*" {
if this.as_ref() == "*" { return Ok::<_, anyhow::Error>(true);
continue; }
}
let other = if other.len() < 32 { let other = if other.len() < 32 {
let mut vec = other.to_vec(); let mut vec = other.to_vec();
vec.resize(32, 0); vec.resize(32, 0);
std::borrow::Cow::Owned(vec) std::borrow::Cow::Owned(vec)
} else { } else {
std::borrow::Cow::Borrowed(other) std::borrow::Cow::Borrowed(other)
}; };
let this = this.resolve(resolver, context).await?; let this = this.resolve(resolver, context).await?;
let other = U256::from_be_slice(&other); let other = U256::from_be_slice(&other);
if this != other { Ok(this == other)
return Ok(false); })
} .buffered(0xFF)
} .all(|v| async move { v.is_ok_and(|v| v) })
Ok(true) .map(Ok)
.await
} }
} }
} }
+6 -3
View File
@@ -35,7 +35,10 @@ use anyhow::Context;
use revive_common::EVMVersion; use revive_common::EVMVersion;
use tracing::{Instrument, instrument}; use tracing::{Instrument, instrument};
use revive_dt_common::{fs::clear_directory, futures::poll}; use revive_dt_common::{
fs::clear_directory,
futures::{PollingWaitBehavior, poll},
};
use revive_dt_config::Arguments; use revive_dt_config::Arguments;
use revive_dt_format::traits::ResolverApi; use revive_dt_format::traits::ResolverApi;
use revive_dt_node_interaction::EthereumNode; use revive_dt_node_interaction::EthereumNode;
@@ -301,7 +304,7 @@ impl EthereumNode for GethNode {
let provider = Arc::new(provider); let provider = Arc::new(provider);
poll( poll(
Self::RECEIPT_POLLING_DURATION, Self::RECEIPT_POLLING_DURATION,
Default::default(), PollingWaitBehavior::Constant(Duration::from_millis(200)),
move || { move || {
let provider = provider.clone(); let provider = provider.clone();
async move { async move {
@@ -335,7 +338,7 @@ impl EthereumNode for GethNode {
let provider = Arc::new(self.provider().await?); let provider = Arc::new(self.provider().await?);
poll( poll(
Self::TRACE_POLLING_DURATION, Self::TRACE_POLLING_DURATION,
Default::default(), PollingWaitBehavior::Constant(Duration::from_millis(200)),
move || { move || {
let provider = provider.clone(); let provider = provider.clone();
let trace_options = trace_options.clone(); let trace_options = trace_options.clone();