Move compile check to prepare_tests

This commit is contained in:
James Wilson
2025-08-14 16:12:48 +01:00
parent b2290b3177
commit e5c7bf300b
+71 -65
View File
@@ -13,7 +13,8 @@ use alloy::{
}; };
use anyhow::Context; use anyhow::Context;
use clap::Parser; use clap::Parser;
use futures::StreamExt; use futures::stream::futures_unordered::FuturesUnordered;
use futures::{Stream, StreamExt};
use revive_dt_common::iterators::FilesWithExtensionIterator; use revive_dt_common::iterators::FilesWithExtensionIterator;
use revive_dt_node_interaction::EthereumNode; use revive_dt_node_interaction::EthereumNode;
use semver::Version; use semver::Version;
@@ -61,7 +62,7 @@ struct Test {
} }
/// This represents the results that we gather from running test cases. /// This represents the results that we gather from running test cases.
type CaseResult = Result<Option<usize>, anyhow::Error>; type CaseResult = Result<usize, anyhow::Error>;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = init_cli()?; let args = init_cli()?;
@@ -144,7 +145,7 @@ where
{ {
let (report_tx, report_rx) = mpsc::unbounded_channel::<(Test, CaseResult)>(); let (report_tx, report_rx) = mpsc::unbounded_channel::<(Test, CaseResult)>();
let tests = prepare_tests::<L, F>(metadata_files); let tests = prepare_tests::<L, F>(args, metadata_files);
let driver_task = start_driver_task::<L, F>(args, tests, span, report_tx)?; let driver_task = start_driver_task::<L, F>(args, tests, span, report_tx)?;
let status_reporter_task = start_reporter_task(report_rx); let status_reporter_task = start_reporter_task(report_rx);
@@ -153,7 +154,10 @@ where
Ok(()) Ok(())
} }
fn prepare_tests<L, F>(metadata_files: &[MetadataFile]) -> impl Iterator<Item = Test> fn prepare_tests<'args, L, F>(
args: &'args Arguments,
metadata_files: &[MetadataFile],
) -> impl Stream<Item = Test>
where where
L: Platform, L: Platform,
F: Platform, F: Platform,
@@ -235,11 +239,49 @@ where
case: case.clone(), case: case.clone(),
} }
}) })
.map(async |test| test)
.collect::<FuturesUnordered<_>>()
.filter_map(async move |test| {
// Check that both compilers support this test, else we skip it
let is_supported = does_compiler_support_mode::<L>(args, &test.mode).await.ok().unwrap_or(false) &&
does_compiler_support_mode::<F>(args, &test.mode).await.ok().unwrap_or(false);
tracing::warn!(
metadata_file_path = %test.path.display(),
case_idx = %test.case_idx,
case_name = ?test.case.name,
mode = %test.mode,
"Skipping test as one or both of the compilers don't support it"
);
// We filter_map to avoid needing to clone `test`, but return it as-is.
if is_supported {
Some(test)
} else {
None
}
})
}
async fn does_compiler_support_mode<P: Platform>(
args: &Arguments,
mode: &Mode,
) -> anyhow::Result<bool> {
let compiler_version_or_requirement = mode.compiler_version_to_use(args.solc.clone());
let compiler_path =
P::Compiler::get_compiler_executable(args, compiler_version_or_requirement).await?;
let compiler_version = P::Compiler::new(compiler_path.clone()).version()?;
Ok(P::Compiler::supports_mode(
&compiler_version,
mode.optimize_setting,
mode.pipeline,
))
} }
fn start_driver_task<L, F>( fn start_driver_task<L, F>(
args: &Arguments, args: &Arguments,
tests: impl Iterator<Item = Test>, tests: impl Stream<Item = Test>,
span: Span, span: Span,
report_tx: mpsc::UnboundedSender<(Test, CaseResult)>, report_tx: mpsc::UnboundedSender<(Test, CaseResult)>,
) -> anyhow::Result<impl Future<Output = ()>> ) -> anyhow::Result<impl Future<Output = ()>>
@@ -254,7 +296,7 @@ where
let compilation_cache = Arc::new(RwLock::new(HashMap::new())); let compilation_cache = Arc::new(RwLock::new(HashMap::new()));
let number_concurrent_tasks = args.number_of_concurrent_tasks(); let number_concurrent_tasks = args.number_of_concurrent_tasks();
Ok(futures::stream::iter(tests).for_each_concurrent( Ok(tests.for_each_concurrent(
// We want to limit the concurrent tasks here because: // We want to limit the concurrent tasks here because:
// //
// 1. We don't want to overwhelm the nodes with too many requests, leading to responses timing out. // 1. We don't want to overwhelm the nodes with too many requests, leading to responses timing out.
@@ -309,14 +351,12 @@ async fn start_reporter_task(mut report_rx: mpsc::UnboundedReceiver<(Test, CaseR
const GREEN: &str = "\x1B[32m"; const GREEN: &str = "\x1B[32m";
const RED: &str = "\x1B[31m"; const RED: &str = "\x1B[31m";
const GRAY: &str = "\x1B[90m";
const COLOUR_RESET: &str = "\x1B[0m"; const COLOUR_RESET: &str = "\x1B[0m";
const BOLD: &str = "\x1B[1m"; const BOLD: &str = "\x1B[1m";
const BOLD_RESET: &str = "\x1B[22m"; const BOLD_RESET: &str = "\x1B[22m";
let mut number_of_successes = 0; let mut number_of_successes = 0;
let mut number_of_failures = 0; let mut number_of_failures = 0;
let mut number_of_ignored = 0;
let mut failures = vec![]; let mut failures = vec![];
// Wait for reports to come from our test runner. When the channel closes, this ends. // Wait for reports to come from our test runner. When the channel closes, this ends.
@@ -327,18 +367,12 @@ async fn start_reporter_task(mut report_rx: mpsc::UnboundedReceiver<(Test, CaseR
let test_mode = test.mode.clone(); let test_mode = test.mode.clone();
match case_result { match case_result {
Ok(Some(_inputs)) => { Ok(_inputs) => {
number_of_successes += 1; number_of_successes += 1;
eprintln!( eprintln!(
"{GREEN}Case Succeeded:{COLOUR_RESET} {test_path} -> {case_name}:{case_idx} (mode: {test_mode})" "{GREEN}Case Succeeded:{COLOUR_RESET} {test_path} -> {case_name}:{case_idx} (mode: {test_mode})"
); );
} }
Ok(None) => {
number_of_ignored += 1;
eprintln!(
"{GRAY}Case Ignored:{COLOUR_RESET} {test_path} -> {case_name}:{case_idx} (mode: {test_mode})"
);
}
Err(err) => { Err(err) => {
number_of_failures += 1; number_of_failures += 1;
eprintln!( eprintln!(
@@ -372,8 +406,8 @@ async fn start_reporter_task(mut report_rx: mpsc::UnboundedReceiver<(Test, CaseR
// Summary at the end. // Summary at the end.
eprintln!( eprintln!(
"{} cases: {GREEN}{number_of_successes}{COLOUR_RESET} cases succeeded, {RED}{number_of_failures}{COLOUR_RESET} cases failed, {GRAY}{number_of_ignored}{COLOUR_RESET} cases ignored in {} seconds", "{} cases: {GREEN}{number_of_successes}{COLOUR_RESET} cases succeeded, {RED}{number_of_failures}{COLOUR_RESET} cases failed in {} seconds",
number_of_successes + number_of_failures + number_of_ignored, number_of_successes + number_of_failures,
elapsed.as_secs() elapsed.as_secs()
); );
} }
@@ -390,14 +424,14 @@ async fn handle_case_driver<L, F>(
leader_node: &L::Blockchain, leader_node: &L::Blockchain,
follower_node: &F::Blockchain, follower_node: &F::Blockchain,
_: Span, _: Span,
) -> anyhow::Result<Option<usize>> ) -> anyhow::Result<usize>
where where
L: Platform, L: Platform,
F: Platform, F: Platform,
L::Blockchain: revive_dt_node::Node + Send + Sync + 'static, L::Blockchain: revive_dt_node::Node + Send + Sync + 'static,
F::Blockchain: revive_dt_node::Node + Send + Sync + 'static, F::Blockchain: revive_dt_node::Node + Send + Sync + 'static,
{ {
let Some(leader_pre_link_contracts) = get_or_build_contracts::<L>( let leader_pre_link_contracts = get_or_build_contracts::<L>(
metadata, metadata,
metadata_file_path, metadata_file_path,
mode.clone(), mode.clone(),
@@ -405,11 +439,8 @@ where
compilation_cache.clone(), compilation_cache.clone(),
&HashMap::new(), &HashMap::new(),
) )
.await? .await?;
else { let follower_pre_link_contracts = get_or_build_contracts::<F>(
return Ok(None);
};
let Some(follower_pre_link_contracts) = get_or_build_contracts::<F>(
metadata, metadata,
metadata_file_path, metadata_file_path,
mode.clone(), mode.clone(),
@@ -417,10 +448,7 @@ where
compilation_cache.clone(), compilation_cache.clone(),
&HashMap::new(), &HashMap::new(),
) )
.await? .await?;
else {
return Ok(None);
};
let mut leader_deployed_libraries = HashMap::new(); let mut leader_deployed_libraries = HashMap::new();
let mut follower_deployed_libraries = HashMap::new(); let mut follower_deployed_libraries = HashMap::new();
@@ -581,7 +609,7 @@ where
cache.remove(&follower_key); cache.remove(&follower_key);
} }
let Some(leader_post_link_contracts) = get_or_build_contracts::<L>( let leader_post_link_contracts = get_or_build_contracts::<L>(
metadata, metadata,
metadata_file_path, metadata_file_path,
mode.clone(), mode.clone(),
@@ -589,11 +617,8 @@ where
compilation_cache.clone(), compilation_cache.clone(),
&leader_deployed_libraries, &leader_deployed_libraries,
) )
.await? .await?;
else { let follower_post_link_contracts = get_or_build_contracts::<F>(
return Ok(None);
};
let Some(follower_post_link_contracts) = get_or_build_contracts::<F>(
metadata, metadata,
metadata_file_path, metadata_file_path,
mode.clone(), mode.clone(),
@@ -601,10 +626,7 @@ where
compilation_cache, compilation_cache,
&follower_deployed_libraries, &follower_deployed_libraries,
) )
.await? .await?;
else {
return Ok(None);
};
(leader_post_link_contracts, follower_post_link_contracts) (leader_post_link_contracts, follower_post_link_contracts)
} else { } else {
@@ -631,7 +653,7 @@ where
leader_state, leader_state,
follower_state, follower_state,
); );
driver.execute().await.map(Some) driver.execute().await
} }
async fn get_or_build_contracts<P: Platform>( async fn get_or_build_contracts<P: Platform>(
@@ -641,7 +663,7 @@ async fn get_or_build_contracts<P: Platform>(
config: &Arguments, config: &Arguments,
compilation_cache: CompilationCache, compilation_cache: CompilationCache,
deployed_libraries: &HashMap<ContractInstance, (Address, JsonAbi)>, deployed_libraries: &HashMap<ContractInstance, (Address, JsonAbi)>,
) -> anyhow::Result<Option<Arc<(Version, CompilerOutput)>>> { ) -> anyhow::Result<Arc<(Version, CompilerOutput)>> {
let key = ( let key = (
metadata_file_path.to_path_buf(), metadata_file_path.to_path_buf(),
mode.clone(), mode.clone(),
@@ -652,25 +674,22 @@ async fn get_or_build_contracts<P: Platform>(
match *compilation_artifact { match *compilation_artifact {
Some(ref compiled_contracts) => { Some(ref compiled_contracts) => {
tracing::debug!(?key, "Compiled contracts cache hit"); tracing::debug!(?key, "Compiled contracts cache hit");
return Ok(Some(compiled_contracts.clone())); return Ok(compiled_contracts.clone());
} }
None => { None => {
tracing::debug!(?key, "Compiled contracts cache miss"); tracing::debug!(?key, "Compiled contracts cache miss");
let Some(compiled_contracts) = compile_contracts::<P>( let compiled_contracts = compile_contracts::<P>(
metadata, metadata,
metadata_file_path, metadata_file_path,
&mode, &mode,
config, config,
deployed_libraries, deployed_libraries,
) )
.await? .await?;
else {
return Ok(None);
};
let compiled_contracts = Arc::new(compiled_contracts); let compiled_contracts = Arc::new(compiled_contracts);
*compilation_artifact = Some(compiled_contracts.clone()); *compilation_artifact = Some(compiled_contracts.clone());
return Ok(Some(compiled_contracts.clone())); return Ok(compiled_contracts.clone());
} }
} }
}; };
@@ -684,21 +703,18 @@ async fn get_or_build_contracts<P: Platform>(
}; };
let mut compilation_artifact = mutex.lock().await; let mut compilation_artifact = mutex.lock().await;
let Some(compiled_contracts) = compile_contracts::<P>( let compiled_contracts = compile_contracts::<P>(
metadata, metadata,
metadata_file_path, metadata_file_path,
&mode, &mode,
config, config,
deployed_libraries, deployed_libraries,
) )
.await? .await?;
else {
return Ok(None);
};
let compiled_contracts = Arc::new(compiled_contracts); let compiled_contracts = Arc::new(compiled_contracts);
*compilation_artifact = Some(compiled_contracts.clone()); *compilation_artifact = Some(compiled_contracts.clone());
Ok(Some(compiled_contracts.clone())) Ok(compiled_contracts.clone())
} }
async fn compile_contracts<P: Platform>( async fn compile_contracts<P: Platform>(
@@ -707,22 +723,12 @@ async fn compile_contracts<P: Platform>(
mode: &Mode, mode: &Mode,
config: &Arguments, config: &Arguments,
deployed_libraries: &HashMap<ContractInstance, (Address, JsonAbi)>, deployed_libraries: &HashMap<ContractInstance, (Address, JsonAbi)>,
) -> anyhow::Result<Option<(Version, CompilerOutput)>> { ) -> anyhow::Result<(Version, CompilerOutput)> {
let compiler_version_or_requirement = mode.compiler_version_to_use(config.solc.clone()); let compiler_version_or_requirement = mode.compiler_version_to_use(config.solc.clone());
let compiler_path = let compiler_path =
P::Compiler::get_compiler_executable(config, compiler_version_or_requirement).await?; P::Compiler::get_compiler_executable(config, compiler_version_or_requirement).await?;
let compiler_version = P::Compiler::new(compiler_path.clone()).version()?; let compiler_version = P::Compiler::new(compiler_path.clone()).version()?;
if !P::Compiler::supports_mode(&compiler_version, mode.optimize_setting, mode.pipeline) {
tracing::info!(
%compiler_version,
metadata_file_path = %metadata_file_path.display(),
mode = ?mode,
"Skipping compilation: compiler does not support this mode or version"
);
return Ok(None);
}
tracing::info!( tracing::info!(
%compiler_version, %compiler_version,
metadata_file_path = %metadata_file_path.display(), metadata_file_path = %metadata_file_path.display(),
@@ -759,7 +765,7 @@ async fn compile_contracts<P: Platform>(
let compiler_output = compiler.try_build(compiler_path).await?; let compiler_output = compiler.try_build(compiler_path).await?;
Ok(Some((compiler_version, compiler_output))) Ok((compiler_version, compiler_output))
} }
async fn execute_corpus( async fn execute_corpus(