diff --git a/crates/core/src/driver/mod.rs b/crates/core/src/driver/mod.rs index cac0aa8..721003e 100644 --- a/crates/core/src/driver/mod.rs +++ b/crates/core/src/driver/mod.rs @@ -4,10 +4,10 @@ use std::collections::HashMap; use std::fmt::Debug; use std::marker::PhantomData; use std::path::PathBuf; -use std::str::FromStr; use alloy::json_abi::JsonAbi; use alloy::network::{Ethereum, TransactionBuilder}; +use alloy::primitives::U256; use alloy::rpc::types::TransactionReceipt; use alloy::rpc::types::trace::geth::{ CallFrame, GethDebugBuiltInTracerType, GethDebugTracerType, GethDebugTracingOptions, GethTrace, @@ -55,6 +55,10 @@ pub struct State<'a, T: Platform> { /// files. deployed_contracts: HashMap>, + /// This map stores the variables used for each one of the cases contained in the metadata + /// file. + variables: HashMap>, + /// This is a map of the deployed libraries. /// /// This map is not per case, but rather, per metadata file. This means that we do not redeploy @@ -77,6 +81,7 @@ where span, compiled_contracts: Default::default(), deployed_contracts: Default::default(), + variables: Default::default(), deployed_libraries: Default::default(), compiler_version: Default::default(), phantom: Default::default(), @@ -216,7 +221,16 @@ where self.handle_contract_deployment(metadata, case_idx, input, node)?; let execution_receipt = self.handle_input_execution(case_idx, input, deployment_receipts, node)?; - self.handle_input_expectations(case_idx, input, &execution_receipt, node, mode)?; + let tracing_result = self.handle_input_call_frame_tracing(&execution_receipt, node)?; + self.handle_input_variable_assignment(case_idx, input, &tracing_result)?; + self.handle_input_expectations( + case_idx, + input, + &execution_receipt, + node, + mode, + &tracing_result, + )?; self.handle_input_diff(case_idx, execution_receipt, node) } @@ -237,7 +251,12 @@ where let mut instances_we_must_deploy = IndexMap::::new(); for instance in input.find_all_contract_instances().into_iter() { - if !self.deployed_contracts(case_idx).contains_key(&instance) { + if !self + .deployed_contracts + .entry(case_idx) + .or_insert_with(|| self.deployed_libraries.clone()) + .contains_key(&instance) + { instances_we_must_deploy.entry(instance).or_insert(false); } } @@ -289,7 +308,13 @@ where .remove(&input.instance) .context("Failed to find deployment receipt"), Method::Fallback | Method::FunctionName(_) => { - let tx = match input.legacy_transaction(self.deployed_contracts(case_idx), node) { + let tx = match input.legacy_transaction( + self.deployed_contracts + .entry(case_idx) + .or_insert_with(|| self.deployed_libraries.clone()), + &*self.variables.entry(case_idx).or_default(), + node, + ) { Ok(tx) => { tracing::debug!("Legacy transaction data: {tx:#?}"); tx @@ -317,6 +342,56 @@ where } } + fn handle_input_call_frame_tracing( + &self, + execution_receipt: &TransactionReceipt, + node: &T::Blockchain, + ) -> anyhow::Result { + node.trace_transaction( + execution_receipt, + GethDebugTracingOptions { + tracer: Some(GethDebugTracerType::BuiltInTracer( + GethDebugBuiltInTracerType::CallTracer, + )), + ..Default::default() + }, + ) + .map(|trace| { + trace + .try_into_call_frame() + .expect("Impossible - we requested a callframe trace so we must get it back") + }) + } + + fn handle_input_variable_assignment( + &mut self, + case_idx: CaseIdx, + input: &Input, + tracing_result: &CallFrame, + ) -> anyhow::Result<()> { + let Some(ref assignments) = input.variable_assignments else { + return Ok(()); + }; + + // Handling the return data variable assignments. + for (variable_name, output_word) in assignments.return_data.iter().zip( + tracing_result + .output + .as_ref() + .unwrap_or_default() + .to_vec() + .chunks(32), + ) { + let value = U256::from_be_slice(output_word); + self.variables + .entry(case_idx) + .or_default() + .insert(variable_name.clone(), value); + } + + Ok(()) + } + fn handle_input_expectations( &mut self, case_idx: CaseIdx, @@ -324,6 +399,7 @@ where execution_receipt: &TransactionReceipt, node: &T::Blockchain, mode: &SolcMode, + tracing_result: &CallFrame, ) -> anyhow::Result<()> { let span = tracing::info_span!("Handling input expectations"); let _guard = span.enter(); @@ -356,29 +432,13 @@ where } } - // Note: we need to do assertions and checks on the output of the last call and this isn't - // available in the receipt. The only way to get this information is through tracing on the - // node. - let tracing_result = node - .trace_transaction( - execution_receipt, - GethDebugTracingOptions { - tracer: Some(GethDebugTracerType::BuiltInTracer( - GethDebugBuiltInTracerType::CallTracer, - )), - ..Default::default() - }, - )? - .try_into_call_frame() - .expect("Impossible - we requested a callframe trace so we must get it back"); - for expectation in expectations.iter() { self.handle_input_expectation_item( case_idx, execution_receipt, node, expectation, - &tracing_result, + tracing_result, mode, )?; } @@ -405,14 +465,24 @@ where } } - let deployed_contracts = self.deployed_contracts(case_idx); + let deployed_contracts = self + .deployed_contracts + .entry(case_idx) + .or_insert_with(|| self.deployed_libraries.clone()); + let variables = self.variables.entry(case_idx).or_default(); let chain_state_provider = node; // Handling the receipt state assertion. let expected = !expectation.exception; let actual = execution_receipt.status(); if actual != expected { - tracing::error!(expected, actual, "Transaction status assertion failed",); + tracing::error!( + expected, + actual, + ?execution_receipt, + ?tracing_result, + "Transaction status assertion failed" + ); anyhow::bail!( "Transaction status assertion failed - Expected {expected} but got {actual}", ); @@ -422,7 +492,12 @@ where if let Some(ref expected_calldata) = expectation.return_data { let expected = expected_calldata; let actual = &tracing_result.output.as_ref().unwrap_or_default(); - if !expected.is_equivalent(actual, deployed_contracts, chain_state_provider)? { + if !expected.is_equivalent( + actual, + deployed_contracts, + &*variables, + chain_state_provider, + )? { tracing::error!( ?execution_receipt, ?expected, @@ -451,17 +526,12 @@ where { // Handling the emitter assertion. if let Some(ref expected_address) = expected_event.address { - let expected = if let Some(contract_instance) = expected_address - .strip_suffix(".address") - .map(ContractInstance::new) - { - deployed_contracts - .get(&contract_instance) - .map(|(address, _)| *address) - } else { - Address::from_str(expected_address).ok() - } - .context("Failed to get the address of the event")?; + let expected = Address::from_slice( + Calldata::new_compound([expected_address]) + .calldata(deployed_contracts, &*variables, node)? + .get(12..32) + .expect("Can't fail"), + ); let actual = actual_event.address(); if actual != expected { tracing::error!( @@ -486,6 +556,7 @@ where if !expected.is_equivalent( &actual.0, deployed_contracts, + &*variables, chain_state_provider, )? { tracing::error!( @@ -503,7 +574,12 @@ where // Handling the values assertion. let expected = &expected_event.values; let actual = &actual_event.data().data; - if !expected.is_equivalent(&actual.0, deployed_contracts, chain_state_provider)? { + if !expected.is_equivalent( + &actual.0, + deployed_contracts, + &*variables, + chain_state_provider, + )? { tracing::error!( ?execution_receipt, ?expected, @@ -541,19 +617,6 @@ where Ok((execution_receipt, trace, diff)) } - fn deployed_contracts( - &mut self, - case_idx: impl Into>, - ) -> &mut HashMap { - match case_idx.into() { - Some(case_idx) => self - .deployed_contracts - .entry(case_idx) - .or_insert_with(|| self.deployed_libraries.clone()), - None => &mut self.deployed_libraries, - } - } - /// Gets the information of a deployed contract or library from the state. If it's found to not /// be deployed then it will be deployed. /// @@ -572,17 +635,16 @@ where ) -> anyhow::Result<(Address, JsonAbi, Option)> { let case_idx = case_idx.into(); - if let Some((address, abi)) = self.deployed_libraries.get(contract_instance) { - return Ok((*address, abi.clone(), None)); - } - if let Some(case_idx) = case_idx { - if let Some((address, abi)) = self + let deployed_contracts = match case_idx { + Some(case_idx) => self .deployed_contracts - .get(&case_idx) - .and_then(|contracts| contracts.get(contract_instance)) - { - return Ok((*address, abi.clone(), None)); - } + .entry(case_idx) + .or_insert_with(|| self.deployed_libraries.clone()), + None => &mut self.deployed_libraries, + }; + + if let Some((address, abi)) = deployed_contracts.get(contract_instance) { + return Ok((*address, abi.clone(), None)); } let Some(ContractPathAndIdent { @@ -628,7 +690,7 @@ where }; if let Some(calldata) = calldata { - let calldata = calldata.calldata(self.deployed_contracts(case_idx), node)?; + let calldata = calldata.calldata(deployed_contracts, None, node)?; code.extend(calldata); } @@ -663,8 +725,7 @@ where "Deployed contract" ); - self.deployed_contracts(case_idx) - .insert(contract_instance.clone(), (address, abi.clone())); + deployed_contracts.insert(contract_instance.clone(), (address, abi.clone())); Ok((address, abi, Some(receipt))) } diff --git a/crates/format/src/input.rs b/crates/format/src/input.rs index 7e71f41..8031553 100644 --- a/crates/format/src/input.rs +++ b/crates/format/src/input.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use alloy::{ eips::BlockNumberOrTag, + hex::ToHexExt, json_abi::JsonAbi, network::TransactionBuilder, primitives::{Address, Bytes, U256}, @@ -30,6 +31,7 @@ pub struct Input { pub expected: Option, pub value: Option, pub storage: Option>, + pub variable_assignments: Option, } #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] @@ -136,6 +138,8 @@ enum Operation { BitwiseAnd, BitwiseOr, BitwiseXor, + ShiftLeft, + ShiftRight, } /// Specify how the contract is called. @@ -164,6 +168,14 @@ define_wrapper_type!( pub struct EtherValue(U256); ); +#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq)] +pub struct VariableAssignments { + /// A vector of the variable names to assign to the return data. + /// + /// Example: `UniswapV3PoolAddress` + pub return_data: Vec, +} + impl Input { pub const fn default_caller() -> Address { Address(FixedBytes(alloy::hex!( @@ -186,16 +198,17 @@ impl Input { .ok_or_else(|| anyhow::anyhow!("instance {instance:?} not deployed")) } - pub fn encoded_input( - &self, + pub fn encoded_input<'a>( + &'a self, deployed_contracts: &HashMap, + variables: impl Into>> + Clone, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result { match self.method { Method::Deployer | Method::Fallback => { - let calldata = self - .calldata - .calldata(deployed_contracts, chain_state_provider)?; + let calldata = + self.calldata + .calldata(deployed_contracts, variables, chain_state_provider)?; Ok(calldata.into()) } @@ -244,6 +257,7 @@ impl Input { self.calldata.calldata_into_slice( &mut calldata, deployed_contracts, + variables, chain_state_provider, )?; @@ -253,12 +267,13 @@ impl Input { } /// Parse this input into a legacy transaction. - pub fn legacy_transaction( - &self, + pub fn legacy_transaction<'a>( + &'a self, deployed_contracts: &HashMap, + variables: impl Into>> + Clone, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result { - let input_data = self.encoded_input(deployed_contracts, chain_state_provider)?; + let input_data = self.encoded_input(deployed_contracts, variables, chain_state_provider)?; let transaction_request = TransactionRequest::default().from(self.caller).value( self.value .map(|value| value.into_inner()) @@ -336,20 +351,27 @@ impl Calldata { } } - pub fn calldata( - &self, + pub fn calldata<'a>( + &'a self, deployed_contracts: &HashMap, + variables: impl Into>> + Clone, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result> { let mut buffer = Vec::::with_capacity(self.size_requirement()); - self.calldata_into_slice(&mut buffer, deployed_contracts, chain_state_provider)?; + self.calldata_into_slice( + &mut buffer, + deployed_contracts, + variables, + chain_state_provider, + )?; Ok(buffer) } - pub fn calldata_into_slice( - &self, + pub fn calldata_into_slice<'a>( + &'a self, buffer: &mut Vec, deployed_contracts: &HashMap, + variables: impl Into>> + Clone, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result<()> { match self { @@ -358,7 +380,7 @@ impl Calldata { } Calldata::Compound(items) => { for (arg_idx, arg) in items.iter().enumerate() { - match arg.resolve(deployed_contracts, chain_state_provider) { + match arg.resolve(deployed_contracts, variables.clone(), chain_state_provider) { Ok(resolved) => { buffer.extend(resolved.to_be_bytes::<32>()); } @@ -381,10 +403,11 @@ impl Calldata { } /// Checks if this [`Calldata`] is equivalent to the passed calldata bytes. - pub fn is_equivalent( - &self, + pub fn is_equivalent<'a>( + &'a self, other: &[u8], deployed_contracts: &HashMap, + variables: impl Into>> + Clone, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result { match self { @@ -407,7 +430,8 @@ impl Calldata { std::borrow::Cow::Borrowed(other) }; - let this = this.resolve(deployed_contracts, chain_state_provider)?; + let this = + this.resolve(deployed_contracts, variables.clone(), chain_state_provider)?; let other = U256::from_be_slice(&other); if this != other { return Ok(false); @@ -420,16 +444,17 @@ impl Calldata { } impl CalldataItem { - fn resolve( - &self, + fn resolve<'a>( + &'a self, deployed_contracts: &HashMap, + variables: impl Into>> + Clone, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result { let mut stack = Vec::>::new(); for token in self .calldata_tokens() - .map(|token| token.resolve(deployed_contracts, chain_state_provider)) + .map(|token| token.resolve(deployed_contracts, variables.clone(), chain_state_provider)) { let token = token?; let new_token = match token { @@ -452,8 +477,14 @@ impl CalldataItem { Operation::BitwiseAnd => Some(left_operand & right_operand), Operation::BitwiseOr => Some(left_operand | right_operand), Operation::BitwiseXor => Some(left_operand ^ right_operand), + Operation::ShiftLeft => { + Some(left_operand << usize::try_from(right_operand)?) + } + Operation::ShiftRight => { + Some(left_operand >> usize::try_from(right_operand)?) + } } - .context("Invalid calldata arithmetic operation")?; + .context("Invalid calldata arithmetic operation - Invalid operation")?; CalldataToken::Item(result) } @@ -464,8 +495,17 @@ impl CalldataItem { match stack.as_slice() { // Empty stack means that we got an empty compound calldata which we resolve to zero. [] => Ok(U256::ZERO), - [CalldataToken::Item(item)] => Ok(*item), - _ => Err(anyhow::anyhow!("Invalid calldata arithmetic operation")), + [CalldataToken::Item(item)] => { + tracing::debug!( + original = self.0, + resolved = item.to_be_bytes::<32>().encode_hex(), + "Resolved a Calldata item" + ); + Ok(*item) + } + _ => Err(anyhow::anyhow!( + "Invalid calldata arithmetic operation - Invalid stack" + )), } } @@ -478,6 +518,8 @@ impl CalldataItem { "&" => CalldataToken::Operation(Operation::BitwiseAnd), "|" => CalldataToken::Operation(Operation::BitwiseOr), "^" => CalldataToken::Operation(Operation::BitwiseXor), + "<<" => CalldataToken::Operation(Operation::ShiftLeft), + ">>" => CalldataToken::Operation(Operation::ShiftRight), _ => CalldataToken::Item(item), }) } @@ -494,6 +536,7 @@ impl CalldataToken { const BLOCK_HASH_VARIABLE_PREFIX: &str = "$BLOCK_HASH"; const BLOCK_NUMBER_VARIABLE: &str = "$BLOCK_NUMBER"; const BLOCK_TIMESTAMP_VARIABLE: &str = "$BLOCK_TIMESTAMP"; + const VARIABLE_PREFIX: &str = "$VARIABLE:"; fn into_item(self) -> Option { match self { @@ -512,9 +555,10 @@ impl> CalldataToken { /// This piece of code is taken from the matter-labs-tester repository which is licensed under /// MIT or Apache. The original source code can be found here: /// https://github.com/matter-labs/era-compiler-tester/blob/0ed598a27f6eceee7008deab3ff2311075a2ec69/compiler_tester/src/test/case/input/value.rs#L43-L146 - fn resolve( + fn resolve<'a>( self, deployed_contracts: &HashMap, + variables: impl Into>> + Clone, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result> { match self { @@ -579,6 +623,16 @@ impl> CalldataToken { let timestamp = chain_state_provider.block_timestamp(BlockNumberOrTag::Latest)?; Ok(U256::from(timestamp)) + } else if let Some(variable_name) = item.strip_prefix(Self::VARIABLE_PREFIX) { + let Some(variables) = variables.into() else { + anyhow::bail!( + "Variable resolution required but no variables were passed in" + ); + }; + let Some(variable) = variables.get(variable_name) else { + anyhow::bail!("No variable found with the name {}", variable_name) + }; + Ok(*variable) } else { Ok(U256::from_str_radix(item, 10) .map_err(|error| anyhow::anyhow!("Invalid decimal literal: {}", error))?) @@ -699,7 +753,9 @@ mod tests { (Address::ZERO, parsed_abi), ); - let encoded = input.encoded_input(&contracts, &MockResolver).unwrap(); + let encoded = input + .encoded_input(&contracts, None, &MockResolver) + .unwrap(); assert!(encoded.0.starts_with(&selector)); type T = (u64,); @@ -741,7 +797,9 @@ mod tests { (Address::ZERO, parsed_abi), ); - let encoded = input.encoded_input(&contracts, &MockResolver).unwrap(); + let encoded = input + .encoded_input(&contracts, None, &MockResolver) + .unwrap(); assert!(encoded.0.starts_with(&selector)); type T = (alloy_primitives::Address,); @@ -786,7 +844,9 @@ mod tests { (Address::ZERO, parsed_abi), ); - let encoded = input.encoded_input(&contracts, &MockResolver).unwrap(); + let encoded = input + .encoded_input(&contracts, None, &MockResolver) + .unwrap(); assert!(encoded.0.starts_with(&selector)); type T = (alloy_primitives::Address,); @@ -802,7 +862,7 @@ mod tests { deployed_contracts: &HashMap, chain_state_provider: &impl ResolverApi, ) -> anyhow::Result { - CalldataItem::new(input).resolve(deployed_contracts, chain_state_provider) + CalldataItem::new(input).resolve(deployed_contracts, None, chain_state_provider) } #[test]