Add variables support

This commit is contained in:
Omar Abdulla
2025-07-30 11:26:34 +03:00
parent f51693cb9f
commit b715228c35
2 changed files with 211 additions and 90 deletions
+123 -62
View File
@@ -4,10 +4,10 @@ use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::str::FromStr;
use alloy::json_abi::JsonAbi;
use alloy::network::{Ethereum, TransactionBuilder};
use alloy::primitives::U256;
use alloy::rpc::types::TransactionReceipt;
use alloy::rpc::types::trace::geth::{
CallFrame, GethDebugBuiltInTracerType, GethDebugTracerType, GethDebugTracingOptions, GethTrace,
@@ -55,6 +55,10 @@ pub struct State<'a, T: Platform> {
/// files.
deployed_contracts: HashMap<CaseIdx, HashMap<ContractInstance, (Address, JsonAbi)>>,
/// This map stores the variables used for each one of the cases contained in the metadata
/// file.
variables: HashMap<CaseIdx, HashMap<String, U256>>,
/// This is a map of the deployed libraries.
///
/// This map is not per case, but rather, per metadata file. This means that we do not redeploy
@@ -77,6 +81,7 @@ where
span,
compiled_contracts: Default::default(),
deployed_contracts: Default::default(),
variables: Default::default(),
deployed_libraries: Default::default(),
compiler_version: Default::default(),
phantom: Default::default(),
@@ -216,7 +221,16 @@ where
self.handle_contract_deployment(metadata, case_idx, input, node)?;
let execution_receipt =
self.handle_input_execution(case_idx, input, deployment_receipts, node)?;
self.handle_input_expectations(case_idx, input, &execution_receipt, node, mode)?;
let tracing_result = self.handle_input_call_frame_tracing(&execution_receipt, node)?;
self.handle_input_variable_assignment(case_idx, input, &tracing_result)?;
self.handle_input_expectations(
case_idx,
input,
&execution_receipt,
node,
mode,
&tracing_result,
)?;
self.handle_input_diff(case_idx, execution_receipt, node)
}
@@ -237,7 +251,12 @@ where
let mut instances_we_must_deploy = IndexMap::<ContractInstance, bool>::new();
for instance in input.find_all_contract_instances().into_iter() {
if !self.deployed_contracts(case_idx).contains_key(&instance) {
if !self
.deployed_contracts
.entry(case_idx)
.or_insert_with(|| self.deployed_libraries.clone())
.contains_key(&instance)
{
instances_we_must_deploy.entry(instance).or_insert(false);
}
}
@@ -289,7 +308,13 @@ where
.remove(&input.instance)
.context("Failed to find deployment receipt"),
Method::Fallback | Method::FunctionName(_) => {
let tx = match input.legacy_transaction(self.deployed_contracts(case_idx), node) {
let tx = match input.legacy_transaction(
self.deployed_contracts
.entry(case_idx)
.or_insert_with(|| self.deployed_libraries.clone()),
&*self.variables.entry(case_idx).or_default(),
node,
) {
Ok(tx) => {
tracing::debug!("Legacy transaction data: {tx:#?}");
tx
@@ -317,6 +342,56 @@ where
}
}
fn handle_input_call_frame_tracing(
&self,
execution_receipt: &TransactionReceipt,
node: &T::Blockchain,
) -> anyhow::Result<CallFrame> {
node.trace_transaction(
execution_receipt,
GethDebugTracingOptions {
tracer: Some(GethDebugTracerType::BuiltInTracer(
GethDebugBuiltInTracerType::CallTracer,
)),
..Default::default()
},
)
.map(|trace| {
trace
.try_into_call_frame()
.expect("Impossible - we requested a callframe trace so we must get it back")
})
}
fn handle_input_variable_assignment(
&mut self,
case_idx: CaseIdx,
input: &Input,
tracing_result: &CallFrame,
) -> anyhow::Result<()> {
let Some(ref assignments) = input.variable_assignments else {
return Ok(());
};
// Handling the return data variable assignments.
for (variable_name, output_word) in assignments.return_data.iter().zip(
tracing_result
.output
.as_ref()
.unwrap_or_default()
.to_vec()
.chunks(32),
) {
let value = U256::from_be_slice(output_word);
self.variables
.entry(case_idx)
.or_default()
.insert(variable_name.clone(), value);
}
Ok(())
}
fn handle_input_expectations(
&mut self,
case_idx: CaseIdx,
@@ -324,6 +399,7 @@ where
execution_receipt: &TransactionReceipt,
node: &T::Blockchain,
mode: &SolcMode,
tracing_result: &CallFrame,
) -> anyhow::Result<()> {
let span = tracing::info_span!("Handling input expectations");
let _guard = span.enter();
@@ -356,29 +432,13 @@ where
}
}
// Note: we need to do assertions and checks on the output of the last call and this isn't
// available in the receipt. The only way to get this information is through tracing on the
// node.
let tracing_result = node
.trace_transaction(
execution_receipt,
GethDebugTracingOptions {
tracer: Some(GethDebugTracerType::BuiltInTracer(
GethDebugBuiltInTracerType::CallTracer,
)),
..Default::default()
},
)?
.try_into_call_frame()
.expect("Impossible - we requested a callframe trace so we must get it back");
for expectation in expectations.iter() {
self.handle_input_expectation_item(
case_idx,
execution_receipt,
node,
expectation,
&tracing_result,
tracing_result,
mode,
)?;
}
@@ -405,14 +465,24 @@ where
}
}
let deployed_contracts = self.deployed_contracts(case_idx);
let deployed_contracts = self
.deployed_contracts
.entry(case_idx)
.or_insert_with(|| self.deployed_libraries.clone());
let variables = self.variables.entry(case_idx).or_default();
let chain_state_provider = node;
// Handling the receipt state assertion.
let expected = !expectation.exception;
let actual = execution_receipt.status();
if actual != expected {
tracing::error!(expected, actual, "Transaction status assertion failed",);
tracing::error!(
expected,
actual,
?execution_receipt,
?tracing_result,
"Transaction status assertion failed"
);
anyhow::bail!(
"Transaction status assertion failed - Expected {expected} but got {actual}",
);
@@ -422,7 +492,12 @@ where
if let Some(ref expected_calldata) = expectation.return_data {
let expected = expected_calldata;
let actual = &tracing_result.output.as_ref().unwrap_or_default();
if !expected.is_equivalent(actual, deployed_contracts, chain_state_provider)? {
if !expected.is_equivalent(
actual,
deployed_contracts,
&*variables,
chain_state_provider,
)? {
tracing::error!(
?execution_receipt,
?expected,
@@ -451,17 +526,12 @@ where
{
// Handling the emitter assertion.
if let Some(ref expected_address) = expected_event.address {
let expected = if let Some(contract_instance) = expected_address
.strip_suffix(".address")
.map(ContractInstance::new)
{
deployed_contracts
.get(&contract_instance)
.map(|(address, _)| *address)
} else {
Address::from_str(expected_address).ok()
}
.context("Failed to get the address of the event")?;
let expected = Address::from_slice(
Calldata::new_compound([expected_address])
.calldata(deployed_contracts, &*variables, node)?
.get(12..32)
.expect("Can't fail"),
);
let actual = actual_event.address();
if actual != expected {
tracing::error!(
@@ -486,6 +556,7 @@ where
if !expected.is_equivalent(
&actual.0,
deployed_contracts,
&*variables,
chain_state_provider,
)? {
tracing::error!(
@@ -503,7 +574,12 @@ where
// Handling the values assertion.
let expected = &expected_event.values;
let actual = &actual_event.data().data;
if !expected.is_equivalent(&actual.0, deployed_contracts, chain_state_provider)? {
if !expected.is_equivalent(
&actual.0,
deployed_contracts,
&*variables,
chain_state_provider,
)? {
tracing::error!(
?execution_receipt,
?expected,
@@ -541,19 +617,6 @@ where
Ok((execution_receipt, trace, diff))
}
fn deployed_contracts(
&mut self,
case_idx: impl Into<Option<CaseIdx>>,
) -> &mut HashMap<ContractInstance, (Address, JsonAbi)> {
match case_idx.into() {
Some(case_idx) => self
.deployed_contracts
.entry(case_idx)
.or_insert_with(|| self.deployed_libraries.clone()),
None => &mut self.deployed_libraries,
}
}
/// Gets the information of a deployed contract or library from the state. If it's found to not
/// be deployed then it will be deployed.
///
@@ -572,17 +635,16 @@ where
) -> anyhow::Result<(Address, JsonAbi, Option<TransactionReceipt>)> {
let case_idx = case_idx.into();
if let Some((address, abi)) = self.deployed_libraries.get(contract_instance) {
return Ok((*address, abi.clone(), None));
}
if let Some(case_idx) = case_idx {
if let Some((address, abi)) = self
let deployed_contracts = match case_idx {
Some(case_idx) => self
.deployed_contracts
.get(&case_idx)
.and_then(|contracts| contracts.get(contract_instance))
{
return Ok((*address, abi.clone(), None));
}
.entry(case_idx)
.or_insert_with(|| self.deployed_libraries.clone()),
None => &mut self.deployed_libraries,
};
if let Some((address, abi)) = deployed_contracts.get(contract_instance) {
return Ok((*address, abi.clone(), None));
}
let Some(ContractPathAndIdent {
@@ -628,7 +690,7 @@ where
};
if let Some(calldata) = calldata {
let calldata = calldata.calldata(self.deployed_contracts(case_idx), node)?;
let calldata = calldata.calldata(deployed_contracts, None, node)?;
code.extend(calldata);
}
@@ -663,8 +725,7 @@ where
"Deployed contract"
);
self.deployed_contracts(case_idx)
.insert(contract_instance.clone(), (address, abi.clone()));
deployed_contracts.insert(contract_instance.clone(), (address, abi.clone()));
Ok((address, abi, Some(receipt)))
}