diff --git a/assets/test_metadata.json b/assets/test_metadata.json index 6d54584..7fca971 100644 --- a/assets/test_metadata.json +++ b/assets/test_metadata.json @@ -8,6 +8,10 @@ { "name": "first", "inputs": [ + { + "address": "0xdeadbeef00000000000000000000000000000042", + "expected_balance": "1233" + }, { "instance": "WBTC_1", "method": "#deployer", diff --git a/crates/common/src/iterators/files_with_extension_iterator.rs b/crates/common/src/iterators/files_with_extension_iterator.rs index 1892c5f..ac81103 100644 --- a/crates/common/src/iterators/files_with_extension_iterator.rs +++ b/crates/common/src/iterators/files_with_extension_iterator.rs @@ -1,4 +1,8 @@ -use std::{borrow::Cow, collections::HashSet, path::PathBuf}; +use std::{ + borrow::Cow, + collections::HashSet, + path::{Path, PathBuf}, +}; /// An iterator that finds files of a certain extension in the provided directory. You can think of /// this a glob pattern similar to: `${path}/**/*.md` @@ -18,10 +22,10 @@ pub struct FilesWithExtensionIterator { } impl FilesWithExtensionIterator { - pub fn new(root_directory: PathBuf) -> Self { + pub fn new(root_directory: impl AsRef) -> Self { Self { allowed_extensions: Default::default(), - directories_to_search: vec![root_directory], + directories_to_search: vec![root_directory.as_ref().to_path_buf()], files_matching_allowed_extensions: Default::default(), } } diff --git a/crates/core/src/driver/mod.rs b/crates/core/src/driver/mod.rs index c070d38..32b13c5 100644 --- a/crates/core/src/driver/mod.rs +++ b/crates/core/src/driver/mod.rs @@ -26,9 +26,11 @@ use revive_dt_format::traits::{ResolutionContext, ResolverApi}; use semver::Version; use revive_dt_format::case::{Case, CaseIdx}; -use revive_dt_format::input::{Calldata, EtherValue, Expected, ExpectedOutput, Method}; +use revive_dt_format::input::{ + BalanceAssertion, Calldata, EtherValue, Expected, ExpectedOutput, Input, Method, +}; use revive_dt_format::metadata::{ContractInstance, ContractPathAndIdent}; -use revive_dt_format::{input::Input, metadata::Metadata}; +use revive_dt_format::{input::Step, metadata::Metadata}; use revive_dt_node::Node; use revive_dt_node_interaction::EthereumNode; use tracing::Instrument; @@ -70,6 +72,27 @@ where } } + pub async fn handle_step( + &mut self, + metadata: &Metadata, + case_idx: CaseIdx, + step: &Step, + node: &T::Blockchain, + ) -> anyhow::Result { + match step { + Step::FunctionCall(input) => { + let (receipt, geth_trace, diff_mode) = + self.handle_input(metadata, case_idx, input, node).await?; + Ok(StepOutput::FunctionCall(receipt, geth_trace, diff_mode)) + } + Step::BalanceAssertion(balance_assertion) => { + self.handle_balance_assertion(metadata, case_idx, balance_assertion, node) + .await?; + Ok(StepOutput::BalanceAssertion) + } + } + } + pub async fn handle_input( &mut self, metadata: &Metadata, @@ -78,7 +101,7 @@ where node: &T::Blockchain, ) -> anyhow::Result<(TransactionReceipt, GethTrace, DiffMode)> { let deployment_receipts = self - .handle_contract_deployment(metadata, case_idx, input, node) + .handle_input_contract_deployment(metadata, case_idx, input, node) .await?; let execution_receipt = self .handle_input_execution(input, deployment_receipts, node) @@ -93,8 +116,21 @@ where .await } + pub async fn handle_balance_assertion( + &mut self, + metadata: &Metadata, + _: CaseIdx, + balance_assertion: &BalanceAssertion, + node: &T::Blockchain, + ) -> anyhow::Result<()> { + self.handle_balance_assertion_contract_deployment(metadata, balance_assertion, node) + .await?; + + Ok(()) + } + /// Handles the contract deployment for a given input performing it if it needs to be performed. - async fn handle_contract_deployment( + async fn handle_input_contract_deployment( &mut self, metadata: &Metadata, case_idx: CaseIdx, @@ -462,6 +498,65 @@ where Ok((execution_receipt, trace, diff)) } + pub async fn handle_balance_assertion_contract_deployment( + &mut self, + metadata: &Metadata, + balance_assertion: &BalanceAssertion, + node: &T::Blockchain, + ) -> anyhow::Result<()> { + let Some(instance) = balance_assertion + .address + .strip_prefix(".address") + .map(ContractInstance::new) + else { + return Ok(()); + }; + self.get_or_deploy_contract_instance( + &instance, + metadata, + Input::default_caller(), + None, + None, + node, + ) + .await?; + Ok(()) + } + + pub async fn handle_balance_assertion_execution( + &mut self, + BalanceAssertion { + address: address_string, + expected_balance: amount, + }: &BalanceAssertion, + node: &T::Blockchain, + ) -> anyhow::Result<()> { + let address = Address::from_slice( + Calldata::new_compound([address_string]) + .calldata(node, self.default_resolution_context()) + .await? + .get(12..32) + .expect("Can't fail"), + ); + + let balance = node.balance_of(address).await?; + + let expected = *amount; + let actual = balance; + if expected != actual { + tracing::error!(%expected, %actual, %address, "Balance assertion failed"); + anyhow::bail!( + "Balance assertion failed - Expected {} but got {} for {} resolved to {}", + expected, + actual, + address_string, + address, + ) + } + + Ok(()) + } + /// Gets the information of a deployed contract or library from the state. If it's found to not /// be deployed then it will be deployed. /// @@ -651,38 +746,53 @@ where return Ok(0); } - let mut inputs_executed = 0; - for (input_idx, input) in self.case.inputs_iterator().enumerate() { - let tracing_span = tracing::info_span!("Handling input", input_idx); + let mut steps_executed = 0; + for (step_idx, step) in self.case.steps_iterator().enumerate() { + let tracing_span = tracing::info_span!("Handling input", step_idx); - let (leader_receipt, _, leader_diff) = self + let leader_step_output = self .leader_state - .handle_input(self.metadata, self.case_idx, &input, self.leader_node) + .handle_step(self.metadata, self.case_idx, &step, self.leader_node) .instrument(tracing_span.clone()) .await?; - let (follower_receipt, _, follower_diff) = self + let follower_step_output = self .follower_state - .handle_input(self.metadata, self.case_idx, &input, self.follower_node) + .handle_step(self.metadata, self.case_idx, &step, self.follower_node) .instrument(tracing_span) .await?; + match (leader_step_output, follower_step_output) { + ( + StepOutput::FunctionCall(leader_receipt, _, leader_diff), + StepOutput::FunctionCall(follower_receipt, _, follower_diff), + ) => { + if leader_diff == follower_diff { + tracing::debug!("State diffs match between leader and follower."); + } else { + tracing::debug!("State diffs mismatch between leader and follower."); + Self::trace_diff_mode("Leader", &leader_diff); + Self::trace_diff_mode("Follower", &follower_diff); + } - if leader_diff == follower_diff { - tracing::debug!("State diffs match between leader and follower."); - } else { - tracing::debug!("State diffs mismatch between leader and follower."); - Self::trace_diff_mode("Leader", &leader_diff); - Self::trace_diff_mode("Follower", &follower_diff); + if leader_receipt.logs() != follower_receipt.logs() { + tracing::debug!("Log/event mismatch between leader and follower."); + tracing::trace!("Leader logs: {:?}", leader_receipt.logs()); + tracing::trace!("Follower logs: {:?}", follower_receipt.logs()); + } + } + (StepOutput::BalanceAssertion, StepOutput::BalanceAssertion) => {} + _ => unreachable!("The two step outputs can not be of a different kind"), } - if leader_receipt.logs() != follower_receipt.logs() { - tracing::debug!("Log/event mismatch between leader and follower."); - tracing::trace!("Leader logs: {:?}", leader_receipt.logs()); - tracing::trace!("Follower logs: {:?}", follower_receipt.logs()); - } - - inputs_executed += 1; + steps_executed += 1; } - Ok(inputs_executed) + Ok(steps_executed) } } + +#[derive(Clone, Debug)] +#[allow(clippy::large_enum_variant)] +pub enum StepOutput { + FunctionCall(TransactionReceipt, GethTrace, DiffMode), + BalanceAssertion, +} diff --git a/crates/core/src/main.rs b/crates/core/src/main.rs index 4221b8d..70512a5 100644 --- a/crates/core/src/main.rs +++ b/crates/core/src/main.rs @@ -32,7 +32,7 @@ use revive_dt_core::{ use revive_dt_format::{ case::{Case, CaseIdx}, corpus::Corpus, - input::Input, + input::{Input, Step}, metadata::{ContractInstance, ContractPathAndIdent, Metadata, MetadataFile}, mode::SolcMode, }; @@ -446,9 +446,12 @@ where // doing the deployments from different accounts and therefore we're not slowed down by // the nonce. let deployer_address = case - .inputs + .steps .iter() - .map(|input| input.caller) + .filter_map(|step| match step { + Step::FunctionCall(input) => Some(input.caller), + Step::BalanceAssertion(..) => None, + }) .next() .unwrap_or(Input::default_caller()); let leader_tx = TransactionBuilder::::with_deploy_code( diff --git a/crates/format/src/case.rs b/crates/format/src/case.rs index 59572f0..26e803e 100644 --- a/crates/format/src/case.rs +++ b/crates/format/src/case.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use revive_dt_common::macros::define_wrapper_type; use crate::{ - input::{Expected, Input}, + input::{Expected, Step}, mode::Mode, }; @@ -12,21 +12,27 @@ pub struct Case { pub name: Option, pub comment: Option, pub modes: Option>, - pub inputs: Vec, + #[serde(rename = "inputs")] + pub steps: Vec, pub group: Option, pub expected: Option, pub ignore: Option, } impl Case { - pub fn inputs_iterator(&self) -> impl Iterator { - let inputs_len = self.inputs.len(); - self.inputs + #[allow(irrefutable_let_patterns)] + pub fn steps_iterator(&self) -> impl Iterator { + let steps_len = self.steps.len(); + self.steps .clone() .into_iter() .enumerate() - .map(move |(idx, mut input)| { - if idx + 1 == inputs_len { + .map(move |(idx, mut step)| { + let Step::FunctionCall(ref mut input) = step else { + return step; + }; + + if idx + 1 == steps_len { if input.expected.is_none() { input.expected = self.expected.clone(); } @@ -36,9 +42,9 @@ impl Case { // the case? What are we supposed to do with that final expected field on the // case? - input + step } else { - input + step } }) } diff --git a/crates/format/src/input.rs b/crates/format/src/input.rs index 810ec31..cc96341 100644 --- a/crates/format/src/input.rs +++ b/crates/format/src/input.rs @@ -17,6 +17,19 @@ use revive_dt_common::macros::define_wrapper_type; use crate::traits::ResolverApi; use crate::{metadata::ContractInstance, traits::ResolutionContext}; +/// A test step. +/// +/// A test step can be anything. It could be an invocation to a function, an assertion, or any other +/// action that needs to be run or executed on the nodes used in the tests. +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[serde(untagged)] +pub enum Step { + /// A function call or an invocation to some function on some smart contract. + FunctionCall(Box), + /// A step for performing a balance assertion on some account or contract. + BalanceAssertion(Box), +} + #[derive(Clone, Debug, Default, Serialize, Deserialize, Eq, PartialEq)] pub struct Input { #[serde(default = "Input::default_caller")] @@ -33,6 +46,20 @@ pub struct Input { pub variable_assignments: Option, } +#[derive(Clone, Debug, Default, Serialize, Deserialize, Eq, PartialEq)] +pub struct BalanceAssertion { + /// The address that the balance assertion should be done on. + /// + /// This is a string which will be resolved into an address when being processed. Therefore, + /// this could be a normal hex address, a variable such as `Test.address`, or perhaps even a + /// full on variable like `$VARIABLE:Uniswap`. It follows the same resolution rules that are + /// followed in the calldata. + pub address: String, + + /// The amount of balance to assert that the account or contract has. + pub expected_balance: U256, +} + #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(untagged)] pub enum Expected { diff --git a/crates/node-interaction/src/lib.rs b/crates/node-interaction/src/lib.rs index 791ba4b..b052c7a 100644 --- a/crates/node-interaction/src/lib.rs +++ b/crates/node-interaction/src/lib.rs @@ -1,5 +1,6 @@ //! This crate implements all node interactions. +use alloy::primitives::{Address, U256}; use alloy::rpc::types::trace::geth::{DiffMode, GethDebugTracingOptions, GethTrace}; use alloy::rpc::types::{TransactionReceipt, TransactionRequest}; use anyhow::Result; @@ -21,4 +22,7 @@ pub trait EthereumNode { /// Returns the state diff of the transaction hash in the [TransactionReceipt]. fn state_diff(&self, receipt: &TransactionReceipt) -> impl Future>; + + /// Returns the balance of the provided [`Address`] back. + fn balance_of(&self, address: Address) -> impl Future>; } diff --git a/crates/node/src/geth.rs b/crates/node/src/geth.rs index 5ce7921..b3618ba 100644 --- a/crates/node/src/geth.rs +++ b/crates/node/src/geth.rs @@ -371,6 +371,15 @@ impl EthereumNode for GethNode { _ => anyhow::bail!("expected a diff mode trace"), } } + + #[tracing::instrument(skip_all, fields(geth_node_id = self.id))] + async fn balance_of(&self, address: Address) -> anyhow::Result { + self.provider() + .await? + .get_balance(address) + .await + .map_err(Into::into) + } } impl ResolverApi for GethNode { diff --git a/crates/node/src/kitchensink.rs b/crates/node/src/kitchensink.rs index 2300e05..805b71f 100644 --- a/crates/node/src/kitchensink.rs +++ b/crates/node/src/kitchensink.rs @@ -428,6 +428,15 @@ impl EthereumNode for KitchensinkNode { _ => anyhow::bail!("expected a diff mode trace"), } } + + #[tracing::instrument(skip_all, fields(kitchensink_node_id = self.id))] + async fn balance_of(&self, address: Address) -> anyhow::Result { + self.provider() + .await? + .get_balance(address) + .await + .map_err(Into::into) + } } impl ResolverApi for KitchensinkNode {