6 Commits

Author SHA1 Message Date
Sergei Shulepov 310db5f92c Improve performance of stack height injection
Specifically,

- avoid using `import_count` and `function_space` since those are slow.
- Avoid using the `builder` module since it requires recreation of the module essentially and is not efficient.
- That also allows us to avoid creation of redundant signatures for thunks.
2022-10-19 17:56:29 +02:00
dependabot[bot] 4c1d47a618 Update wasmparser requirement from 0.90 to 0.92 (#33)
Updates the requirements on [wasmparser](https://github.com/bytecodealliance/wasm-tools) to permit the latest version.
- [Release notes](https://github.com/bytecodealliance/wasm-tools/releases)
- [Commits](https://github.com/bytecodealliance/wasm-tools/compare/wasmparser-0.90.0...wasmparser-0.92.0)

---
updated-dependencies:
- dependency-name: wasmparser
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-10-10 11:22:57 +02:00
dependabot[bot] 840af19d4b Update criterion requirement from 0.3 to 0.4 (#31)
Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version.
- [Release notes](https://github.com/bheisler/criterion.rs/releases)
- [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/bheisler/criterion.rs/compare/0.3.0...0.4.0)

---
updated-dependencies:
- dependency-name: criterion
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-09-14 12:00:09 +02:00
Alexander Theißen 0229f865b6 Use u64 for gas counter (#30)
* Use `u64` for gas counter

* Update doc
2022-09-11 14:36:06 +02:00
dependabot[bot] c2db4b8365 Update wasmparser requirement from 0.89 to 0.90 (#28)
Updates the requirements on [wasmparser](https://github.com/bytecodealliance/wasm-tools) to permit the latest version.
- [Release notes](https://github.com/bytecodealliance/wasm-tools/releases)
- [Commits](https://github.com/bytecodealliance/wasm-tools/compare/wasmparser-0.89.0...wasmparser-0.90.0)

---
updated-dependencies:
- dependency-name: wasmparser
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-08-24 08:59:46 +02:00
dependabot[bot] 1ade161da4 Update wasmparser requirement from 0.88 to 0.89 (#27)
Updates the requirements on [wasmparser](https://github.com/bytecodealliance/wasm-tools) to permit the latest version.
- [Release notes](https://github.com/bytecodealliance/wasm-tools/releases)
- [Commits](https://github.com/bytecodealliance/wasm-tools/compare/wasmparser-0.88.0...wasmparser-0.89.0)

---
updated-dependencies:
- dependency-name: wasmparser
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-08-17 14:08:30 +02:00
13 changed files with 487 additions and 509 deletions
+10
View File
@@ -16,7 +16,17 @@ The interface provided to smart contracts will adhere to semver with one excepti
major version bumps will be backwards compatible with regard to already deployed contracts. major version bumps will be backwards compatible with regard to already deployed contracts.
In other words: Upgrading this pallet will not break pre-existing contracts. In other words: Upgrading this pallet will not break pre-existing contracts.
## [v0.3.0]
### Changed
- Use 64bit arithmetic for per-block gas counter
[#30](https://github.com/paritytech/wasm-instrument/pull/30)
## [v0.2.0] 2022-06-06 ## [v0.2.0] 2022-06-06
### Changed
- Adjust debug information (if already parsed) when injecting gas metering - Adjust debug information (if already parsed) when injecting gas metering
[#16](https://github.com/paritytech/wasm-instrument/pull/16) [#16](https://github.com/paritytech/wasm-instrument/pull/16)
+4 -4
View File
@@ -1,6 +1,6 @@
[package] [package]
name = "wasm-instrument" name = "wasm-instrument"
version = "0.2.0" version = "0.3.0"
edition = "2021" edition = "2021"
rust-version = "1.56.1" rust-version = "1.56.1"
authors = ["Parity Technologies <admin@parity.io>"] authors = ["Parity Technologies <admin@parity.io>"]
@@ -22,15 +22,15 @@ codegen-units = 1
[dependencies] [dependencies]
parity-wasm = { version = "0.45", default-features = false } parity-wasm = { version = "0.45", default-features = false }
log = "0.4"
[dev-dependencies] [dev-dependencies]
binaryen = "0.12" binaryen = "0.12"
criterion = "0.3" criterion = "0.4"
diff = "0.1" diff = "0.1"
pretty_assertions = "1"
rand = "0.8" rand = "0.8"
wat = "1" wat = "1"
wasmparser = "0.88" wasmparser = "0.92"
wasmprinter = "0.2" wasmprinter = "0.2"
[features] [features]
+49 -38
View File
@@ -105,7 +105,7 @@ impl Rules for ConstantCostRules {
/// imported gas metering function. /// imported gas metering function.
/// ///
/// The output module imports a function "gas" from the specified module with type signature /// The output module imports a function "gas" from the specified module with type signature
/// [i32] -> []. The argument is the amount of gas required to continue execution. The external /// [i64] -> []. The argument is the amount of gas required to continue execution. The external
/// function is meant to keep track of the total amount of gas used and trap or otherwise halt /// function is meant to keep track of the total amount of gas used and trap or otherwise halt
/// execution of the runtime if the gas usage exceeds some allowed limit. /// execution of the runtime if the gas usage exceeds some allowed limit.
/// ///
@@ -144,7 +144,7 @@ pub fn inject<R: Rules>(
// Injecting gas counting external // Injecting gas counting external
let mut mbuilder = builder::from_module(module); let mut mbuilder = builder::from_module(module);
let import_sig = let import_sig =
mbuilder.push_signature(builder::signature().with_param(ValueType::I32).build_sig()); mbuilder.push_signature(builder::signature().with_param(ValueType::I64).build_sig());
mbuilder.push_import( mbuilder.push_import(
builder::import() builder::import()
@@ -284,7 +284,7 @@ struct MeteredBlock {
/// Index of the first instruction (aka `Opcode`) in the block. /// Index of the first instruction (aka `Opcode`) in the block.
start_pos: usize, start_pos: usize,
/// Sum of costs of all instructions until end of the block. /// Sum of costs of all instructions until end of the block.
cost: u32, cost: u64,
} }
/// Counter is used to manage state during the gas metering algorithm implemented by /// Counter is used to manage state during the gas metering algorithm implemented by
@@ -375,7 +375,8 @@ impl Counter {
.expect("last_index is greater than 0; last_index is stack size - 1; qed"); .expect("last_index is greater than 0; last_index is stack size - 1; qed");
let prev_metered_block = &mut prev_control_block.active_metered_block; let prev_metered_block = &mut prev_control_block.active_metered_block;
if closing_metered_block.start_pos == prev_metered_block.start_pos { if closing_metered_block.start_pos == prev_metered_block.start_pos {
prev_metered_block.cost += closing_metered_block.cost; prev_metered_block.cost =
prev_metered_block.cost.checked_add(closing_metered_block.cost).ok_or(())?;
return Ok(()) return Ok(())
} }
} }
@@ -425,7 +426,7 @@ impl Counter {
/// Increment the cost of the current block by the specified value. /// Increment the cost of the current block by the specified value.
fn increment(&mut self, val: u32) -> Result<(), ()> { fn increment(&mut self, val: u32) -> Result<(), ()> {
let top_block = self.active_metered_block()?; let top_block = self.active_metered_block()?;
top_block.cost = top_block.cost.checked_add(val).ok_or(())?; top_block.cost = top_block.cost.checked_add(val.into()).ok_or(())?;
Ok(()) Ok(())
} }
} }
@@ -465,8 +466,9 @@ fn add_grow_counter<R: Rules>(
.with_instructions(elements::Instructions::new(vec![ .with_instructions(elements::Instructions::new(vec![
GetLocal(0), GetLocal(0),
GetLocal(0), GetLocal(0),
I32Const(cost as i32), I64ExtendUI32,
I32Mul, I64Const(i64::from(cost)),
I64Mul,
// todo: there should be strong guarantee that it does not return anything on // todo: there should be strong guarantee that it does not return anything on
// stack? // stack?
Call(gas_func), Call(gas_func),
@@ -583,7 +585,7 @@ fn insert_metering_calls(
// If there the next block starts at this position, inject metering instructions. // If there the next block starts at this position, inject metering instructions.
let used_block = if let Some(block) = block_iter.peek() { let used_block = if let Some(block) = block_iter.peek() {
if block.start_pos == original_pos { if block.start_pos == original_pos {
new_instrs.push(I32Const(block.cost as i32)); new_instrs.push(I64Const(block.cost as i64));
new_instrs.push(Call(gas_func)); new_instrs.push(Call(gas_func));
true true
} else { } else {
@@ -612,6 +614,7 @@ fn insert_metering_calls(
mod tests { mod tests {
use super::*; use super::*;
use parity_wasm::{builder, elements, elements::Instruction::*, serialize}; use parity_wasm::{builder, elements, elements::Instruction::*, serialize};
use pretty_assertions::assert_eq;
fn get_function_body( fn get_function_body(
module: &elements::Module, module: &elements::Module,
@@ -639,12 +642,20 @@ mod tests {
assert_eq!( assert_eq!(
get_function_body(&injected_module, 0).unwrap(), get_function_body(&injected_module, 0).unwrap(),
&vec![I32Const(2), Call(0), GetGlobal(0), Call(2), End][..] &vec![I64Const(2), Call(0), GetGlobal(0), Call(2), End][..]
); );
assert_eq!( assert_eq!(
get_function_body(&injected_module, 1).unwrap(), get_function_body(&injected_module, 1).unwrap(),
&vec![GetLocal(0), GetLocal(0), I32Const(10000), I32Mul, Call(0), GrowMemory(0), End,] &vec![
[..] GetLocal(0),
GetLocal(0),
I64ExtendUI32,
I64Const(10000),
I64Mul,
Call(0),
GrowMemory(0),
End,
][..]
); );
let binary = serialize(injected_module).expect("serialization failed"); let binary = serialize(injected_module).expect("serialization failed");
@@ -667,7 +678,7 @@ mod tests {
assert_eq!( assert_eq!(
get_function_body(&injected_module, 0).unwrap(), get_function_body(&injected_module, 0).unwrap(),
&vec![I32Const(2), Call(0), GetGlobal(0), GrowMemory(0), End][..] &vec![I64Const(2), Call(0), GetGlobal(0), GrowMemory(0), End][..]
); );
assert_eq!(injected_module.functions_space(), 2); assert_eq!(injected_module.functions_space(), 2);
@@ -719,17 +730,17 @@ mod tests {
assert_eq!( assert_eq!(
get_function_body(&injected_module, 1).unwrap(), get_function_body(&injected_module, 1).unwrap(),
&vec![ &vec![
I32Const(3), I64Const(3),
Call(0), Call(0),
Call(1), Call(1),
If(elements::BlockType::NoResult), If(elements::BlockType::NoResult),
I32Const(3), I64Const(3),
Call(0), Call(0),
Call(1), Call(1),
Call(1), Call(1),
Call(1), Call(1),
Else, Else,
I32Const(2), I64Const(2),
Call(0), Call(0),
Call(1), Call(1),
Call(1), Call(1),
@@ -775,7 +786,7 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 1)) (call 0 (i64.const 1))
(get_global 0))) (get_global 0)))
"# "#
} }
@@ -795,7 +806,7 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 6)) (call 0 (i64.const 6))
(get_global 0) (get_global 0)
(block (block
(get_global 0) (get_global 0)
@@ -824,16 +835,16 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 3)) (call 0 (i64.const 3))
(get_global 0) (get_global 0)
(if (if
(then (then
(call 0 (i32.const 3)) (call 0 (i64.const 3))
(get_global 0) (get_global 0)
(get_global 0) (get_global 0)
(get_global 0)) (get_global 0))
(else (else
(call 0 (i32.const 2)) (call 0 (i64.const 2))
(get_global 0) (get_global 0)
(get_global 0))) (get_global 0)))
(get_global 0))) (get_global 0)))
@@ -857,13 +868,13 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 6)) (call 0 (i64.const 6))
(get_global 0) (get_global 0)
(block (block
(get_global 0) (get_global 0)
(drop) (drop)
(br 0) (br 0)
(call 0 (i32.const 2)) (call 0 (i64.const 2))
(get_global 0) (get_global 0)
(drop)) (drop))
(get_global 0))) (get_global 0)))
@@ -891,18 +902,18 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 5)) (call 0 (i64.const 5))
(get_global 0) (get_global 0)
(block (block
(get_global 0) (get_global 0)
(if (if
(then (then
(call 0 (i32.const 4)) (call 0 (i64.const 4))
(get_global 0) (get_global 0)
(get_global 0) (get_global 0)
(drop) (drop)
(br_if 1))) (br_if 1)))
(call 0 (i32.const 2)) (call 0 (i64.const 2))
(get_global 0) (get_global 0)
(drop)) (drop))
(get_global 0))) (get_global 0)))
@@ -933,18 +944,18 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 3)) (call 0 (i64.const 3))
(get_global 0) (get_global 0)
(loop (loop
(call 0 (i32.const 4)) (call 0 (i64.const 4))
(get_global 0) (get_global 0)
(if (if
(then (then
(call 0 (i32.const 2)) (call 0 (i64.const 2))
(get_global 0) (get_global 0)
(br_if 0)) (br_if 0))
(else (else
(call 0 (i32.const 4)) (call 0 (i64.const 4))
(get_global 0) (get_global 0)
(get_global 0) (get_global 0)
(drop) (drop)
@@ -969,13 +980,13 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 2)) (call 0 (i64.const 2))
(get_global 0) (get_global 0)
(if (if
(then (then
(call 0 (i32.const 1)) (call 0 (i64.const 1))
(return))) (return)))
(call 0 (i32.const 1)) (call 0 (i64.const 1))
(get_global 0))) (get_global 0)))
"# "#
} }
@@ -998,18 +1009,18 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (result i32) (func (result i32)
(call 0 (i32.const 5)) (call 0 (i64.const 5))
(get_global 0) (get_global 0)
(block (block
(get_global 0) (get_global 0)
(if (if
(then (then
(call 0 (i32.const 1)) (call 0 (i64.const 1))
(br 1)) (br 1))
(else (else
(call 0 (i32.const 1)) (call 0 (i64.const 1))
(br 0))) (br 0)))
(call 0 (i32.const 2)) (call 0 (i64.const 2))
(get_global 0) (get_global 0)
(drop)) (drop))
(get_global 0))) (get_global 0)))
@@ -1031,9 +1042,9 @@ mod tests {
expected = r#" expected = r#"
(module (module
(func (func
(call 0 (i32.const 2)) (call 0 (i64.const 2))
(loop (loop
(call 0 (i32.const 1)) (call 0 (i64.const 1))
(br 0) (br 0)
) )
unreachable unreachable
+7 -7
View File
@@ -23,10 +23,10 @@ struct ControlFlowNode {
first_instr_pos: Option<usize>, first_instr_pos: Option<usize>,
/// The actual gas cost of executing all instructions in the basic block. /// The actual gas cost of executing all instructions in the basic block.
actual_cost: u32, actual_cost: u64,
/// The amount of gas charged by the injected metering instructions within this basic block. /// The amount of gas charged by the injected metering instructions within this basic block.
charged_cost: u32, charged_cost: u64,
/// Whether there are any other nodes in the graph that loop back to this one. Every cycle in /// Whether there are any other nodes in the graph that loop back to this one. Every cycle in
/// the control flow graph contains at least one node with this flag set. /// the control flow graph contains at least one node with this flag set.
@@ -68,10 +68,10 @@ impl ControlFlowGraph {
} }
fn increment_actual_cost(&mut self, node_id: NodeId, cost: u32) { fn increment_actual_cost(&mut self, node_id: NodeId, cost: u32) {
self.get_node_mut(node_id).actual_cost += cost; self.get_node_mut(node_id).actual_cost += u64::from(cost);
} }
fn increment_charged_cost(&mut self, node_id: NodeId, cost: u32) { fn increment_charged_cost(&mut self, node_id: NodeId, cost: u64) {
self.get_node_mut(node_id).charged_cost += cost; self.get_node_mut(node_id).charged_cost += cost;
} }
@@ -267,9 +267,9 @@ fn validate_graph_gas_costs(graph: &ControlFlowGraph) -> bool {
fn visit( fn visit(
graph: &ControlFlowGraph, graph: &ControlFlowGraph,
node_id: NodeId, node_id: NodeId,
mut total_actual: u32, mut total_actual: u64,
mut total_charged: u32, mut total_charged: u64,
loop_costs: &mut Map<NodeId, (u32, u32)>, loop_costs: &mut Map<NodeId, (u64, u64)>,
) -> bool { ) -> bool {
let node = graph.get_node(node_id); let node = graph.get_node(node_id);
+1 -3
View File
@@ -1,8 +1,6 @@
#![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(not(feature = "std"), no_std)]
extern crate alloc; extern crate alloc;
#[macro_use]
extern crate log;
mod export_globals; mod export_globals;
pub mod gas_metering; pub mod gas_metering;
@@ -10,4 +8,4 @@ mod stack_limiter;
pub use export_globals::export_mutable_globals; pub use export_globals::export_mutable_globals;
pub use parity_wasm; pub use parity_wasm;
pub use stack_limiter::{compute_stack_cost, inject as inject_stack_limiter}; pub use stack_limiter::inject as inject_stack_limiter;
+205 -306
View File
@@ -1,36 +1,23 @@
use super::resolve_func_type;
use alloc::vec::Vec; use alloc::vec::Vec;
use parity_wasm::elements::{self, BlockType, Type}; use parity_wasm::elements::{self, BlockType, Type};
#[cfg(feature = "sign_ext")] #[cfg(feature = "sign_ext")]
use parity_wasm::elements::SignExtInstruction; use parity_wasm::elements::SignExtInstruction;
use super::Context;
// The cost in stack items that should be charged per call of a function. This is // The cost in stack items that should be charged per call of a function. This is
// is a static cost that is added to each function call. This makes sense because even // is a static cost that is added to each function call. This makes sense because even
// if a function does not use any parameters or locals some stack space on the host // if a function does not use any parameters or locals some stack space on the host
// machine might be consumed to hold some context. // machine might be consumed to hold some context.
const ACTIVATION_FRAME_COST: u32 = 1; const ACTIVATION_FRAME_COST: u32 = 2;
#[derive(Debug,PartialEq,Default,Clone,Copy)]
pub struct StackHeightStats {
pub activation_cost: u32,
pub max_height: u32,
pub max_control_height: u32,
pub locals_count: u32,
pub params_count: u32,
pub blocks_count: u32,
pub condbr_count: u32,
pub push_count: u32,
pub local_set_count: u32,
pub opcode_count: u32,
pub total_cost: u32,
}
/// Control stack frame. /// Control stack frame.
#[derive(Debug)] #[derive(Debug)]
struct Frame { struct Frame {
/// Counts the nesting level of unreachable code. 0 if currently processed code is reachable /// Stack becomes polymorphic only after an instruction that
unreachable_depth: u32, /// never passes control further was executed.
is_polymorphic: bool,
/// Count of values which will be pushed after the exit /// Count of values which will be pushed after the exit
/// from the current block. /// from the current block.
@@ -56,7 +43,7 @@ struct Stack {
impl Stack { impl Stack {
fn new() -> Stack { fn new() -> Stack {
Stack { height: 0, control_stack: Vec::new() } Stack { height: ACTIVATION_FRAME_COST, control_stack: Vec::new() }
} }
/// Returns current height of the value stack. /// Returns current height of the value stack.
@@ -64,10 +51,6 @@ impl Stack {
self.height self.height
} }
fn control_height(&self) -> u32 {
self.control_stack.len() as u32
}
/// Returns a reference to a frame by specified depth relative to the top of /// Returns a reference to a frame by specified depth relative to the top of
/// control stack. /// control stack.
fn frame(&self, rel_depth: u32) -> Result<&Frame, &'static str> { fn frame(&self, rel_depth: u32) -> Result<&Frame, &'static str> {
@@ -78,27 +61,14 @@ impl Stack {
} }
/// Mark successive instructions as unreachable. /// Mark successive instructions as unreachable.
///
/// This effectively makes stack polymorphic.
fn mark_unreachable(&mut self) -> Result<(), &'static str> { fn mark_unreachable(&mut self) -> Result<(), &'static str> {
let top_frame = self.control_stack.last_mut().ok_or("control stack must be non-empty")?; let top_frame = self.control_stack.last_mut().ok_or("stack must be non-empty")?;
top_frame.unreachable_depth = 1; top_frame.is_polymorphic = true;
Ok(()) Ok(())
} }
/// Increase nesting level of unreachable code
fn push_unreachable(&mut self) -> Result<(), &'static str> {
let top_frame = self.control_stack.last_mut().ok_or("control stack must be non-empty")?;
top_frame.unreachable_depth += 1;
Ok(())
}
/// Decrease nesting level of unrechable code (probably making it reachable)
fn pop_unreachable(&mut self) -> Result<u32, &'static str> {
let top_frame = self.control_stack.last_mut().ok_or("control stack must be non-empty")?;
top_frame.unreachable_depth =
top_frame.unreachable_depth.checked_sub(1).ok_or("unreachable code underflow")?;
Ok(top_frame.unreachable_depth)
}
/// Push control frame into the control stack. /// Push control frame into the control stack.
fn push_frame(&mut self, frame: Frame) { fn push_frame(&mut self, frame: Frame) {
self.control_stack.push(frame); self.control_stack.push(frame);
@@ -132,6 +102,19 @@ impl Stack {
if value_count == 0 { if value_count == 0 {
return Ok(()) return Ok(())
} }
{
let top_frame = self.frame(0)?;
if self.height == top_frame.start_height {
// It is an error to pop more values than was pushed in the current frame
// (ie pop values pushed in the parent frame), unless the frame became
// polymorphic.
return if top_frame.is_polymorphic {
Ok(())
} else {
return Err("trying to pop more values than pushed")
}
}
}
self.height = self.height.checked_sub(value_count).ok_or("stack underflow")?; self.height = self.height.checked_sub(value_count).ok_or("stack underflow")?;
@@ -140,11 +123,13 @@ impl Stack {
} }
/// This function expects the function to be validated. /// This function expects the function to be validated.
pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightStats, &'static str> { pub fn compute(
func_idx: u32,
ctx: &Context,
module: &elements::Module,
) -> Result<u32, &'static str> {
use parity_wasm::elements::Instruction::*; use parity_wasm::elements::Instruction::*;
trace!("Processing function index {}", func_idx);
let func_section = module.function_section().ok_or("No function section")?; let func_section = module.function_section().ok_or("No function section")?;
let code_section = module.code_section().ok_or("No code section")?; let code_section = module.code_section().ok_or("No code section")?;
let type_section = module.type_section().ok_or("No type section")?; let type_section = module.type_section().ok_or("No type section")?;
@@ -167,66 +152,31 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
let mut stack = Stack::new(); let mut stack = Stack::new();
let mut max_height: u32 = 0; let mut max_height: u32 = 0;
let mut max_control_height: u32 = 0; let mut pc = 0;
let mut locals_count: u32 = 0;
for local_group in body.locals() {
locals_count =
locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
}
let params_count = func_signature.params().len() as u32;
let mut blocks_count = 0u32;
let mut condbr_count = 0u32;
let mut push_count = 0u32;
let mut local_set_count = 0u32;
// Add implicit frame for the function. Breaks to this frame and execution of // Add implicit frame for the function. Breaks to this frame and execution of
// the last end should deal with this frame. // the last end should deal with this frame.
let func_arity = func_signature.results().len() as u32; let func_arity = func_signature.results().len() as u32;
stack.push_frame(Frame { stack.push_frame(Frame {
unreachable_depth: 0, is_polymorphic: false,
end_arity: func_arity, end_arity: func_arity,
branch_arity: func_arity, branch_arity: func_arity,
start_height: 0, start_height: 0,
}); });
for opcode in instructions.elements() { loop {
if stack.frame(0)?.unreachable_depth > 0 { if pc >= instructions.elements().len() {
match opcode { break
Block(_) | Loop(_) | If(_) => {
trace!("Entering unreachable block {:?}", opcode);
stack.push_unreachable()?;
continue
},
Else => {
let depth = stack.pop_unreachable()?;
if depth == 0 {
trace!("Transiting from unreachable If body to reachable Else block");
} else {
trace!("Processing unreachable Else");
stack.push_unreachable()?;
continue
}
},
End => {
let depth = stack.pop_unreachable()?;
if depth == 0 {
trace!("Exiting unreachable code");
} else {
trace!("Exiting unreachable block");
continue
}
},
_ => {
trace!("Skipping unreachable instruction {:?}", opcode);
continue
},
}
} }
assert_eq!(stack.frame(0)?.unreachable_depth, 0); // If current value stack is higher than maximal height observed so far,
trace!("Processing opcode {:?}", opcode); // save the new height.
// However, we don't increase maximal value in unreachable code.
if stack.height() > max_height && !stack.frame(0)?.is_polymorphic {
max_height = stack.height();
}
let opcode = &instructions.elements()[pc];
match opcode { match opcode {
Nop => {}, Nop => {},
@@ -236,18 +186,17 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
if let If(_) = *opcode { if let If(_) = *opcode {
stack.pop_values(1)?; stack.pop_values(1)?;
} }
let height = stack.height();
stack.push_frame(Frame { stack.push_frame(Frame {
unreachable_depth: 0, is_polymorphic: false,
end_arity, end_arity,
branch_arity, branch_arity,
start_height: stack.height(), start_height: height,
}); });
blocks_count += 1;
}, },
Else => { Else => {
let frame = stack.pop_frame()?; // The frame at the top should be pushed by `If`. So we leave
stack.trunc(frame.start_height); // it as is.
stack.push_frame(frame);
}, },
End => { End => {
let frame = stack.pop_frame()?; let frame = stack.pop_frame()?;
@@ -276,12 +225,18 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// Push values back. // Push values back.
stack.push_values(target_arity)?; stack.push_values(target_arity)?;
condbr_count += 1;
}, },
BrTable(br_table_data) => { BrTable(br_table_data) => {
let arity_of_default = stack.frame(br_table_data.default)?.branch_arity; let arity_of_default = stack.frame(br_table_data.default)?.branch_arity;
// Check that all jump targets have an equal arities.
for target in &*br_table_data.table {
let arity = stack.frame(*target)?.branch_arity;
if arity != arity_of_default {
return Err("Arity of all jump-targets must be equal")
}
}
// Because all jump targets have an equal arities, we can just take arity of // Because all jump targets have an equal arities, we can just take arity of
// the default branch. // the default branch.
stack.pop_values(arity_of_default)?; stack.pop_values(arity_of_default)?;
@@ -289,8 +244,6 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// This instruction doesn't let control flow to go further, since the control flow // This instruction doesn't let control flow to go further, since the control flow
// should take either one of branches depending on the value or the default branch. // should take either one of branches depending on the value or the default branch.
stack.mark_unreachable()?; stack.mark_unreachable()?;
condbr_count += 1;
}, },
Return => { Return => {
// Pop return values of the function. Mark successive instructions as unreachable // Pop return values of the function. Mark successive instructions as unreachable
@@ -298,8 +251,10 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
stack.pop_values(func_arity)?; stack.pop_values(func_arity)?;
stack.mark_unreachable()?; stack.mark_unreachable()?;
}, },
Call(idx) => { Call(fn_idx) => {
let ty = resolve_func_type(*idx, module)?; let ty_idx = ctx.func_type(*fn_idx).ok_or("function idx is not found in the func types list")?;
let Type::Function(ty) =
type_section.types().get(ty_idx as usize).ok_or("Type not found")?;
// Pop values for arguments of the function. // Pop values for arguments of the function.
stack.pop_values(ty.params().len() as u32)?; stack.pop_values(ty.params().len() as u32)?;
@@ -332,26 +287,21 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// Push the selected value. // Push the selected value.
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
GetLocal(_) => { GetLocal(_) => {
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
SetLocal(_) => { SetLocal(_) => {
stack.pop_values(1)?; stack.pop_values(1)?;
local_set_count += 1;
}, },
TeeLocal(_) => { TeeLocal(_) => {
// This instruction pops and pushes the value, so // This instruction pops and pushes the value, so
// effectively it doesn't modify the stack height. // effectively it doesn't modify the stack height.
stack.pop_values(1)?; stack.pop_values(1)?;
stack.push_values(1)?; stack.push_values(1)?;
local_set_count += 1;
}, },
GetGlobal(_) => { GetGlobal(_) => {
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
SetGlobal(_) => { SetGlobal(_) => {
stack.pop_values(1)?; stack.pop_values(1)?;
@@ -374,7 +324,6 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// which effictively don't modify the stack height. // which effictively don't modify the stack height.
stack.pop_values(1)?; stack.pop_values(1)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
I32Store(_, _) | I32Store(_, _) |
@@ -393,19 +342,16 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
CurrentMemory(_) => { CurrentMemory(_) => {
// Pushes current memory size // Pushes current memory size
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
GrowMemory(_) => { GrowMemory(_) => {
// Grow memory takes the value of pages to grow and pushes // Grow memory takes the value of pages to grow and pushes
stack.pop_values(1)?; stack.pop_values(1)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
I32Const(_) | I64Const(_) | F32Const(_) | F64Const(_) => { I32Const(_) | I64Const(_) | F32Const(_) | F64Const(_) => {
// These instructions just push the single literal value onto the stack. // These instructions just push the single literal value onto the stack.
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
I32Eqz | I64Eqz => { I32Eqz | I64Eqz => {
@@ -413,7 +359,6 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// the result of the comparison. // the result of the comparison.
stack.pop_values(1)?; stack.pop_values(1)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU | I32LeS | I32LeU | I32GeS | I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU | I32LeS | I32LeU | I32GeS |
@@ -423,7 +368,6 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// Comparison operations take two operands and produce one result. // Comparison operations take two operands and produce one result.
stack.pop_values(2)?; stack.pop_values(2)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
I32Clz | I32Ctz | I32Popcnt | I64Clz | I64Ctz | I64Popcnt | F32Abs | F32Neg | I32Clz | I32Ctz | I32Popcnt | I64Clz | I64Ctz | I64Popcnt | F32Abs | F32Neg |
@@ -432,7 +376,6 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// Unary operators take one operand and produce one result. // Unary operators take one operand and produce one result.
stack.pop_values(1)?; stack.pop_values(1)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU | I32And | I32Or | I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU | I32And | I32Or |
@@ -444,7 +387,6 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// Binary operators take two operands and produce one result. // Binary operators take two operands and produce one result.
stack.pop_values(2)?; stack.pop_values(2)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
I32WrapI64 | I32TruncSF32 | I32TruncUF32 | I32TruncSF64 | I32TruncUF64 | I32WrapI64 | I32TruncSF32 | I32TruncUF32 | I32TruncSF64 | I32TruncUF64 |
@@ -456,7 +398,6 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
// Conversion operators take one value and produce one result. // Conversion operators take one value and produce one result.
stack.pop_values(1)?; stack.pop_values(1)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
#[cfg(feature = "sign_ext")] #[cfg(feature = "sign_ext")]
@@ -467,213 +408,171 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<StackHeightSt
SignExt(SignExtInstruction::I64Extend32S) => { SignExt(SignExtInstruction::I64Extend32S) => {
stack.pop_values(1)?; stack.pop_values(1)?;
stack.push_values(1)?; stack.push_values(1)?;
push_count += 1;
}, },
} }
pc += 1;
// If current value/control stack is higher than maximal height observed so far,
// save the new height.
if stack.height() > max_height {
max_height = stack.height();
}
if stack.control_height() > max_control_height {
max_control_height = stack.control_height();
}
trace!(
" Stack height: {}, control stack height: {}",
stack.height(),
stack.control_height()
);
} }
assert_eq!(stack.control_height(), 0); Ok(max_height)
assert_eq!(stack.height(), func_signature.results().len() as u32);
let res = StackHeightStats {
activation_cost: ACTIVATION_FRAME_COST,
max_height: max_height,
max_control_height: max_control_height,
locals_count: locals_count,
params_count: params_count,
blocks_count: blocks_count,
condbr_count: condbr_count,
push_count: push_count,
local_set_count: local_set_count,
opcode_count: instructions.elements().len() as u32,
// total_cost: (11.749 * params_count as f64 - 0.4888 * locals_count as f64 + 14.8169 * max_height as f64 - 5.1594 * max_control_height as f64 - 24.4941) as u32
total_cost: ACTIVATION_FRAME_COST + 2 * max_height + max_control_height + locals_count + 2 * params_count,
};
trace!("Result: {:?}", res);
Ok(res)
} }
// #[cfg(test)] #[cfg(test)]
// mod tests { mod tests {
// use super::*; use super::ACTIVATION_FRAME_COST;
// use parity_wasm::elements; use crate::stack_limiter::prepare_context;
use parity_wasm::elements;
// fn parse_wat(source: &str) -> elements::Module { fn parse_wat(source: &str) -> elements::Module {
// elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm")) elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm"))
// .expect("Failed to deserialize the module") .expect("Failed to deserialize the module")
// } }
// #[test] fn test_compute(func_idx: u32, source: &str) -> u32 {
// fn simple_test() { let module = parse_wat(source);
// let module = parse_wat( let ctx = prepare_context(&module, 0).unwrap();
// r#" ctx.stack_cost(func_idx).unwrap()
// (module }
// (func
// i32.const 1
// i32.const 2
// i32.const 3
// drop
// drop
// drop
// )
// )
// "#,
// );
// let height = compute(0, &module).unwrap(); #[test]
// assert_eq!(height, ACTIVATION_FRAME_COST + 3 + 1 + 0 + 0); fn simple_test() {
// } let module = r#"
(module
(func
i32.const 1
i32.const 2
i32.const 3
drop
drop
drop
)
)
"#;
// #[test] let height = test_compute(0, module);
// fn implicit_and_explicit_return() { assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
// let module = parse_wat( }
// r#"
// (module
// (func (result i32)
// i32.const 0
// return
// )
// )
// "#,
// );
// let height = compute(0, &module).unwrap(); #[test]
// assert_eq!(height, ACTIVATION_FRAME_COST + 1 + 1 + 0 + 0); fn implicit_and_explicit_return() {
// } let module = r#"
(module
(func (result i32)
i32.const 0
return
)
)
"#;
// #[test] let height = test_compute(0, module);
// fn dont_count_in_unreachable() { assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
// let module = parse_wat( }
// r#"
// (module
// (memory 0)
// (func (result i32)
// unreachable
// grow_memory
// )
// )
// "#,
// );
// let height = compute(0, &module).unwrap(); #[test]
// assert_eq!(height, ACTIVATION_FRAME_COST + 0 + 1 + 0 + 0); fn dont_count_in_unreachable() {
// } let module = r#"
(module
(memory 0)
(func (result i32)
unreachable
grow_memory
)
)
"#;
// #[test] let height = test_compute(0, module);
// fn yet_another_test() { assert_eq!(height, ACTIVATION_FRAME_COST);
// let module = parse_wat( }
// r#"
// (module
// (memory 0)
// (func
// ;; Push two values and then pop them.
// ;; This will make max depth to be equal to 2.
// i32.const 0
// i32.const 1
// drop
// drop
// ;; Code after `unreachable` shouldn't have an effect #[test]
// ;; on the max depth. fn yet_another_test() {
// unreachable let module = r#"
// i32.const 0 (module
// i32.const 1 (memory 0)
// i32.const 2 (func
// ) ;; Push two values and then pop them.
// ) ;; This will make max depth to be equal to 2.
// "#, i32.const 0
// ); i32.const 1
drop
drop
// let height = compute(0, &module).unwrap(); ;; Code after `unreachable` shouldn't have an effect
// assert_eq!(height, 2 + ACTIVATION_FRAME_COST); ;; on the max depth.
// } unreachable
i32.const 0
i32.const 1
i32.const 2
)
)
"#;
// #[test] let height = test_compute(0, module);
// fn call_indirect() { assert_eq!(height, 2 + ACTIVATION_FRAME_COST);
// let module = parse_wat( }
// r#"
// (module
// (table $ptr 1 1 funcref)
// (elem $ptr (i32.const 0) func 1)
// (func $main
// (call_indirect (i32.const 0))
// (call_indirect (i32.const 0))
// (call_indirect (i32.const 0))
// )
// (func $callee
// i64.const 42
// drop
// )
// )
// "#,
// );
// let height = compute(0, &module).unwrap(); #[test]
// assert_eq!(height, 1 + ACTIVATION_FRAME_COST); fn call_indirect() {
// } let module = r#"
(module
(table $ptr 1 1 funcref)
(elem $ptr (i32.const 0) func 1)
(func $main
(call_indirect (i32.const 0))
(call_indirect (i32.const 0))
(call_indirect (i32.const 0))
)
(func $callee
i64.const 42
drop
)
)
"#;
// #[test] let height = test_compute(0, module);
// fn breaks() { assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
// let module = parse_wat( }
// r#"
// (module
// (func $main
// block (result i32)
// block (result i32)
// i32.const 99
// br 1
// end
// end
// drop
// )
// )
// "#,
// );
// let height = compute(0, &module).unwrap(); #[test]
// assert_eq!(height, 1 + ACTIVATION_FRAME_COST); fn breaks() {
// } let module = r#"
(module
(func $main
block (result i32)
block (result i32)
i32.const 99
br 1
end
end
drop
)
)
"#;
// #[test] let height = test_compute(0, module);
// fn if_else_works() { assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
// let module = parse_wat( }
// r#"
// (module
// (func $main
// i32.const 7
// i32.const 1
// if (result i32)
// i32.const 42
// else
// i32.const 99
// end
// i32.const 97
// drop
// drop
// drop
// )
// )
// "#,
// );
// let height = compute(0, &module).unwrap(); #[test]
// assert_eq!(height, 3 + ACTIVATION_FRAME_COST); fn if_else_works() {
// } let module = r#"
// } (module
(func $main
i32.const 7
i32.const 1
if (result i32)
i32.const 42
else
i32.const 99
end
i32.const 97
drop
drop
drop
)
)
"#;
let height = test_compute(0, module);
assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
}
}
+136 -91
View File
@@ -6,7 +6,6 @@ use parity_wasm::{
builder, builder,
elements::{self, Instruction, Instructions, Type}, elements::{self, Instruction, Instructions, Type},
}; };
pub use max_height::StackHeightStats;
/// Macro to generate preamble and postamble. /// Macro to generate preamble and postamble.
macro_rules! instrument_call { macro_rules! instrument_call {
@@ -40,8 +39,16 @@ mod max_height;
mod thunk; mod thunk;
pub struct Context { pub struct Context {
/// Number of functions that the module imports. Required to convert defined functions indicies
/// into the global function index space.
func_imports: u32,
/// For each function in the function space this vector stores the respective type index.
func_types: Vec<u32>,
/// The index of the global variable that contains the current stack height.
stack_height_global_idx: u32, stack_height_global_idx: u32,
func_stack_costs: Vec<StackHeightStats>, /// Logical stack costs for each function in the function space. Imported functions have cost
/// of 0.
func_stack_costs: Vec<u32>,
stack_limit: u32, stack_limit: u32,
} }
@@ -53,11 +60,12 @@ impl Context {
/// Returns `stack_cost` for `func_idx`. /// Returns `stack_cost` for `func_idx`.
fn stack_cost(&self, func_idx: u32) -> Option<u32> { fn stack_cost(&self, func_idx: u32) -> Option<u32> {
if let Some(stats) = self.func_stack_costs.get(func_idx as usize) { self.func_stack_costs.get(func_idx as usize).cloned()
Some(stats.total_cost) }
} else {
None /// Returns a reference to the function type index given by the index into the function space.
} fn func_type(&self, func_idx: u32) -> Option<u32> {
self.func_types.get(func_idx as usize).copied()
} }
/// Returns stack limit specified by the rules. /// Returns stack limit specified by the rules.
@@ -120,20 +128,106 @@ pub fn inject(
mut module: elements::Module, mut module: elements::Module,
stack_limit: u32, stack_limit: u32,
) -> Result<elements::Module, &'static str> { ) -> Result<elements::Module, &'static str> {
let mut ctx = Context { let mut ctx = prepare_context(&module, stack_limit)?;
stack_height_global_idx: generate_stack_height_global(&mut module),
func_stack_costs: compute_stack_costs(&module)?,
stack_limit,
};
instrument_functions(&mut ctx, &mut module)?; generate_stack_height_global(&mut ctx.stack_height_global_idx, &mut module)?;
instrument_functions(&ctx, &mut module)?;
let module = thunk::generate_thunks(&mut ctx, module)?; let module = thunk::generate_thunks(&mut ctx, module)?;
Ok(module) Ok(module)
} }
fn prepare_context(module: &elements::Module, stack_limit: u32) -> Result<Context, &'static str> {
let mut ctx = Context {
func_imports: module.import_count(elements::ImportCountType::Function) as u32,
func_types: Vec::new(),
stack_height_global_idx: 0,
func_stack_costs: Vec::new(),
stack_limit,
};
collect_func_types(&mut ctx, &module)?;
compute_stack_costs(&mut ctx, &module)?;
Ok(ctx)
}
fn collect_func_types(ctx: &mut Context, module: &elements::Module) -> Result<(), &'static str> {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
let imports = module.import_section().map(|is| is.entries()).unwrap_or(&[]);
let ensure_ty = |sig_idx: u32| -> Result<(), &'static str> {
let Type::Function(_) = types
.get(sig_idx as usize)
.ok_or("The signature as specified by a function isn't defined")?;
Ok(())
};
for import in imports {
if let elements::External::Function(sig_idx) = import.external() {
ensure_ty(*sig_idx)?;
ctx.func_types.push(*sig_idx);
}
}
for def_func_idx in functions {
ensure_ty(def_func_idx.type_ref())?;
ctx.func_types.push(def_func_idx.type_ref());
}
Ok(())
}
/// Calculate stack costs for all functions in the function space.
///
/// The function space consists of the imported functions followed by defined functions.
/// All imported functions assumed to have the cost of 0.
fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), &'static str> {
for _ in 0..ctx.func_imports {
ctx.func_stack_costs.push(0);
}
let def_func_n = module.function_section().map(|fs| fs.entries().len()).unwrap_or(0) as u32;
for def_func_idx in 0..def_func_n {
let cost = compute_stack_cost(def_func_idx, ctx, module)?;
ctx.func_stack_costs.push(cost);
}
Ok(())
}
/// Computes the stack cost of a given function. The function is specified by its index in the
/// declared function space.
///
/// Stack cost of a given function is the sum of it's locals count (that is,
/// number of arguments plus number of local variables) and the maximal stack
/// height.
fn compute_stack_cost(
def_func_idx: u32,
ctx: &Context,
module: &elements::Module,
) -> Result<u32, &'static str> {
let code_section =
module.code_section().ok_or("Due to validation code section should exists")?;
let body = &code_section
.bodies()
.get(def_func_idx as usize)
.ok_or("Function body is out of bounds")?;
let mut locals_count: u32 = 0;
for local_group in body.locals() {
locals_count =
locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
}
let max_stack_height = max_height::compute(def_func_idx, ctx, module)?;
locals_count
.checked_add(max_stack_height)
.ok_or("Overflow in adding locals_count and max_stack_height")
}
/// Generate a new global that will be used for tracking current stack height. /// Generate a new global that will be used for tracking current stack height.
fn generate_stack_height_global(module: &mut elements::Module) -> u32 { fn generate_stack_height_global(
stack_height_global_idx: &mut u32,
module: &mut elements::Module,
) -> Result<(), &'static str> {
let global_entry = builder::global() let global_entry = builder::global()
.value_type() .value_type()
.i32() .i32()
@@ -145,54 +239,41 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
for section in module.sections_mut() { for section in module.sections_mut() {
if let elements::Section::Global(gs) = section { if let elements::Section::Global(gs) = section {
gs.entries_mut().push(global_entry); gs.entries_mut().push(global_entry);
return (gs.entries().len() as u32) - 1 *stack_height_global_idx = (gs.entries().len() as u32) - 1;
return Ok(());
} }
} }
// Existing section not found, create one! // Existing section not found, create one!
module //
.sections_mut() // It's a bit tricky since the sections have a strict prescribed order.
.push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry]))); let global_section = elements::GlobalSection::with_entries(vec![global_entry]);
0 let prec_index = module
} .sections()
.iter()
/// Calculate stack costs for all functions. .rposition(|section| {
/// use elements::Section::*;
/// Returns a vector with a stack cost for each function, including imports. match section {
fn compute_stack_costs(module: &elements::Module) -> Result<Vec<StackHeightStats>, &'static str> { Type(_) | Import(_) | Function(_) | Table(_) | Memory(_) => true,
let func_imports = module.import_count(elements::ImportCountType::Function); _ => false,
// TODO: optimize!
(0..module.functions_space())
.map(|func_idx| {
if func_idx < func_imports {
// We can't calculate stack_cost of the import functions.
Ok(Default::default())
} else {
compute_stack_cost(func_idx as u32, module)
} }
}) })
.collect() .ok_or("generate stack height global hasn't found any preceding sections")?;
// now `prec_index` points to the last section preceding the `global_section`. It's guaranteed that at least
// one of those functions is present. Therefore, the candidate position for the global section is the following
// one. However, technically, custom sections could occupy any place between the well-known sections.
//
// Now, regarding `+1` here. `insert` panics iff `index > len`. `prec_index + 1` can only be equal to `len`.
module
.sections_mut()
.insert(prec_index + 1, elements::Section::Global(global_section));
// First entry in the brand new globals section.
*stack_height_global_idx = 0;
Ok(())
} }
/// Stack cost of the given *defined* function is the sum of it's locals count (that is, fn instrument_functions(ctx: &Context, module: &mut elements::Module) -> Result<(), &'static str> {
/// number of arguments plus number of local variables) and the maximal stack
/// height.
pub fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<StackHeightStats, &'static str> {
// To calculate the cost of a function we need to convert index from
// function index space to defined function spaces.
let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
let defined_func_idx = func_idx
.checked_sub(func_imports)
.ok_or("This should be a index of a defined function")?;
max_height::compute(defined_func_idx, module)
}
fn instrument_functions(
ctx: &mut Context,
module: &mut elements::Module,
) -> Result<(), &'static str> {
for section in module.sections_mut() { for section in module.sections_mut() {
if let elements::Section::Code(code_section) = section { if let elements::Section::Code(code_section) = section {
for func_body in code_section.bodies_mut() { for func_body in code_section.bodies_mut() {
@@ -230,7 +311,7 @@ fn instrument_functions(
/// ///
/// drop /// drop
/// ``` /// ```
fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), &'static str> { fn instrument_function(ctx: &Context, func: &mut Instructions) -> Result<(), &'static str> {
use Instruction::*; use Instruction::*;
struct InstrumentCall { struct InstrumentCall {
@@ -297,42 +378,6 @@ fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(),
Ok(()) Ok(())
} }
fn resolve_func_type(
func_idx: u32,
module: &elements::Module,
) -> Result<&elements::FunctionType, &'static str> {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
let func_imports = module.import_count(elements::ImportCountType::Function);
let sig_idx = if func_idx < func_imports as u32 {
module
.import_section()
.expect("function import count is not zero; import section must exists; qed")
.entries()
.iter()
.filter_map(|entry| match entry.external() {
elements::External::Function(idx) => Some(*idx),
_ => None,
})
.nth(func_idx as usize)
.expect(
"func_idx is less than function imports count;
nth function import must be `Some`;
qed",
)
} else {
functions
.get(func_idx as usize - func_imports)
.ok_or("Function at the specified index is not defined")?
.type_ref()
};
let Type::Function(ty) = types
.get(sig_idx as usize)
.ok_or("The signature as specified by a function isn't defined")?;
Ok(ty)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
+58 -43
View File
@@ -1,17 +1,17 @@
#[cfg(not(features = "std"))] #[cfg(not(features = "std"))]
use alloc::collections::BTreeMap as Map; use alloc::collections::BTreeMap as Map;
use alloc::vec::Vec; use alloc::vec::Vec;
use parity_wasm::{ use parity_wasm::elements::{self, Internal};
builder,
elements::{self, FunctionType, Internal},
};
#[cfg(features = "std")] #[cfg(features = "std")]
use std::collections::HashMap as Map; use std::collections::HashMap as Map;
use super::{resolve_func_type, Context}; use super::Context;
struct Thunk { struct Thunk {
signature: FunctionType, /// The index of the signature in the type section.
type_idx: u32,
/// The number of parameters the function has.
param_num: u32,
// Index in function space of this thunk. // Index in function space of this thunk.
idx: Option<u32>, idx: Option<u32>,
callee_stack_cost: u32, callee_stack_cost: u32,
@@ -19,10 +19,11 @@ struct Thunk {
pub fn generate_thunks( pub fn generate_thunks(
ctx: &mut Context, ctx: &mut Context,
module: elements::Module, mut module: elements::Module,
) -> Result<elements::Module, &'static str> { ) -> Result<elements::Module, &'static str> {
// First, we need to collect all function indices that should be replaced by thunks // First, we need to collect all function indices that should be replaced by thunks
let mut replacement_map: Map<u32, Thunk> = { let mut replacement_map: Map<u32, Thunk> = {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]); let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]);
let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]); let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]);
let start_func_idx = module.start_section(); let start_func_idx = module.start_section();
@@ -45,14 +46,12 @@ pub fn generate_thunks(
// Don't generate a thunk if stack_cost of a callee is zero. // Don't generate a thunk if stack_cost of a callee is zero.
if callee_stack_cost != 0 { if callee_stack_cost != 0 {
replacement_map.insert( let type_idx = ctx.func_type(func_idx).ok_or("type idx for thunk not found")?;
func_idx, let elements::Type::Function(func_ty) =
Thunk { types.get(type_idx as usize).ok_or("sig for thunk is not found")?;
signature: resolve_func_type(func_idx, &module)?.clone(), let param_num = func_ty.params().len() as u32;
idx: None, replacement_map
callee_stack_cost, .insert(func_idx, Thunk { type_idx, param_num, idx: None, callee_stack_cost });
},
);
} }
} }
@@ -61,10 +60,6 @@ pub fn generate_thunks(
// Then, we generate a thunk for each original function. // Then, we generate a thunk for each original function.
// Save current func_idx
let mut next_func_idx = module.functions_space() as u32;
let mut mbuilder = builder::from_module(module);
for (func_idx, thunk) in replacement_map.iter_mut() { for (func_idx, thunk) in replacement_map.iter_mut() {
let instrumented_call = instrument_call!( let instrumented_call = instrument_call!(
*func_idx, *func_idx,
@@ -77,32 +72,23 @@ pub fn generate_thunks(
// - instrumented call // - instrumented call
// - end // - end
let mut thunk_body: Vec<elements::Instruction> = let mut thunk_body: Vec<elements::Instruction> =
Vec::with_capacity(thunk.signature.params().len() + instrumented_call.len() + 1); Vec::with_capacity(thunk.param_num as usize + instrumented_call.len() + 1);
for (arg_idx, _) in thunk.signature.params().iter().enumerate() { for arg_idx in 0..thunk.param_num {
thunk_body.push(elements::Instruction::GetLocal(arg_idx as u32)); thunk_body.push(elements::Instruction::GetLocal(arg_idx));
} }
thunk_body.extend_from_slice(&instrumented_call); thunk_body.extend_from_slice(&instrumented_call);
thunk_body.push(elements::Instruction::End); thunk_body.push(elements::Instruction::End);
// TODO: Don't generate a signature, but find an existing one. let func_idx = insert_function(
ctx,
mbuilder = mbuilder &mut module,
.function() thunk.type_idx,
// Signature of the thunk should match the original function signature. Vec::new(), // No declared local variables.
.signature() elements::Instructions::new(thunk_body),
.with_params(thunk.signature.params().to_vec()) )?;
.with_results(thunk.signature.results().to_vec()) thunk.idx = Some(func_idx);
.build()
.body()
.with_instructions(elements::Instructions::new(thunk_body))
.build()
.build();
thunk.idx = Some(next_func_idx);
next_func_idx += 1;
} }
let mut module = mbuilder.build();
// And finally, fixup thunks in export and table sections. // And finally, fixup thunks in export and table sections.
@@ -118,18 +104,20 @@ pub fn generate_thunks(
for section in module.sections_mut() { for section in module.sections_mut() {
match section { match section {
elements::Section::Export(export_section) => elements::Section::Export(export_section) => {
for entry in export_section.entries_mut() { for entry in export_section.entries_mut() {
if let Internal::Function(function_idx) = entry.internal_mut() { if let Internal::Function(function_idx) = entry.internal_mut() {
fixup(function_idx) fixup(function_idx)
} }
}, }
elements::Section::Element(elem_section) => },
elements::Section::Element(elem_section) => {
for segment in elem_section.entries_mut() { for segment in elem_section.entries_mut() {
for function_idx in segment.members_mut() { for function_idx in segment.members_mut() {
fixup(function_idx) fixup(function_idx)
} }
}, }
},
elements::Section::Start(start_idx) => fixup(start_idx), elements::Section::Start(start_idx) => fixup(start_idx),
_ => {}, _ => {},
} }
@@ -137,3 +125,30 @@ pub fn generate_thunks(
Ok(module) Ok(module)
} }
/// Inserts a new function into the module and returns it's index in the function space.
///
/// Specifically, inserts entires into the function section and the code section.
fn insert_function(
ctx: &Context,
module: &mut elements::Module,
type_idx: u32,
locals: Vec<elements::Local>,
insns: elements::Instructions,
) -> Result<u32, &'static str> {
let funcs = module
.function_section_mut()
.ok_or("insert function no function section")?
.entries_mut();
let new_func_idx = ctx
.func_imports
.checked_add(funcs.len() as u32)
.ok_or("insert function func idx overflow")?;
funcs.push(elements::Func::new(type_idx));
let func_bodies =
module.code_section_mut().ok_or("insert function no code section")?.bodies_mut();
func_bodies.push(elements::FuncBody::new(locals, insns));
Ok(new_func_idx)
}
+3 -3
View File
@@ -1,10 +1,10 @@
(module (module
(type (;0;) (func (result i32))) (type (;0;) (func (result i32)))
(type (;1;) (func (param i32))) (type (;1;) (func (param i64)))
(import "env" "gas" (func (;0;) (type 1))) (import "env" "gas" (func (;0;) (type 1)))
(func $fibonacci_with_break (;1;) (type 0) (result i32) (func $fibonacci_with_break (;1;) (type 0) (result i32)
(local i32 i32) (local i32 i32)
i32.const 13 i64.const 13
call 0 call 0
block ;; label = @1 block ;; label = @1
i32.const 0 i32.const 0
@@ -18,7 +18,7 @@
local.set 1 local.set 1
i32.const 1 i32.const 1
br_if 0 (;@1;) br_if 0 (;@1;)
i32.const 5 i64.const 5
call 0 call 0
local.get 0 local.get 0
local.get 1 local.get 1
+3 -3
View File
@@ -1,10 +1,10 @@
(module (module
(type (;0;) (func (param i32 i32) (result i32))) (type (;0;) (func (param i32 i32) (result i32)))
(type (;1;) (func (param i32))) (type (;1;) (func (param i64)))
(import "env" "gas" (func (;0;) (type 1))) (import "env" "gas" (func (;0;) (type 1)))
(func $add_locals (;1;) (type 0) (param $x i32) (param $y i32) (result i32) (func $add_locals (;1;) (type 0) (param $x i32) (param $y i32) (result i32)
(local i32) (local i32)
i32.const 5 i64.const 5
call 0 call 0
local.get $x local.get $x
local.get $y local.get $y
@@ -13,7 +13,7 @@
local.get 2 local.get 2
) )
(func $add (;2;) (type 0) (param i32 i32) (result i32) (func $add (;2;) (type 0) (param i32 i32) (result i32)
i32.const 3 i64.const 3
call 0 call 0
local.get 0 local.get 0
local.get 1 local.get 1
+4 -4
View File
@@ -1,19 +1,19 @@
(module (module
(type (;0;) (func (param i32) (result i32))) (type (;0;) (func (param i32) (result i32)))
(type (;1;) (func (param i32))) (type (;1;) (func (param i64)))
(import "env" "gas" (func (;0;) (type 1))) (import "env" "gas" (func (;0;) (type 1)))
(func (;1;) (type 0) (param i32) (result i32) (func (;1;) (type 0) (param i32) (result i32)
i32.const 2 i64.const 2
call 0 call 0
i32.const 1 i32.const 1
if (result i32) ;; label = @1 if (result i32) ;; label = @1
i32.const 3 i64.const 3
call 0 call 0
local.get 0 local.get 0
i32.const 1 i32.const 1
i32.add i32.add
else else
i32.const 2 i64.const 2
call 0 call 0
local.get 0 local.get 0
i32.popcnt i32.popcnt
+5 -5
View File
@@ -1,16 +1,16 @@
(module (module
(type (;0;) (func)) (type (;0;) (func))
(type (;1;) (func (param i32))) (type (;1;) (func (param i64)))
(import "env" "gas" (func (;0;) (type 1))) (import "env" "gas" (func (;0;) (type 1)))
(func (;1;) (type 0) (func (;1;) (type 0)
i32.const 2 i64.const 2
call 0 call 0
i32.const 1 i32.const 1
if ;; label = @1 if ;; label = @1
i32.const 1 i64.const 1
call 0 call 0
loop ;; label = @2 loop ;; label = @2
i32.const 2 i64.const 2
call 0 call 0
i32.const 123 i32.const 123
drop drop
@@ -18,7 +18,7 @@
end end
) )
(func (;2;) (type 0) (func (;2;) (type 0)
i32.const 1 i64.const 1
call 0 call 0
block ;; label = @1 block ;; label = @1
end end
+2 -2
View File
@@ -1,12 +1,12 @@
(module (module
(type (;0;) (func (param i32 i32))) (type (;0;) (func (param i32 i32)))
(type (;1;) (func)) (type (;1;) (func))
(type (;2;) (func (param i32))) (type (;2;) (func (param i64)))
(import "env" "ext_return" (func $ext_return (;0;) (type 0))) (import "env" "ext_return" (func $ext_return (;0;) (type 0)))
(import "env" "memory" (memory (;0;) 1 1)) (import "env" "memory" (memory (;0;) 1 1))
(import "env" "gas" (func (;1;) (type 2))) (import "env" "gas" (func (;1;) (type 2)))
(func $start (;2;) (type 1) (func $start (;2;) (type 1)
i32.const 4 i64.const 4
call 1 call 1
i32.const 8 i32.const 8
i32.const 4 i32.const 4