diff --git a/crates/integration/contracts/DivisionArithmetics.sol b/crates/integration/contracts/DivisionArithmetics.sol new file mode 100644 index 0000000..972c6fd --- /dev/null +++ b/crates/integration/contracts/DivisionArithmetics.sol @@ -0,0 +1,25 @@ +contract DivisionArithmetics { + function div(uint n, uint d) public pure returns (uint q) { + assembly { + q := div(n, d) + } + } + + function sdiv(int n, int d) public pure returns (int q) { + assembly { + q := sdiv(n, d) + } + } + + function mod(uint n, uint d) public pure returns (uint r) { + assembly { + r := mod(n, d) + } + } + + function smod(int n, int d) public pure returns (int r) { + assembly { + r := smod(n, d) + } + } +} diff --git a/crates/integration/src/cases.rs b/crates/integration/src/cases.rs index 3611122..69b340b 100644 --- a/crates/integration/src/cases.rs +++ b/crates/integration/src/cases.rs @@ -1,4 +1,4 @@ -use alloy_primitives::U256; +use alloy_primitives::{I256, U256}; use alloy_sol_types::{sol, SolCall}; #[derive(Clone)] @@ -83,6 +83,18 @@ sol!( } ); +sol!( + contract DivisionArithmetics { + function div(uint n, uint d) public pure returns (uint q); + + function sdiv(int n, int d) public pure returns (int q); + + function mod(uint n, uint d) public pure returns (uint r); + + function smod(int n, int d) public pure returns (int r); + } +); + impl Contract { pub fn baseline() -> Self { let code = include_str!("../contracts/Baseline.sol"); @@ -226,6 +238,50 @@ impl Contract { calldata: Context::callerCall::new(()).abi_encode(), } } + + pub fn division_arithmetics_div(n: U256, d: U256) -> Self { + let code = include_str!("../contracts/DivisionArithmetics.sol"); + let name = "DivisionArithmetics"; + + Self { + evm_runtime: crate::compile_evm_bin_runtime(name, code), + pvm_runtime: crate::compile_blob(name, code), + calldata: DivisionArithmetics::divCall::new((n, d)).abi_encode(), + } + } + + pub fn division_arithmetics_sdiv(n: I256, d: I256) -> Self { + let code = include_str!("../contracts/DivisionArithmetics.sol"); + let name = "DivisionArithmetics"; + + Self { + evm_runtime: crate::compile_evm_bin_runtime(name, code), + pvm_runtime: crate::compile_blob(name, code), + calldata: DivisionArithmetics::sdivCall::new((n, d)).abi_encode(), + } + } + + pub fn division_arithmetics_mod(n: U256, d: U256) -> Self { + let code = include_str!("../contracts/DivisionArithmetics.sol"); + let name = "DivisionArithmetics"; + + Self { + evm_runtime: crate::compile_evm_bin_runtime(name, code), + pvm_runtime: crate::compile_blob(name, code), + calldata: DivisionArithmetics::modCall::new((n, d)).abi_encode(), + } + } + + pub fn division_arithmetics_smod(n: I256, d: I256) -> Self { + let code = include_str!("../contracts/DivisionArithmetics.sol"); + let name = "DivisionArithmetics"; + + Self { + evm_runtime: crate::compile_evm_bin_runtime(name, code), + pvm_runtime: crate::compile_blob(name, code), + calldata: DivisionArithmetics::smodCall::new((n, d)).abi_encode(), + } + } } #[cfg(test)] diff --git a/crates/integration/src/tests.rs b/crates/integration/src/tests.rs index 78d784c..f5042cc 100644 --- a/crates/integration/src/tests.rs +++ b/crates/integration/src/tests.rs @@ -296,3 +296,99 @@ fn caller() { let expected = Address::from(&mock_runtime::State::CALLER); assert_eq!(received, expected); } + +#[test] +fn unsigned_division() { + let one = U256::from(1); + let two = U256::from(2); + let five = U256::from(5); + for (n, d, q) in [ + (five, five, one), + (five, one, five), + (U256::ZERO, U256::MAX, U256::ZERO), + (five, two, two), + (one, U256::ZERO, U256::ZERO), + ] { + let state = assert_success(Contract::division_arithmetics_div(n, d), true); + let received = U256::from_be_bytes::<32>(state.output.data.try_into().unwrap()); + assert_eq!(received, q); + } +} + +#[test] +fn signed_division() { + let one = I256::try_from(1).unwrap(); + let two = I256::try_from(2).unwrap(); + let minus_two = I256::try_from(-2).unwrap(); + let five = I256::try_from(5).unwrap(); + let minus_five = I256::try_from(-5).unwrap(); + for (n, d, q) in [ + (five, five, one), + (five, one, five), + (I256::ZERO, I256::MAX, I256::ZERO), + (I256::ZERO, I256::MINUS_ONE, I256::ZERO), + (five, two, two), + (five, I256::MINUS_ONE, minus_five), + (I256::MINUS_ONE, minus_two, I256::ZERO), + (minus_five, minus_five, one), + (minus_five, two, minus_two), + (I256::MINUS_ONE, I256::MIN, I256::ZERO), + (one, I256::ZERO, I256::ZERO), + ] { + let state = assert_success(Contract::division_arithmetics_sdiv(n, d), true); + let received = I256::from_be_bytes::<32>(state.output.data.try_into().unwrap()); + assert_eq!(received, q); + } +} + +#[test] +fn unsigned_remainder() { + let one = U256::from(1); + let two = U256::from(2); + let five = U256::from(5); + for (n, d, q) in [ + (five, five, U256::ZERO), + (five, one, U256::ZERO), + (U256::ZERO, U256::MAX, U256::ZERO), + (U256::MAX, U256::MAX, U256::ZERO), + (five, two, one), + (two, five, two), + (U256::MAX, U256::ZERO, U256::ZERO), + ] { + let state = assert_success(Contract::division_arithmetics_mod(n, d), true); + let received = U256::from_be_bytes::<32>(state.output.data.try_into().unwrap()); + assert_eq!(received, q); + } +} + +#[test] +fn signed_remainder() { + let one = I256::try_from(1).unwrap(); + let two = I256::try_from(2).unwrap(); + let minus_two = I256::try_from(-2).unwrap(); + let five = I256::try_from(5).unwrap(); + let minus_five = I256::try_from(-5).unwrap(); + for (n, d, q) in [ + (five, five, I256::ZERO), + (five, one, I256::ZERO), + (I256::ZERO, I256::MAX, I256::ZERO), + (I256::MAX, I256::MAX, I256::ZERO), + (five, two, one), + (two, five, two), + (five, minus_five, I256::ZERO), + (five, I256::MINUS_ONE, I256::ZERO), + (five, minus_two, one), + (minus_five, two, I256::MINUS_ONE), + (minus_two, five, minus_two), + (minus_five, minus_five, I256::ZERO), + (minus_five, I256::MINUS_ONE, I256::ZERO), + (minus_five, minus_two, I256::MINUS_ONE), + (minus_two, minus_five, minus_two), + (I256::MIN, I256::MINUS_ONE, I256::ZERO), + (I256::ZERO, I256::ZERO, I256::ZERO), + ] { + let state = assert_success(Contract::division_arithmetics_smod(n, d), true); + let received = I256::from_be_bytes::<32>(state.output.data.try_into().unwrap()); + assert_eq!(received, q); + } +} diff --git a/crates/llvm-context/src/polkavm/context/function/llvm_runtime.rs b/crates/llvm-context/src/polkavm/context/function/llvm_runtime.rs index 523a92b..aecb76f 100644 --- a/crates/llvm-context/src/polkavm/context/function/llvm_runtime.rs +++ b/crates/llvm-context/src/polkavm/context/function/llvm_runtime.rs @@ -16,15 +16,6 @@ pub struct LLVMRuntime<'ctx> { /// The LLVM exception throwing function. pub cxa_throw: FunctionDeclaration<'ctx>, - /// The corresponding LLVM runtime function. - pub div: FunctionDeclaration<'ctx>, - /// The corresponding LLVM runtime function. - pub sdiv: FunctionDeclaration<'ctx>, - /// The corresponding LLVM runtime function. - pub r#mod: FunctionDeclaration<'ctx>, - /// The corresponding LLVM runtime function. - pub smod: FunctionDeclaration<'ctx>, - /// The corresponding LLVM runtime function. pub shl: FunctionDeclaration<'ctx>, /// The corresponding LLVM runtime function. @@ -82,18 +73,6 @@ impl<'ctx> LLVMRuntime<'ctx> { /// The LLVM exception throwing function name. pub const FUNCTION_CXA_THROW: &'static str = "__cxa_throw"; - /// The corresponding runtime function name. - pub const FUNCTION_DIV: &'static str = "__div"; - - /// The corresponding runtime function name. - pub const FUNCTION_SDIV: &'static str = "__sdiv"; - - /// The corresponding runtime function name. - pub const FUNCTION_MOD: &'static str = "__mod"; - - /// The corresponding runtime function name. - pub const FUNCTION_SMOD: &'static str = "__smod"; - /// The corresponding runtime function name. pub const FUNCTION_SHL: &'static str = "__shl"; @@ -184,82 +163,6 @@ impl<'ctx> LLVMRuntime<'ctx> { ); Function::set_cxa_throw_attributes(llvm, cxa_throw); - let div = Self::declare( - module, - Self::FUNCTION_DIV, - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .fn_type( - vec![ - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .as_basic_type_enum() - .into(); - 2 - ] - .as_slice(), - false, - ), - Some(inkwell::module::Linkage::External), - ); - Function::set_default_attributes(llvm, div, optimizer); - Function::set_pure_function_attributes(llvm, div); - - let r#mod = Self::declare( - module, - Self::FUNCTION_MOD, - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .fn_type( - vec![ - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .as_basic_type_enum() - .into(); - 2 - ] - .as_slice(), - false, - ), - Some(inkwell::module::Linkage::External), - ); - Function::set_default_attributes(llvm, r#mod, optimizer); - Function::set_pure_function_attributes(llvm, r#mod); - - let sdiv = Self::declare( - module, - Self::FUNCTION_SDIV, - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .fn_type( - vec![ - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .as_basic_type_enum() - .into(); - 2 - ] - .as_slice(), - false, - ), - Some(inkwell::module::Linkage::External), - ); - Function::set_default_attributes(llvm, sdiv, optimizer); - Function::set_pure_function_attributes(llvm, sdiv); - - let smod = Self::declare( - module, - Self::FUNCTION_SMOD, - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .fn_type( - vec![ - llvm.custom_width_int_type(revive_common::BIT_LENGTH_WORD as u32) - .as_basic_type_enum() - .into(); - 2 - ] - .as_slice(), - false, - ), - Some(inkwell::module::Linkage::External), - ); - Function::set_default_attributes(llvm, smod, optimizer); - Function::set_pure_function_attributes(llvm, smod); - let shl = Self::declare( module, Self::FUNCTION_SHL, @@ -554,11 +457,6 @@ impl<'ctx> LLVMRuntime<'ctx> { personality, cxa_throw, - div, - sdiv, - r#mod, - smod, - shl, shr, sar, diff --git a/crates/llvm-context/src/polkavm/evm/arithmetic.rs b/crates/llvm-context/src/polkavm/evm/arithmetic.rs index 9990ec3..ed808bd 100644 --- a/crates/llvm-context/src/polkavm/evm/arithmetic.rs +++ b/crates/llvm-context/src/polkavm/evm/arithmetic.rs @@ -59,10 +59,11 @@ pub fn division<'ctx, D>( where D: Dependency + Clone, { - Ok(context - .builder() - .build_int_unsigned_div(operand_1, operand_2, "udiv")? - .into()) + WrappedDivision::new(context, operand_2)?.with(|| { + Ok(context + .builder() + .build_int_unsigned_div(operand_1, operand_2, "DIV")?) + }) } /// Translates the arithmetic remainder. @@ -74,16 +75,11 @@ pub fn remainder<'ctx, D>( where D: Dependency + Clone, { - Ok(context - .build_call( - context.llvm_runtime().r#mod, - &[ - operand_1.as_basic_value_enum(), - operand_2.as_basic_value_enum(), - ], - "add_mod_call", - ) - .expect("Always exists")) + WrappedDivision::new(context, operand_2)?.with(|| { + Ok(context + .builder() + .build_int_unsigned_rem(operand_1, operand_2, "MOD")?) + }) } /// Translates the signed arithmetic division. @@ -98,16 +94,59 @@ pub fn division_signed<'ctx, D>( where D: Dependency + Clone, { - Ok(context - .build_call( - context.llvm_runtime().sdiv, - &[ - operand_1.as_basic_value_enum(), - operand_2.as_basic_value_enum(), - ], - "add_mod_call", - ) - .expect("Always exists")) + WrappedDivision::new(context, operand_2)?.with(|| { + let block_no_overflow = context.append_basic_block("no_overflow"); + let block_operand_1_overflow = context.append_basic_block("operand_1_overflow"); + let block_select_quotient = context.append_basic_block("block_select_quotient"); + + let max_uint = context.builder().build_int_z_extend( + context + .integer_type(revive_common::BIT_LENGTH_WORD - 1) + .const_all_ones(), + context.word_type(), + "constant_zext_max_uint", + )?; + let is_operand_1_overflow = context.builder().build_int_compare( + inkwell::IntPredicate::EQ, + operand_1, + context.builder().build_int_neg(max_uint, "min_uint")?, + "is_operand_1_overflow", + )?; + context.build_conditional_branch( + is_operand_1_overflow, + block_operand_1_overflow, + block_no_overflow, + )?; + + context.set_basic_block(block_operand_1_overflow); + let is_operand_2_overflow = context.builder().build_int_compare( + inkwell::IntPredicate::EQ, + operand_2, + context.word_type().const_all_ones(), + "is_operand_2_overflow", + )?; + context.build_conditional_branch( + is_operand_2_overflow, + block_select_quotient, + block_no_overflow, + )?; + + context.set_basic_block(block_no_overflow); + let quotient = context + .builder() + .build_int_signed_div(operand_1, operand_2, "SDIV")?; + context.build_unconditional_branch(block_select_quotient); + + context.set_basic_block(block_select_quotient); + let phi_value = context + .builder() + .build_phi(context.word_type(), "phi_quotient")?; + phi_value.add_incoming(&[ + ("ient.as_basic_value_enum(), block_no_overflow), + (&operand_1, block_operand_1_overflow), + ]); + Ok(phi_value.as_basic_value().into_int_value()) + }) } /// Translates the signed arithmetic remainder. @@ -119,14 +158,73 @@ pub fn remainder_signed<'ctx, D>( where D: Dependency + Clone, { - Ok(context - .build_call( - context.llvm_runtime().smod, - &[ - operand_1.as_basic_value_enum(), - operand_2.as_basic_value_enum(), - ], - "add_mod_call", - ) - .expect("Always exists")) + WrappedDivision::new(context, operand_2)?.with(|| { + Ok(context + .builder() + .build_int_signed_rem(operand_1, operand_2, "SMOD")?) + }) +} + +/// Helper to wrap division operations so that zero will be returned +/// if the denominator is zero (see also Ethereum YP Appendix H.2). +struct WrappedDivision<'a, 'ctx, D: Dependency + Clone> { + context: &'a Context<'ctx, D>, + block_origin: inkwell::basic_block::BasicBlock<'ctx>, + block_calculate: inkwell::basic_block::BasicBlock<'ctx>, + block_select: inkwell::basic_block::BasicBlock<'ctx>, +} + +impl<'a, 'ctx, D: Dependency + Clone> WrappedDivision<'a, 'ctx, D> { + /// Create a new wrapped division (inserts a switch on the denominator). + fn new( + context: &'a Context<'ctx, D>, + denominator: inkwell::values::IntValue<'ctx>, + ) -> anyhow::Result { + assert_eq!( + denominator.get_type().get_bit_width(), + revive_common::BIT_LENGTH_WORD as u32 + ); + + let block_calculate = context.append_basic_block("calculate"); + let block_select = context.append_basic_block("select"); + context.builder().build_switch( + denominator, + block_calculate, + &[(context.word_const(0), block_select)], + )?; + + Ok(Self { + context, + block_origin: context.basic_block(), + block_calculate, + block_select, + }) + } + + /// Insert code to calculate the operation. + /// + /// The closure is expected to calculate and return the quotient. + /// + /// The returned value is either the calculated quotient or zero, selected at runtime. + fn with(self, f: F) -> anyhow::Result> + where + F: FnOnce() -> anyhow::Result, + T: inkwell::values::IntMathValue<'ctx>, + { + self.context.set_basic_block(self.block_calculate); + let calculated_value = f()?.as_basic_value_enum(); + let calculated_value_incoming_block = self.context.basic_block(); + self.context.build_unconditional_branch(self.block_select); + + self.context.set_basic_block(self.block_select); + let phi_value = self + .context + .builder() + .build_phi(self.context.word_type(), "phi_result")?; + phi_value.add_incoming(&[ + (&self.context.word_const(0), self.block_origin), + (&calculated_value, calculated_value_incoming_block), + ]); + Ok(phi_value.as_basic_value()) + } } diff --git a/crates/stdlib/stdlib.ll b/crates/stdlib/stdlib.ll index 745c969..8f0dc3f 100644 --- a/crates/stdlib/stdlib.ll +++ b/crates/stdlib/stdlib.ll @@ -249,42 +249,6 @@ entry: } -define i256 @__mod(i256 %arg1, i256 %arg2) #0 { -entry: - %is_divider_zero = icmp eq i256 %arg2, 0 - br i1 %is_divider_zero, label %return, label %remainder - -remainder: - %rem_res = urem i256 %arg1, %arg2 - br label %return - -return: - %res = phi i256 [ 0, %entry ], [ %rem_res, %remainder ] - ret i256 %res -} - -define i256 @__smod(i256 %arg1, i256 %arg2) #0 { -entry: - %is_divider_zero = icmp eq i256 %arg2, 0 - br i1 %is_divider_zero, label %return, label %division_overflow - -division_overflow: - %is_divided_int_min = icmp eq i256 %arg1, -57896044618658097711785492504343953926634992332820282019728792003956564819968 - %is_minus_one = icmp eq i256 %arg2, -1 - %is_overflow = and i1 %is_divided_int_min, %is_minus_one - br i1 %is_overflow, label %return, label %remainder - -remainder: - %rem_res = srem i256 %arg1, %arg2 - br label %return - -return: - %res = phi i256 [ 0, %entry ], [ 0, %division_overflow ], [ %rem_res, %remainder ] - ret i256 %res -} - - - attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn } attributes #1 = { argmemonly readonly nofree null_pointer_is_valid } attributes #2 = { argmemonly mustprogress nofree norecurse nosync nounwind willreturn null_pointer_is_valid }