contracts: Allow ChainExtension::call() to access &mut self (#11874)

* Give chain extensions the ability to store some temporary values

* Update frame/contracts/src/wasm/runtime.rs

Co-authored-by: Hernando Castano <HCastano@users.noreply.github.com>

* Rename func_id -> id

* Replace `id` param by two functions on `env`

Co-authored-by: Hernando Castano <HCastano@users.noreply.github.com>
This commit is contained in:
Alexander Theißen
2022-07-25 17:48:01 +02:00
committed by GitHub
parent 626140454d
commit c470e9d11d
6 changed files with 254 additions and 56 deletions
+113 -37
View File
@@ -118,10 +118,17 @@ pub struct TestExtension {
last_seen_inputs: (u32, u32, u32, u32),
}
#[derive(Default)]
pub struct RevertingExtension;
#[derive(Default)]
pub struct DisabledExtension;
#[derive(Default)]
pub struct TempStorageExtension {
storage: u32,
}
impl TestExtension {
fn disable() {
TEST_EXTENSION.with(|e| e.borrow_mut().enabled = false)
@@ -143,18 +150,20 @@ impl Default for TestExtension {
}
impl ChainExtension<Test> for TestExtension {
fn call<E>(func_id: u32, env: Environment<E, InitState>) -> ExtensionResult<RetVal>
fn call<E>(&mut self, env: Environment<E, InitState>) -> ExtensionResult<RetVal>
where
E: Ext<T = Test>,
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
{
let func_id = env.func_id();
let id = env.ext_id() as u32 | func_id as u32;
match func_id {
0 => {
let mut env = env.buf_in_buf_out();
let input = env.read(8)?;
env.write(&input, false, None)?;
TEST_EXTENSION.with(|e| e.borrow_mut().last_seen_buffer = input);
Ok(RetVal::Converging(func_id))
Ok(RetVal::Converging(id))
},
1 => {
let env = env.only_in();
@@ -162,17 +171,17 @@ impl ChainExtension<Test> for TestExtension {
e.borrow_mut().last_seen_inputs =
(env.val0(), env.val1(), env.val2(), env.val3())
});
Ok(RetVal::Converging(func_id))
Ok(RetVal::Converging(id))
},
2 => {
let mut env = env.buf_in_buf_out();
let weight = env.read(5)?[4].into();
env.charge_weight(weight)?;
Ok(RetVal::Converging(func_id))
Ok(RetVal::Converging(id))
},
3 => Ok(RetVal::Diverging { flags: ReturnFlags::REVERT, data: vec![42, 99] }),
_ => {
panic!("Passed unknown func_id to test chain extension: {}", func_id);
panic!("Passed unknown id to test chain extension: {}", func_id);
},
}
}
@@ -187,7 +196,7 @@ impl RegisteredChainExtension<Test> for TestExtension {
}
impl ChainExtension<Test> for RevertingExtension {
fn call<E>(_func_id: u32, _env: Environment<E, InitState>) -> ExtensionResult<RetVal>
fn call<E>(&mut self, _env: Environment<E, InitState>) -> ExtensionResult<RetVal>
where
E: Ext<T = Test>,
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
@@ -205,7 +214,7 @@ impl RegisteredChainExtension<Test> for RevertingExtension {
}
impl ChainExtension<Test> for DisabledExtension {
fn call<E>(_func_id: u32, _env: Environment<E, InitState>) -> ExtensionResult<RetVal>
fn call<E>(&mut self, _env: Environment<E, InitState>) -> ExtensionResult<RetVal>
where
E: Ext<T = Test>,
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
@@ -222,6 +231,37 @@ impl RegisteredChainExtension<Test> for DisabledExtension {
const ID: u16 = 2;
}
impl ChainExtension<Test> for TempStorageExtension {
fn call<E>(&mut self, env: Environment<E, InitState>) -> ExtensionResult<RetVal>
where
E: Ext<T = Test>,
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
{
let func_id = env.func_id();
match func_id {
0 => self.storage = 42,
1 => assert_eq!(self.storage, 42, "Storage is preserved inside the same call."),
2 => {
assert_eq!(self.storage, 0, "Storage is different for different calls.");
self.storage = 99;
},
3 => assert_eq!(self.storage, 99, "Storage is preserved inside the same call."),
_ => {
panic!("Passed unknown id to test chain extension: {}", func_id);
},
}
Ok(RetVal::Converging(0))
}
fn enabled() -> bool {
TEST_EXTENSION.with(|e| e.borrow().enabled)
}
}
impl RegisteredChainExtension<Test> for TempStorageExtension {
const ID: u16 = 3;
}
parameter_types! {
pub BlockWeights: frame_system::limits::BlockWeights =
frame_system::limits::BlockWeights::simple_max(2 * WEIGHT_PER_SECOND);
@@ -325,7 +365,8 @@ impl Config for Test {
type CallStack = [Frame<Self>; 31];
type WeightPrice = Self;
type WeightInfo = ();
type ChainExtension = (TestExtension, DisabledExtension, RevertingExtension);
type ChainExtension =
(TestExtension, DisabledExtension, RevertingExtension, TempStorageExtension);
type DeletionQueueDepth = ConstU32<1024>;
type DeletionWeightLimit = ConstU64<500_000_000_000>;
type Schedule = MySchedule;
@@ -396,6 +437,29 @@ fn initialize_block(number: u64) {
System::initialize(&number, &[0u8; 32].into(), &Default::default());
}
struct ExtensionInput<'a> {
extension_id: u16,
func_id: u16,
extra: &'a [u8],
}
impl<'a> ExtensionInput<'a> {
fn to_vec(&self) -> Vec<u8> {
((self.extension_id as u32) << 16 | (self.func_id as u32))
.to_le_bytes()
.iter()
.chain(self.extra)
.cloned()
.collect()
}
}
impl<'a> From<ExtensionInput<'a>> for Vec<u8> {
fn from(input: ExtensionInput) -> Vec<u8> {
input.to_vec()
}
}
// Perform a call to a plain account.
// The actual transfer fails because we can only call contracts.
// Then we check that at least the base costs where charged (no runtime gas costs.)
@@ -1567,23 +1631,6 @@ fn disabled_chain_extension_errors_on_call() {
#[test]
fn chain_extension_works() {
struct Input<'a> {
extension_id: u16,
func_id: u16,
extra: &'a [u8],
}
impl<'a> From<Input<'a>> for Vec<u8> {
fn from(input: Input) -> Vec<u8> {
((input.extension_id as u32) << 16 | (input.func_id as u32))
.to_le_bytes()
.iter()
.chain(input.extra)
.cloned()
.collect()
}
}
let (code, hash) = compile_module::<Test>("chain_extension").unwrap();
ExtBuilder::default().existential_deposit(50).build().execute_with(|| {
let min_balance = <Test as Config>::Currency::minimum_balance();
@@ -1599,12 +1646,8 @@ fn chain_extension_works() {
),);
let addr = Contracts::contract_address(&ALICE, &hash, &[]);
// The contract takes a up to 2 byte buffer where the first byte passed is used as
// as func_id to the chain extension which behaves differently based on the
// func_id.
// 0 = read input buffer and pass it through as output
let input: Vec<u8> = Input { extension_id: 0, func_id: 0, extra: &[99] }.into();
let input: Vec<u8> = ExtensionInput { extension_id: 0, func_id: 0, extra: &[99] }.into();
let result =
Contracts::bare_call(ALICE, addr.clone(), 0, GAS_LIMIT, None, input.clone(), false);
assert_eq!(TestExtension::last_seen_buffer(), input);
@@ -1617,7 +1660,7 @@ fn chain_extension_works() {
0,
GAS_LIMIT,
None,
Input { extension_id: 0, func_id: 1, extra: &[] }.into(),
ExtensionInput { extension_id: 0, func_id: 1, extra: &[] }.into(),
false,
)
.result
@@ -1632,7 +1675,7 @@ fn chain_extension_works() {
0,
GAS_LIMIT,
None,
Input { extension_id: 0, func_id: 2, extra: &[0] }.into(),
ExtensionInput { extension_id: 0, func_id: 2, extra: &[0] }.into(),
false,
);
assert_ok!(result.result);
@@ -1643,7 +1686,7 @@ fn chain_extension_works() {
0,
GAS_LIMIT,
None,
Input { extension_id: 0, func_id: 2, extra: &[42] }.into(),
ExtensionInput { extension_id: 0, func_id: 2, extra: &[42] }.into(),
false,
);
assert_ok!(result.result);
@@ -1654,7 +1697,7 @@ fn chain_extension_works() {
0,
GAS_LIMIT,
None,
Input { extension_id: 0, func_id: 2, extra: &[95] }.into(),
ExtensionInput { extension_id: 0, func_id: 2, extra: &[95] }.into(),
false,
);
assert_ok!(result.result);
@@ -1667,7 +1710,7 @@ fn chain_extension_works() {
0,
GAS_LIMIT,
None,
Input { extension_id: 0, func_id: 3, extra: &[] }.into(),
ExtensionInput { extension_id: 0, func_id: 3, extra: &[] }.into(),
false,
)
.result
@@ -1684,7 +1727,7 @@ fn chain_extension_works() {
0,
GAS_LIMIT,
None,
Input { extension_id: 1, func_id: 0, extra: &[] }.into(),
ExtensionInput { extension_id: 1, func_id: 0, extra: &[] }.into(),
false,
)
.result
@@ -1701,13 +1744,46 @@ fn chain_extension_works() {
0,
GAS_LIMIT,
None,
Input { extension_id: 2, func_id: 0, extra: &[] }.into(),
ExtensionInput { extension_id: 2, func_id: 0, extra: &[] }.into(),
),
Error::<Test>::NoChainExtension,
);
});
}
#[test]
fn chain_extension_temp_storage_works() {
let (code, hash) = compile_module::<Test>("chain_extension_temp_storage").unwrap();
ExtBuilder::default().existential_deposit(50).build().execute_with(|| {
let min_balance = <Test as Config>::Currency::minimum_balance();
let _ = Balances::deposit_creating(&ALICE, 1000 * min_balance);
assert_ok!(Contracts::instantiate_with_code(
Origin::signed(ALICE),
min_balance * 100,
GAS_LIMIT,
None,
code,
vec![],
vec![],
),);
let addr = Contracts::contract_address(&ALICE, &hash, &[]);
// Call func 0 and func 1 back to back.
let stop_recursion = 0u8;
let mut input: Vec<u8> = ExtensionInput { extension_id: 3, func_id: 0, extra: &[] }.into();
input.extend_from_slice(
ExtensionInput { extension_id: 3, func_id: 1, extra: &[stop_recursion] }
.to_vec()
.as_ref(),
);
assert_ok!(
Contracts::bare_call(ALICE, addr.clone(), 0, GAS_LIMIT, None, input.clone(), false)
.result
);
})
}
#[test]
fn lazy_removal_works() {
let (code, hash) = compile_module::<Test>("self_destruct").unwrap();