From ad3d580df9c466e68b528b3b1bc8b8cdab8a40ae Mon Sep 17 00:00:00 2001 From: Omar Abdulla Date: Mon, 18 Aug 2025 18:39:03 +0300 Subject: [PATCH] Make the code even more concurrent --- Cargo.lock | 1 + crates/core/src/driver/mod.rs | 85 ++++++++++++++++++----------------- crates/core/src/main.rs | 80 ++++++++++++++++++--------------- crates/format/Cargo.toml | 1 + crates/format/src/input.rs | 61 +++++++++++++------------ crates/node/src/geth.rs | 9 ++-- 6 files changed, 130 insertions(+), 107 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 099e0ae..a04261a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4545,6 +4545,7 @@ dependencies = [ "alloy-primitives", "alloy-sol-types", "anyhow", + "futures", "regex", "revive-common", "revive-dt-common", diff --git a/crates/core/src/driver/mod.rs b/crates/core/src/driver/mod.rs index bea8c6b..4912b9a 100644 --- a/crates/core/src/driver/mod.rs +++ b/crates/core/src/driver/mod.rs @@ -19,6 +19,7 @@ use alloy::{ rpc::types::{TransactionRequest, trace::geth::DiffMode}, }; use anyhow::Context; +use futures::TryStreamExt; use indexmap::IndexMap; use revive_dt_format::traits::{ResolutionContext, ResolverApi}; use semver::Version; @@ -31,6 +32,7 @@ use revive_dt_format::input::{ use revive_dt_format::metadata::{ContractIdent, ContractInstance, ContractPathAndIdent}; use revive_dt_format::{input::Step, metadata::Metadata}; use revive_dt_node_interaction::EthereumNode; +use tokio::try_join; use tracing::{Instrument, info, info_span, instrument}; use crate::Platform; @@ -113,9 +115,11 @@ where .handle_input_call_frame_tracing(&execution_receipt, node) .await?; self.handle_input_variable_assignment(input, &tracing_result)?; - self.handle_input_expectations(input, &execution_receipt, node, &tracing_result) - .await?; - self.handle_input_diff(execution_receipt, node).await + let (_, (geth_trace, diff_mode)) = try_join!( + self.handle_input_expectations(input, &execution_receipt, node, &tracing_result), + self.handle_input_diff(&execution_receipt, node) + )?; + Ok((execution_receipt, geth_trace, diff_mode)) } #[instrument(level = "info", name = "Handling Balance Assertion", skip_all)] @@ -287,7 +291,7 @@ where #[instrument(level = "info", skip_all)] async fn handle_input_expectations( - &mut self, + &self, input: &Input, execution_receipt: &TransactionReceipt, resolver: &impl ResolverApi, @@ -321,25 +325,25 @@ where } } - for expectation in expectations.iter() { - self.handle_input_expectation_item( - execution_receipt, - resolver, - expectation, - tracing_result, - ) - .await?; - } - - Ok(()) + futures::stream::iter(expectations.into_iter().map(Ok)) + .try_for_each_concurrent(None, |expectation| async move { + self.handle_input_expectation_item( + execution_receipt, + resolver, + expectation, + tracing_result, + ) + .await + }) + .await } #[instrument(level = "info", skip_all)] async fn handle_input_expectation_item( - &mut self, + &self, execution_receipt: &TransactionReceipt, resolver: &impl ResolverApi, - expectation: &ExpectedOutput, + expectation: ExpectedOutput, tracing_result: &CallFrame, ) -> anyhow::Result<()> { if let Some(ref version_requirement) = expectation.compiler_version { @@ -479,10 +483,10 @@ where #[instrument(level = "info", skip_all)] async fn handle_input_diff( - &mut self, - execution_receipt: TransactionReceipt, + &self, + execution_receipt: &TransactionReceipt, node: &T::Blockchain, - ) -> anyhow::Result<(TransactionReceipt, GethTrace, DiffMode)> { + ) -> anyhow::Result<(GethTrace, DiffMode)> { let trace_options = GethDebugTracingOptions::prestate_tracer(PreStateConfig { diff_mode: Some(true), disable_code: None, @@ -490,11 +494,11 @@ where }); let trace = node - .trace_transaction(&execution_receipt, trace_options) + .trace_transaction(execution_receipt, trace_options) .await?; - let diff = node.state_diff(&execution_receipt).await?; + let diff = node.state_diff(execution_receipt).await?; - Ok((execution_receipt, trace, diff)) + Ok((trace, diff)) } #[instrument(level = "info", skip_all)] @@ -772,24 +776,23 @@ where .enumerate() .map(|(idx, v)| (StepIdx::new(idx), v)) { - let leader_step_output = self - .leader_state - .handle_step(self.metadata, &step, self.leader_node) - .instrument(info_span!( - "Handling Step", - %step_idx, - target = "Leader", - )) - .await?; - let follower_step_output = self - .follower_state - .handle_step(self.metadata, &step, self.follower_node) - .instrument(info_span!( - "Handling Step", - %step_idx, - target = "Follower", - )) - .await?; + let (leader_step_output, follower_step_output) = try_join!( + self.leader_state + .handle_step(self.metadata, &step, self.leader_node) + .instrument(info_span!( + "Handling Step", + %step_idx, + target = "Leader", + )), + self.follower_state + .handle_step(self.metadata, &step, self.follower_node) + .instrument(info_span!( + "Handling Step", + %step_idx, + target = "Follower", + )) + )?; + match (leader_step_output, follower_step_output) { (StepOutput::FunctionCall(..), StepOutput::FunctionCall(..)) => { // TODO: We need to actually work out how/if we will compare the diff between diff --git a/crates/core/src/main.rs b/crates/core/src/main.rs index f70ee6d..f74101d 100644 --- a/crates/core/src/main.rs +++ b/crates/core/src/main.rs @@ -19,13 +19,13 @@ use futures::{Stream, StreamExt}; use indexmap::IndexMap; use revive_dt_node_interaction::EthereumNode; use temp_dir::TempDir; -use tokio::sync::mpsc; +use tokio::{sync::mpsc, try_join}; use tracing::{debug, info, info_span, instrument}; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use revive_dt_common::types::Mode; -use revive_dt_compiler::SolidityCompiler; +use revive_dt_compiler::{CompilerOutput, SolidityCompiler}; use revive_dt_config::*; use revive_dt_core::{ Geth, Kitchensink, Platform, @@ -500,16 +500,23 @@ where L::Blockchain: revive_dt_node::Node + Send + Sync + 'static, F::Blockchain: revive_dt_node::Node + Send + Sync + 'static, { - let leader_pre_link_contracts = cached_compiler - .compile_contracts::(metadata, metadata_file_path, &mode, config, None) - .await? - .0 - .contracts; - let follower_pre_link_contracts = cached_compiler - .compile_contracts::(metadata, metadata_file_path, &mode, config, None) - .await? - .0 - .contracts; + let ( + ( + CompilerOutput { + contracts: leader_pre_link_contracts, + }, + _, + ), + ( + CompilerOutput { + contracts: follower_pre_link_contracts, + }, + _, + ), + ) = try_join!( + cached_compiler.compile_contracts::(metadata, metadata_file_path, &mode, config, None), + cached_compiler.compile_contracts::(metadata, metadata_file_path, &mode, config, None) + )?; let mut leader_deployed_libraries = None::>; let mut follower_deployed_libraries = None::>; @@ -573,18 +580,10 @@ where follower_code, ); - let leader_receipt = match leader_node.execute_transaction(leader_tx).await { - Ok(receipt) => receipt, - Err(error) => { - return Err(error); - } - }; - let follower_receipt = match follower_node.execute_transaction(follower_tx).await { - Ok(receipt) => receipt, - Err(error) => { - return Err(error); - } - }; + let (leader_receipt, follower_receipt) = try_join!( + leader_node.execute_transaction(leader_tx), + follower_node.execute_transaction(follower_tx) + )?; debug!( ?library_instance, @@ -622,33 +621,44 @@ where ); } - let (leader_post_link_contracts, leader_compiler_version) = cached_compiler - .compile_contracts::( + let ( + ( + CompilerOutput { + contracts: leader_post_link_contracts, + }, + leader_compiler_version, + ), + ( + CompilerOutput { + contracts: follower_post_link_contracts, + }, + follower_compiler_version, + ), + ) = try_join!( + cached_compiler.compile_contracts::( metadata, metadata_file_path, &mode, config, - leader_deployed_libraries.as_ref(), - ) - .await?; - let (follower_post_link_contracts, follower_compiler_version) = cached_compiler - .compile_contracts::( + leader_deployed_libraries.as_ref() + ), + cached_compiler.compile_contracts::( metadata, metadata_file_path, &mode, config, - follower_deployed_libraries.as_ref(), + follower_deployed_libraries.as_ref() ) - .await?; + )?; let leader_state = CaseState::::new( leader_compiler_version, - leader_post_link_contracts.contracts, + leader_post_link_contracts, leader_deployed_libraries.unwrap_or_default(), ); let follower_state = CaseState::::new( follower_compiler_version, - follower_post_link_contracts.contracts, + follower_post_link_contracts, follower_deployed_libraries.unwrap_or_default(), ); diff --git a/crates/format/Cargo.toml b/crates/format/Cargo.toml index 2aa03a9..e49872f 100644 --- a/crates/format/Cargo.toml +++ b/crates/format/Cargo.toml @@ -17,6 +17,7 @@ alloy = { workspace = true } alloy-primitives = { workspace = true } alloy-sol-types = { workspace = true } anyhow = { workspace = true } +futures = { workspace = true } regex = { workspace = true } tracing = { workspace = true } semver = { workspace = true } diff --git a/crates/format/src/input.rs b/crates/format/src/input.rs index 897d650..918c2d4 100644 --- a/crates/format/src/input.rs +++ b/crates/format/src/input.rs @@ -9,6 +9,7 @@ use alloy::{ }; use alloy_primitives::{FixedBytes, utils::parse_units}; use anyhow::Context; +use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, stream}; use semver::VersionReq; use serde::{Deserialize, Serialize}; @@ -427,14 +428,18 @@ impl Calldata { buffer.extend_from_slice(bytes); } Calldata::Compound(items) => { - for (arg_idx, arg) in items.iter().enumerate() { - buffer.extend( + let resolved = stream::iter(items.iter().enumerate()) + .map(|(arg_idx, arg)| async move { arg.resolve(resolver, context) .instrument(info_span!("Resolving argument", %arg, arg_idx)) - .await? - .to_be_bytes::<32>(), - ); - } + .map_ok(|value| value.to_be_bytes::<32>()) + .await + }) + .buffered(0xFF) + .try_collect::>() + .await?; + + buffer.extend(resolved.into_iter().flatten()); } }; Ok(()) @@ -457,30 +462,30 @@ impl Calldata { match self { Calldata::Single(calldata) => Ok(calldata == other), Calldata::Compound(items) => { - // Chunking the "other" calldata into 32 byte chunks since each - // one of the items in the compound calldata represents 32 bytes - for (this, other) in items.iter().zip(other.chunks(32)) { - // The matterlabs format supports wildcards and therefore we - // also need to support them. - if this.as_ref() == "*" { - continue; - } + stream::iter(items.iter().zip(other.chunks(32))) + .map(|(this, other)| async move { + // The matterlabs format supports wildcards and therefore we + // also need to support them. + if this.as_ref() == "*" { + return Ok::<_, anyhow::Error>(true); + } - let other = if other.len() < 32 { - let mut vec = other.to_vec(); - vec.resize(32, 0); - std::borrow::Cow::Owned(vec) - } else { - std::borrow::Cow::Borrowed(other) - }; + let other = if other.len() < 32 { + let mut vec = other.to_vec(); + vec.resize(32, 0); + std::borrow::Cow::Owned(vec) + } else { + std::borrow::Cow::Borrowed(other) + }; - let this = this.resolve(resolver, context).await?; - let other = U256::from_be_slice(&other); - if this != other { - return Ok(false); - } - } - Ok(true) + let this = this.resolve(resolver, context).await?; + let other = U256::from_be_slice(&other); + Ok(this == other) + }) + .buffered(0xFF) + .all(|v| async move { v.is_ok_and(|v| v) }) + .map(Ok) + .await } } } diff --git a/crates/node/src/geth.rs b/crates/node/src/geth.rs index b96c05e..e72a7ed 100644 --- a/crates/node/src/geth.rs +++ b/crates/node/src/geth.rs @@ -35,7 +35,10 @@ use anyhow::Context; use revive_common::EVMVersion; use tracing::{Instrument, instrument}; -use revive_dt_common::{fs::clear_directory, futures::poll}; +use revive_dt_common::{ + fs::clear_directory, + futures::{PollingWaitBehavior, poll}, +}; use revive_dt_config::Arguments; use revive_dt_format::traits::ResolverApi; use revive_dt_node_interaction::EthereumNode; @@ -301,7 +304,7 @@ impl EthereumNode for GethNode { let provider = Arc::new(provider); poll( Self::RECEIPT_POLLING_DURATION, - Default::default(), + PollingWaitBehavior::Constant(Duration::from_millis(200)), move || { let provider = provider.clone(); async move { @@ -335,7 +338,7 @@ impl EthereumNode for GethNode { let provider = Arc::new(self.provider().await?); poll( Self::TRACE_POLLING_DURATION, - Default::default(), + PollingWaitBehavior::Constant(Duration::from_millis(200)), move || { let provider = provider.clone(); let trace_options = trace_options.clone();