Allow account allocations (#161)

* Support repetitions in the tool

* Add support for account allocations

* Update the JSON schema
This commit is contained in:
Omar
2025-09-22 06:19:55 +03:00
committed by GitHub
parent 9b700bfec2
commit 7189361a58
12 changed files with 274 additions and 73 deletions
+69 -50
View File
@@ -2,6 +2,7 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use alloy::consensus::EMPTY_ROOT_HASH;
use alloy::hex;
@@ -17,22 +18,23 @@ use alloy::{
primitives::Address,
rpc::types::{TransactionRequest, trace::geth::DiffMode},
};
use anyhow::Context as _;
use anyhow::{Context as _, bail};
use futures::{TryStreamExt, future::try_join_all};
use indexmap::IndexMap;
use revive_dt_common::types::PlatformIdentifier;
use revive_dt_common::types::{PlatformIdentifier, PrivateKeyAllocator};
use revive_dt_format::traits::{ResolutionContext, ResolverApi};
use revive_dt_report::ExecutionSpecificReporter;
use semver::Version;
use revive_dt_format::case::Case;
use revive_dt_format::input::{
use revive_dt_format::metadata::{ContractIdent, ContractInstance, ContractPathAndIdent};
use revive_dt_format::steps::{
BalanceAssertionStep, Calldata, EtherValue, Expected, ExpectedOutput, FunctionCallStep, Method,
StepIdx, StorageEmptyAssertionStep,
};
use revive_dt_format::metadata::{ContractIdent, ContractInstance, ContractPathAndIdent};
use revive_dt_format::{input::Step, metadata::Metadata};
use revive_dt_format::{metadata::Metadata, steps::Step};
use revive_dt_node_interaction::EthereumNode;
use tokio::sync::Mutex;
use tokio::try_join;
use tracing::{Instrument, info, info_span, instrument};
@@ -53,6 +55,10 @@ pub struct CaseState {
/// The execution reporter.
execution_reporter: ExecutionSpecificReporter,
/// The private key allocator used for this case state. This is an Arc Mutex to allow for the
/// state to be cloned and for all of the clones to refer to the same allocator.
private_key_allocator: Arc<Mutex<PrivateKeyAllocator>>,
}
impl CaseState {
@@ -61,6 +67,7 @@ impl CaseState {
compiled_contracts: HashMap<PathBuf, HashMap<String, (String, JsonAbi)>>,
deployed_contracts: HashMap<ContractInstance, (ContractIdent, Address, JsonAbi)>,
execution_reporter: ExecutionSpecificReporter,
private_key_allocator: Arc<Mutex<PrivateKeyAllocator>>,
) -> Self {
Self {
compiled_contracts,
@@ -68,6 +75,7 @@ impl CaseState {
variables: Default::default(),
compiler_version,
execution_reporter,
private_key_allocator,
}
}
@@ -108,6 +116,12 @@ impl CaseState {
.context("Failed to handle the repetition step")?;
Ok(StepOutput::Repetition)
}
Step::AllocateAccount(account_allocation) => {
self.handle_account_allocation(account_allocation.variable_name.as_str())
.await
.context("Failed to allocate account")?;
Ok(StepOutput::AccountAllocation)
}
}
.inspect(|_| info!("Step Succeeded"))
}
@@ -201,6 +215,21 @@ impl CaseState {
Ok(())
}
#[instrument(level = "info", name = "Handling Account Allocation", skip_all)]
pub async fn handle_account_allocation(&mut self, variable_name: &str) -> anyhow::Result<()> {
let Some(variable_name) = variable_name.strip_prefix("$VARIABLE:") else {
bail!("Account allocation must start with $VARIABLE:");
};
let private_key = self.private_key_allocator.lock().await.allocate()?;
let account = private_key.address();
let variable = U256::from_be_slice(account.0.as_slice());
self.variables.insert(variable_name.to_string(), variable);
Ok(())
}
/// Handles the contract deployment for a given input performing it if it needs to be performed.
#[instrument(level = "info", skip_all)]
async fn handle_input_contract_deployment(
@@ -227,15 +256,16 @@ impl CaseState {
.then_some(input.value)
.flatten();
let caller = {
let context = self.default_resolution_context();
let resolver = node.resolver().await?;
input
.caller
.resolve_address(resolver.as_ref(), context)
.await?
};
if let (_, _, Some(receipt)) = self
.get_or_deploy_contract_instance(
&instance,
metadata,
input.caller,
calldata,
value,
node,
)
.get_or_deploy_contract_instance(&instance, metadata, caller, calldata, value, node)
.await
.context("Failed to get or deploy contract instance during input execution")?
{
@@ -465,13 +495,9 @@ impl CaseState {
{
// Handling the emitter assertion.
if let Some(ref expected_address) = expected_event.address {
let expected = Address::from_slice(
Calldata::new_compound([expected_address])
.calldata(resolver, resolution_context)
.await?
.get(12..32)
.expect("Can't fail"),
);
let expected = expected_address
.resolve_address(resolver, resolution_context)
.await?;
let actual = actual_event.address();
if actual != expected {
tracing::error!(
@@ -568,17 +594,17 @@ impl CaseState {
balance_assertion: &BalanceAssertionStep,
node: &dyn EthereumNode,
) -> anyhow::Result<()> {
let Some(instance) = balance_assertion
.address
.strip_suffix(".address")
.map(ContractInstance::new)
else {
let Some(address) = balance_assertion.address.as_resolvable_address() else {
return Ok(());
};
let Some(instance) = address.strip_suffix(".address").map(ContractInstance::new) else {
return Ok(());
};
self.get_or_deploy_contract_instance(
&instance,
metadata,
FunctionCallStep::default_caller(),
FunctionCallStep::default_caller_address(),
None,
None,
node,
@@ -591,20 +617,16 @@ impl CaseState {
pub async fn handle_balance_assertion_execution(
&mut self,
BalanceAssertionStep {
address: address_string,
address,
expected_balance: amount,
..
}: &BalanceAssertionStep,
node: &dyn EthereumNode,
) -> anyhow::Result<()> {
let resolver = node.resolver().await?;
let address = Address::from_slice(
Calldata::new_compound([address_string])
.calldata(resolver.as_ref(), self.default_resolution_context())
.await?
.get(12..32)
.expect("Can't fail"),
);
let address = address
.resolve_address(resolver.as_ref(), self.default_resolution_context())
.await?;
let balance = node.balance_of(address).await?;
@@ -616,7 +638,7 @@ impl CaseState {
"Balance assertion failed - Expected {} but got {} for {} resolved to {}",
expected,
actual,
address_string,
address,
address,
)
}
@@ -631,17 +653,17 @@ impl CaseState {
storage_empty_assertion: &StorageEmptyAssertionStep,
node: &dyn EthereumNode,
) -> anyhow::Result<()> {
let Some(instance) = storage_empty_assertion
.address
.strip_suffix(".address")
.map(ContractInstance::new)
else {
let Some(address) = storage_empty_assertion.address.as_resolvable_address() else {
return Ok(());
};
let Some(instance) = address.strip_suffix(".address").map(ContractInstance::new) else {
return Ok(());
};
self.get_or_deploy_contract_instance(
&instance,
metadata,
FunctionCallStep::default_caller(),
FunctionCallStep::default_caller_address(),
None,
None,
node,
@@ -654,20 +676,16 @@ impl CaseState {
pub async fn handle_storage_empty_assertion_execution(
&mut self,
StorageEmptyAssertionStep {
address: address_string,
address,
is_storage_empty,
..
}: &StorageEmptyAssertionStep,
node: &dyn EthereumNode,
) -> anyhow::Result<()> {
let resolver = node.resolver().await?;
let address = Address::from_slice(
Calldata::new_compound([address_string])
.calldata(resolver.as_ref(), self.default_resolution_context())
.await?
.get(12..32)
.expect("Can't fail"),
);
let address = address
.resolve_address(resolver.as_ref(), self.default_resolution_context())
.await?;
let storage = node.latest_state_proof(address, Default::default()).await?;
let is_empty = storage.storage_hash == EMPTY_ROOT_HASH;
@@ -681,7 +699,7 @@ impl CaseState {
"Storage Empty Assertion failed - Expected {} but got {} for {} resolved to {}",
expected,
actual,
address_string,
address,
address,
)
};
@@ -875,4 +893,5 @@ pub enum StepOutput {
BalanceAssertion,
StorageEmptyAssertion,
Repetition,
AccountAllocation,
}
+18 -5
View File
@@ -26,10 +26,14 @@ use revive_dt_report::{
};
use schemars::schema_for;
use serde_json::{Value, json};
use tokio::sync::Mutex;
use tracing::{debug, error, info, info_span, instrument};
use tracing_subscriber::{EnvFilter, FmtSubscriber};
use revive_dt_common::{iterators::EitherIter, types::Mode};
use revive_dt_common::{
iterators::EitherIter,
types::{Mode, PrivateKeyAllocator},
};
use revive_dt_compiler::SolidityCompiler;
use revive_dt_config::{Context, *};
use revive_dt_core::{
@@ -39,9 +43,9 @@ use revive_dt_core::{
use revive_dt_format::{
case::{Case, CaseIdx},
corpus::Corpus,
input::{FunctionCallStep, Step},
metadata::{ContractPathAndIdent, Metadata, MetadataFile},
mode::ParsedMode,
steps::{FunctionCallStep, Step},
};
use crate::cached_compiler::CachedCompiler;
@@ -326,8 +330,13 @@ async fn start_driver_task<'a>(
.expect("Can't fail");
}
let private_key_allocator = Arc::new(Mutex::new(PrivateKeyAllocator::new(
context.wallet_configuration.highest_private_key_exclusive(),
)));
let reporter = test.reporter.clone();
let result = handle_case_driver(&test, cached_compiler).await;
let result =
handle_case_driver(&test, cached_compiler, private_key_allocator).await;
match result {
Ok(steps_executed) => reporter
@@ -438,6 +447,7 @@ async fn start_cli_reporting_task(reporter: Reporter) {
async fn handle_case_driver<'a>(
test: &Test<'a>,
cached_compiler: Arc<CachedCompiler<'a>>,
private_key_allocator: Arc<Mutex<PrivateKeyAllocator>>,
) -> anyhow::Result<usize> {
let platform_state = stream::iter(test.platforms.iter())
// Compiling the pre-link contracts.
@@ -511,13 +521,14 @@ async fn handle_case_driver<'a>(
.steps
.iter()
.filter_map(|step| match step {
Step::FunctionCall(input) => Some(input.caller),
Step::FunctionCall(input) => input.caller.as_address().copied(),
Step::BalanceAssertion(..) => None,
Step::StorageEmptyAssertion(..) => None,
Step::Repeat(..) => None,
Step::AllocateAccount(..) => None,
})
.next()
.unwrap_or(FunctionCallStep::default_caller());
.unwrap_or(FunctionCallStep::default_caller_address());
let tx = TransactionBuilder::<Ethereum>::with_deploy_code(
TransactionRequest::default().from(deployer_address),
code,
@@ -564,6 +575,7 @@ async fn handle_case_driver<'a>(
.filter_map(
|(test, platform, node, compiler, reporter, _, deployed_libraries)| {
let cached_compiler = cached_compiler.clone();
let private_key_allocator = private_key_allocator.clone();
async move {
let compiler_output = cached_compiler
@@ -591,6 +603,7 @@ async fn handle_case_driver<'a>(
compiler_output.contracts,
deployed_libraries.unwrap_or_default(),
reporter.clone(),
private_key_allocator,
);
Some((*node, platform.platform_identifier(), case_state))