Remove everything not needed by substrate

Also rename to wasm-instrument
This commit is contained in:
Alexander Theißen
2022-01-08 20:50:18 +02:00
parent b22696aaa5
commit e882111f92
33 changed files with 418 additions and 4720 deletions
+578
View File
@@ -0,0 +1,578 @@
use super::resolve_func_type;
use alloc::vec::Vec;
use parity_wasm::elements::{self, BlockType, Type};
#[cfg(feature = "sign_ext")]
use parity_wasm::elements::SignExtInstruction;
/// Control stack frame.
#[derive(Debug)]
struct Frame {
/// Stack becomes polymorphic only after an instruction that
/// never passes control further was executed.
is_polymorphic: bool,
/// Count of values which will be pushed after the exit
/// from the current block.
end_arity: u32,
/// Count of values which should be poped upon a branch to
/// this frame.
///
/// This might be diffirent from `end_arity` since branch
/// to the loop header can't take any values.
branch_arity: u32,
/// Stack height before entering in the block.
start_height: u32,
}
/// This is a compound stack that abstracts tracking height of the value stack
/// and manipulation of the control stack.
struct Stack {
height: u32,
control_stack: Vec<Frame>,
}
impl Stack {
fn new() -> Stack {
Stack { height: 0, control_stack: Vec::new() }
}
/// Returns current height of the value stack.
fn height(&self) -> u32 {
self.height
}
/// Returns a reference to a frame by specified depth relative to the top of
/// control stack.
fn frame(&self, rel_depth: u32) -> Result<&Frame, &'static str> {
let control_stack_height: usize = self.control_stack.len();
let last_idx = control_stack_height.checked_sub(1).ok_or("control stack is empty")?;
let idx = last_idx.checked_sub(rel_depth as usize).ok_or("control stack out-of-bounds")?;
Ok(&self.control_stack[idx])
}
/// Mark successive instructions as unreachable.
///
/// This effectively makes stack polymorphic.
fn mark_unreachable(&mut self) -> Result<(), &'static str> {
let top_frame = self.control_stack.last_mut().ok_or("stack must be non-empty")?;
top_frame.is_polymorphic = true;
Ok(())
}
/// Push control frame into the control stack.
fn push_frame(&mut self, frame: Frame) {
self.control_stack.push(frame);
}
/// Pop control frame from the control stack.
///
/// Returns `Err` if the control stack is empty.
fn pop_frame(&mut self) -> Result<Frame, &'static str> {
self.control_stack.pop().ok_or("stack must be non-empty")
}
/// Truncate the height of value stack to the specified height.
fn trunc(&mut self, new_height: u32) {
self.height = new_height;
}
/// Push specified number of values into the value stack.
///
/// Returns `Err` if the height overflow usize value.
fn push_values(&mut self, value_count: u32) -> Result<(), &'static str> {
self.height = self.height.checked_add(value_count).ok_or("stack overflow")?;
Ok(())
}
/// Pop specified number of values from the value stack.
///
/// Returns `Err` if the stack happen to be negative value after
/// values popped.
fn pop_values(&mut self, value_count: u32) -> Result<(), &'static str> {
if value_count == 0 {
return Ok(())
}
{
let top_frame = self.frame(0)?;
if self.height == top_frame.start_height {
// It is an error to pop more values than was pushed in the current frame
// (ie pop values pushed in the parent frame), unless the frame became
// polymorphic.
return if top_frame.is_polymorphic {
Ok(())
} else {
return Err("trying to pop more values than pushed")
}
}
}
self.height = self.height.checked_sub(value_count).ok_or("stack underflow")?;
Ok(())
}
}
/// This function expects the function to be validated.
pub fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
use parity_wasm::elements::Instruction::*;
let func_section = module.function_section().ok_or("No function section")?;
let code_section = module.code_section().ok_or("No code section")?;
let type_section = module.type_section().ok_or("No type section")?;
// Get a signature and a body of the specified function.
let func_sig_idx = func_section
.entries()
.get(func_idx as usize)
.ok_or("Function is not found in func section")?
.type_ref();
let Type::Function(func_signature) = type_section
.types()
.get(func_sig_idx as usize)
.ok_or("Function is not found in func section")?;
let body = code_section
.bodies()
.get(func_idx as usize)
.ok_or("Function body for the index isn't found")?;
let instructions = body.code();
let mut stack = Stack::new();
let mut max_height: u32 = 0;
let mut pc = 0;
// Add implicit frame for the function. Breaks to this frame and execution of
// the last end should deal with this frame.
let func_arity = func_signature.results().len() as u32;
stack.push_frame(Frame {
is_polymorphic: false,
end_arity: func_arity,
branch_arity: func_arity,
start_height: 0,
});
loop {
if pc >= instructions.elements().len() {
break
}
// If current value stack is higher than maximal height observed so far,
// save the new height.
// However, we don't increase maximal value in unreachable code.
if stack.height() > max_height && !stack.frame(0)?.is_polymorphic {
max_height = stack.height();
}
let opcode = &instructions.elements()[pc];
match opcode {
Nop => {},
Block(ty) | Loop(ty) | If(ty) => {
let end_arity = if *ty == BlockType::NoResult { 0 } else { 1 };
let branch_arity = if let Loop(_) = *opcode { 0 } else { end_arity };
if let If(_) = *opcode {
stack.pop_values(1)?;
}
let height = stack.height();
stack.push_frame(Frame {
is_polymorphic: false,
end_arity,
branch_arity,
start_height: height,
});
},
Else => {
// The frame at the top should be pushed by `If`. So we leave
// it as is.
},
End => {
let frame = stack.pop_frame()?;
stack.trunc(frame.start_height);
stack.push_values(frame.end_arity)?;
},
Unreachable => {
stack.mark_unreachable()?;
},
Br(target) => {
// Pop values for the destination block result.
let target_arity = stack.frame(*target)?.branch_arity;
stack.pop_values(target_arity)?;
// This instruction unconditionally transfers control to the specified block,
// thus all instruction until the end of the current block is deemed unreachable
stack.mark_unreachable()?;
},
BrIf(target) => {
// Pop values for the destination block result.
let target_arity = stack.frame(*target)?.branch_arity;
stack.pop_values(target_arity)?;
// Pop condition value.
stack.pop_values(1)?;
// Push values back.
stack.push_values(target_arity)?;
},
BrTable(br_table_data) => {
let arity_of_default = stack.frame(br_table_data.default)?.branch_arity;
// Check that all jump targets have an equal arities.
for target in &*br_table_data.table {
let arity = stack.frame(*target)?.branch_arity;
if arity != arity_of_default {
return Err("Arity of all jump-targets must be equal")
}
}
// Because all jump targets have an equal arities, we can just take arity of
// the default branch.
stack.pop_values(arity_of_default)?;
// This instruction doesn't let control flow to go further, since the control flow
// should take either one of branches depending on the value or the default branch.
stack.mark_unreachable()?;
},
Return => {
// Pop return values of the function. Mark successive instructions as unreachable
// since this instruction doesn't let control flow to go further.
stack.pop_values(func_arity)?;
stack.mark_unreachable()?;
},
Call(idx) => {
let ty = resolve_func_type(*idx, module)?;
// Pop values for arguments of the function.
stack.pop_values(ty.params().len() as u32)?;
// Push result of the function execution to the stack.
let callee_arity = ty.results().len() as u32;
stack.push_values(callee_arity)?;
},
CallIndirect(x, _) => {
let Type::Function(ty) =
type_section.types().get(*x as usize).ok_or("Type not found")?;
// Pop the offset into the function table.
stack.pop_values(1)?;
// Pop values for arguments of the function.
stack.pop_values(ty.params().len() as u32)?;
// Push result of the function execution to the stack.
let callee_arity = ty.results().len() as u32;
stack.push_values(callee_arity)?;
},
Drop => {
stack.pop_values(1)?;
},
Select => {
// Pop two values and one condition.
stack.pop_values(2)?;
stack.pop_values(1)?;
// Push the selected value.
stack.push_values(1)?;
},
GetLocal(_) => {
stack.push_values(1)?;
},
SetLocal(_) => {
stack.pop_values(1)?;
},
TeeLocal(_) => {
// This instruction pops and pushes the value, so
// effectively it doesn't modify the stack height.
stack.pop_values(1)?;
stack.push_values(1)?;
},
GetGlobal(_) => {
stack.push_values(1)?;
},
SetGlobal(_) => {
stack.pop_values(1)?;
},
I32Load(_, _) |
I64Load(_, _) |
F32Load(_, _) |
F64Load(_, _) |
I32Load8S(_, _) |
I32Load8U(_, _) |
I32Load16S(_, _) |
I32Load16U(_, _) |
I64Load8S(_, _) |
I64Load8U(_, _) |
I64Load16S(_, _) |
I64Load16U(_, _) |
I64Load32S(_, _) |
I64Load32U(_, _) => {
// These instructions pop the address and pushes the result,
// which effictively don't modify the stack height.
stack.pop_values(1)?;
stack.push_values(1)?;
},
I32Store(_, _) |
I64Store(_, _) |
F32Store(_, _) |
F64Store(_, _) |
I32Store8(_, _) |
I32Store16(_, _) |
I64Store8(_, _) |
I64Store16(_, _) |
I64Store32(_, _) => {
// These instructions pop the address and the value.
stack.pop_values(2)?;
},
CurrentMemory(_) => {
// Pushes current memory size
stack.push_values(1)?;
},
GrowMemory(_) => {
// Grow memory takes the value of pages to grow and pushes
stack.pop_values(1)?;
stack.push_values(1)?;
},
I32Const(_) | I64Const(_) | F32Const(_) | F64Const(_) => {
// These instructions just push the single literal value onto the stack.
stack.push_values(1)?;
},
I32Eqz | I64Eqz => {
// These instructions pop the value and compare it against zero, and pushes
// the result of the comparison.
stack.pop_values(1)?;
stack.push_values(1)?;
},
I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU | I32LeS | I32LeU | I32GeS |
I32GeU | I64Eq | I64Ne | I64LtS | I64LtU | I64GtS | I64GtU | I64LeS | I64LeU |
I64GeS | I64GeU | F32Eq | F32Ne | F32Lt | F32Gt | F32Le | F32Ge | F64Eq | F64Ne |
F64Lt | F64Gt | F64Le | F64Ge => {
// Comparison operations take two operands and produce one result.
stack.pop_values(2)?;
stack.push_values(1)?;
},
I32Clz | I32Ctz | I32Popcnt | I64Clz | I64Ctz | I64Popcnt | F32Abs | F32Neg |
F32Ceil | F32Floor | F32Trunc | F32Nearest | F32Sqrt | F64Abs | F64Neg | F64Ceil |
F64Floor | F64Trunc | F64Nearest | F64Sqrt => {
// Unary operators take one operand and produce one result.
stack.pop_values(1)?;
stack.push_values(1)?;
},
I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU | I32And | I32Or |
I32Xor | I32Shl | I32ShrS | I32ShrU | I32Rotl | I32Rotr | I64Add | I64Sub |
I64Mul | I64DivS | I64DivU | I64RemS | I64RemU | I64And | I64Or | I64Xor | I64Shl |
I64ShrS | I64ShrU | I64Rotl | I64Rotr | F32Add | F32Sub | F32Mul | F32Div |
F32Min | F32Max | F32Copysign | F64Add | F64Sub | F64Mul | F64Div | F64Min |
F64Max | F64Copysign => {
// Binary operators take two operands and produce one result.
stack.pop_values(2)?;
stack.push_values(1)?;
},
I32WrapI64 | I32TruncSF32 | I32TruncUF32 | I32TruncSF64 | I32TruncUF64 |
I64ExtendSI32 | I64ExtendUI32 | I64TruncSF32 | I64TruncUF32 | I64TruncSF64 |
I64TruncUF64 | F32ConvertSI32 | F32ConvertUI32 | F32ConvertSI64 | F32ConvertUI64 |
F32DemoteF64 | F64ConvertSI32 | F64ConvertUI32 | F64ConvertSI64 | F64ConvertUI64 |
F64PromoteF32 | I32ReinterpretF32 | I64ReinterpretF64 | F32ReinterpretI32 |
F64ReinterpretI64 => {
// Conversion operators take one value and produce one result.
stack.pop_values(1)?;
stack.push_values(1)?;
},
#[cfg(feature = "sign_ext")]
SignExt(SignExtInstruction::I32Extend8S) |
SignExt(SignExtInstruction::I32Extend16S) |
SignExt(SignExtInstruction::I64Extend8S) |
SignExt(SignExtInstruction::I64Extend16S) |
SignExt(SignExtInstruction::I64Extend32S) => {
stack.pop_values(1)?;
stack.push_values(1)?;
},
}
pc += 1;
}
Ok(max_height)
}
#[cfg(test)]
mod tests {
use super::*;
use parity_wasm::elements;
fn parse_wat(source: &str) -> elements::Module {
elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
.expect("Failed to deserialize the module")
}
#[test]
fn simple_test() {
let module = parse_wat(
r#"
(module
(func
i32.const 1
i32.const 2
i32.const 3
drop
drop
drop
)
)
"#,
);
let height = compute(0, &module).unwrap();
assert_eq!(height, 3);
}
#[test]
fn implicit_and_explicit_return() {
let module = parse_wat(
r#"
(module
(func (result i32)
i32.const 0
return
)
)
"#,
);
let height = compute(0, &module).unwrap();
assert_eq!(height, 1);
}
#[test]
fn dont_count_in_unreachable() {
let module = parse_wat(
r#"
(module
(memory 0)
(func (result i32)
unreachable
grow_memory
)
)
"#,
);
let height = compute(0, &module).unwrap();
assert_eq!(height, 0);
}
#[test]
fn yet_another_test() {
const SOURCE: &str = r#"
(module
(memory 0)
(func
;; Push two values and then pop them.
;; This will make max depth to be equal to 2.
i32.const 0
i32.const 1
drop
drop
;; Code after `unreachable` shouldn't have an effect
;; on the max depth.
unreachable
i32.const 0
i32.const 1
i32.const 2
)
)
"#;
let module = elements::deserialize_buffer(
wabt::Wat2Wasm::new()
.validate(false)
.convert(SOURCE)
.expect("Failed to wat2wasm")
.as_ref(),
)
.expect("Failed to deserialize the module");
let height = compute(0, &module).unwrap();
assert_eq!(height, 2);
}
#[test]
fn call_indirect() {
let module = parse_wat(
r#"
(module
(table $ptr 1 1 funcref)
(elem $ptr (i32.const 0) func 1)
(func $main
(call_indirect (i32.const 0))
(call_indirect (i32.const 0))
(call_indirect (i32.const 0))
)
(func $callee
i64.const 42
drop
)
)
"#,
);
let height = compute(0, &module).unwrap();
assert_eq!(height, 1);
}
#[test]
fn breaks() {
let module = parse_wat(
r#"
(module
(func $main
block (result i32)
block (result i32)
i32.const 99
br 1
end
end
drop
)
)
"#,
);
let height = compute(0, &module).unwrap();
assert_eq!(height, 1);
}
#[test]
fn if_else_works() {
let module = parse_wat(
r#"
(module
(func $main
i32.const 7
i32.const 1
if (result i32)
i32.const 42
else
i32.const 99
end
i32.const 97
drop
drop
drop
)
)
"#,
);
let height = compute(0, &module).unwrap();
assert_eq!(height, 3);
}
}
+383
View File
@@ -0,0 +1,383 @@
//! Contains the code for the stack height limiter instrumentation.
use alloc::{vec, vec::Vec};
use core::mem;
use parity_wasm::{
builder,
elements::{self, Instruction, Instructions, Type},
};
/// Macro to generate preamble and postamble.
macro_rules! instrument_call {
($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
use $crate::parity_wasm::elements::Instruction::*;
[
// stack_height += stack_cost(F)
GetGlobal($stack_height_global_idx),
I32Const($callee_stack_cost),
I32Add,
SetGlobal($stack_height_global_idx),
// if stack_counter > LIMIT: unreachable
GetGlobal($stack_height_global_idx),
I32Const($stack_limit as i32),
I32GtU,
If(elements::BlockType::NoResult),
Unreachable,
End,
// Original call
Call($callee_idx),
// stack_height -= stack_cost(F)
GetGlobal($stack_height_global_idx),
I32Const($callee_stack_cost),
I32Sub,
SetGlobal($stack_height_global_idx),
]
}};
}
mod max_height;
mod thunk;
pub struct Context {
stack_height_global_idx: u32,
func_stack_costs: Vec<u32>,
stack_limit: u32,
}
impl Context {
/// Returns index in a global index space of a stack_height global variable.
fn stack_height_global_idx(&self) -> u32 {
self.stack_height_global_idx
}
/// Returns `stack_cost` for `func_idx`.
fn stack_cost(&self, func_idx: u32) -> Option<u32> {
self.func_stack_costs.get(func_idx as usize).cloned()
}
/// Returns stack limit specified by the rules.
fn stack_limit(&self) -> u32 {
self.stack_limit
}
}
/// Inject the instumentation that makes stack overflows deterministic, by introducing
/// an upper bound of the stack size.
///
/// This pass introduces a global mutable variable to track stack height,
/// and instruments all calls with preamble and postamble.
///
/// Stack height is increased prior the call. Otherwise, the check would
/// be made after the stack frame is allocated.
///
/// The preamble is inserted before the call. It increments
/// the global stack height variable with statically determined "stack cost"
/// of the callee. If after the increment the stack height exceeds
/// the limit (specified by the `rules`) then execution traps.
/// Otherwise, the call is executed.
///
/// The postamble is inserted after the call. The purpose of the postamble is to decrease
/// the stack height by the "stack cost" of the callee function.
///
/// Note, that we can't instrument all possible ways to return from the function. The simplest
/// example would be a trap issued by the host function.
/// That means stack height global won't be equal to zero upon the next execution after such trap.
///
/// # Thunks
///
/// Because stack height is increased prior the call few problems arises:
///
/// - Stack height isn't increased upon an entry to the first function, i.e. exported function.
/// - Start function is executed externally (similar to exported functions).
/// - It is statically unknown what function will be invoked in an indirect call.
///
/// The solution for this problems is to generate a intermediate functions, called 'thunks', which
/// will increase before and decrease the stack height after the call to original function, and
/// then make exported function and table entries, start section to point to a corresponding thunks.
///
/// # Stack cost
///
/// Stack cost of the function is calculated as a sum of it's locals
/// and the maximal height of the value stack.
///
/// All values are treated equally, as they have the same size.
///
/// The rationale is that this makes it possible to use the following very naive wasm executor:
///
/// - values are implemented by a union, so each value takes a size equal to the size of the largest
/// possible value type this union can hold. (In MVP it is 8 bytes)
/// - each value from the value stack is placed on the native stack.
/// - each local variable and function argument is placed on the native stack.
/// - arguments pushed by the caller are copied into callee stack rather than shared between the
/// frames.
/// - upon entry into the function entire stack frame is allocated.
pub fn inject(
mut module: elements::Module,
stack_limit: u32,
) -> Result<elements::Module, &'static str> {
let mut ctx = Context {
stack_height_global_idx: generate_stack_height_global(&mut module),
func_stack_costs: compute_stack_costs(&module)?,
stack_limit,
};
instrument_functions(&mut ctx, &mut module)?;
let module = thunk::generate_thunks(&mut ctx, module)?;
Ok(module)
}
/// Generate a new global that will be used for tracking current stack height.
fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
let global_entry = builder::global()
.value_type()
.i32()
.mutable()
.init_expr(Instruction::I32Const(0))
.build();
// Try to find an existing global section.
for section in module.sections_mut() {
if let elements::Section::Global(gs) = section {
gs.entries_mut().push(global_entry);
return (gs.entries().len() as u32) - 1
}
}
// Existing section not found, create one!
module
.sections_mut()
.push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry])));
0
}
/// Calculate stack costs for all functions.
///
/// Returns a vector with a stack cost for each function, including imports.
fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, &'static str> {
let func_imports = module.import_count(elements::ImportCountType::Function);
// TODO: optimize!
(0..module.functions_space())
.map(|func_idx| {
if func_idx < func_imports {
// We can't calculate stack_cost of the import functions.
Ok(0)
} else {
compute_stack_cost(func_idx as u32, module)
}
})
.collect()
}
/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
/// number of arguments plus number of local variables) and the maximal stack
/// height.
fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
// To calculate the cost of a function we need to convert index from
// function index space to defined function spaces.
let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
let defined_func_idx = func_idx
.checked_sub(func_imports)
.ok_or("This should be a index of a defined function")?;
let code_section =
module.code_section().ok_or("Due to validation code section should exists")?;
let body = &code_section
.bodies()
.get(defined_func_idx as usize)
.ok_or("Function body is out of bounds")?;
let mut locals_count: u32 = 0;
for local_group in body.locals() {
locals_count =
locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
}
let max_stack_height = max_height::compute(defined_func_idx, module)?;
locals_count
.checked_add(max_stack_height)
.ok_or("Overflow in adding locals_count and max_stack_height")
}
fn instrument_functions(
ctx: &mut Context,
module: &mut elements::Module,
) -> Result<(), &'static str> {
for section in module.sections_mut() {
if let elements::Section::Code(code_section) = section {
for func_body in code_section.bodies_mut() {
let opcodes = func_body.code_mut();
instrument_function(ctx, opcodes)?;
}
}
}
Ok(())
}
/// This function searches `call` instructions and wrap each call
/// with preamble and postamble.
///
/// Before:
///
/// ```text
/// get_local 0
/// get_local 1
/// call 228
/// drop
/// ```
///
/// After:
///
/// ```text
/// get_local 0
/// get_local 1
///
/// < ... preamble ... >
///
/// call 228
///
/// < .. postamble ... >
///
/// drop
/// ```
fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), &'static str> {
use Instruction::*;
struct InstrumentCall {
offset: usize,
callee: u32,
cost: u32,
}
let calls: Vec<_> = func
.elements()
.iter()
.enumerate()
.filter_map(|(offset, instruction)| {
if let Call(callee) = instruction {
ctx.stack_cost(*callee).and_then(|cost| {
if cost > 0 {
Some(InstrumentCall { callee: *callee, offset, cost })
} else {
None
}
})
} else {
None
}
})
.collect();
// The `instrumented_call!` contains the call itself. This is why we need to subtract one.
let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1);
let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
let new_instrs = func.elements_mut();
let mut calls = calls.into_iter().peekable();
for (original_pos, instr) in original_instrs.into_iter().enumerate() {
// whether there is some call instruction at this position that needs to be instrumented
let did_instrument = if let Some(call) = calls.peek() {
if call.offset == original_pos {
let new_seq = instrument_call!(
call.callee,
call.cost as i32,
ctx.stack_height_global_idx(),
ctx.stack_limit()
);
new_instrs.extend(new_seq);
true
} else {
false
}
} else {
false
};
if did_instrument {
calls.next();
} else {
new_instrs.push(instr);
}
}
if calls.next().is_some() {
return Err("Not all calls were used")
}
Ok(())
}
fn resolve_func_type(
func_idx: u32,
module: &elements::Module,
) -> Result<&elements::FunctionType, &'static str> {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
let func_imports = module.import_count(elements::ImportCountType::Function);
let sig_idx = if func_idx < func_imports as u32 {
module
.import_section()
.expect("function import count is not zero; import section must exists; qed")
.entries()
.iter()
.filter_map(|entry| match entry.external() {
elements::External::Function(idx) => Some(*idx),
_ => None,
})
.nth(func_idx as usize)
.expect(
"func_idx is less than function imports count;
nth function import must be `Some`;
qed",
)
} else {
functions
.get(func_idx as usize - func_imports)
.ok_or("Function at the specified index is not defined")?
.type_ref()
};
let Type::Function(ty) = types
.get(sig_idx as usize)
.ok_or("The signature as specified by a function isn't defined")?;
Ok(ty)
}
#[cfg(test)]
mod tests {
use super::*;
use parity_wasm::elements;
fn parse_wat(source: &str) -> elements::Module {
elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
.expect("Failed to deserialize the module")
}
fn validate_module(module: elements::Module) {
let binary = elements::serialize(module).expect("Failed to serialize");
wabt::Module::read_binary(&binary, &Default::default())
.expect("Wabt failed to read final binary")
.validate()
.expect("Invalid module");
}
#[test]
fn test_with_params_and_result() {
let module = parse_wat(
r#"
(module
(func (export "i32.add") (param i32 i32) (result i32)
get_local 0
get_local 1
i32.add
)
)
"#,
);
let module = inject(module, 1024).expect("Failed to inject stack counter");
validate_module(module);
}
}
+139
View File
@@ -0,0 +1,139 @@
#[cfg(not(features = "std"))]
use alloc::collections::BTreeMap as Map;
use alloc::vec::Vec;
use parity_wasm::{
builder,
elements::{self, FunctionType, Internal},
};
#[cfg(features = "std")]
use std::collections::HashMap as Map;
use super::{resolve_func_type, Context};
struct Thunk {
signature: FunctionType,
// Index in function space of this thunk.
idx: Option<u32>,
callee_stack_cost: u32,
}
pub fn generate_thunks(
ctx: &mut Context,
module: elements::Module,
) -> Result<elements::Module, &'static str> {
// First, we need to collect all function indices that should be replaced by thunks
let mut replacement_map: Map<u32, Thunk> = {
let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]);
let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]);
let start_func_idx = module.start_section();
let exported_func_indices = exports.iter().filter_map(|entry| match entry.internal() {
Internal::Function(function_idx) => Some(*function_idx),
_ => None,
});
let table_func_indices =
elem_segments.iter().flat_map(|segment| segment.members()).cloned();
// Replacement map is at least export section size.
let mut replacement_map: Map<u32, Thunk> = Map::new();
for func_idx in exported_func_indices
.chain(table_func_indices)
.chain(start_func_idx.into_iter())
{
let callee_stack_cost = ctx.stack_cost(func_idx).ok_or("function index isn't found")?;
// Don't generate a thunk if stack_cost of a callee is zero.
if callee_stack_cost != 0 {
replacement_map.insert(
func_idx,
Thunk {
signature: resolve_func_type(func_idx, &module)?.clone(),
idx: None,
callee_stack_cost,
},
);
}
}
replacement_map
};
// Then, we generate a thunk for each original function.
// Save current func_idx
let mut next_func_idx = module.functions_space() as u32;
let mut mbuilder = builder::from_module(module);
for (func_idx, thunk) in replacement_map.iter_mut() {
let instrumented_call = instrument_call!(
*func_idx,
thunk.callee_stack_cost as i32,
ctx.stack_height_global_idx(),
ctx.stack_limit()
);
// Thunk body consist of:
// - argument pushing
// - instrumented call
// - end
let mut thunk_body: Vec<elements::Instruction> =
Vec::with_capacity(thunk.signature.params().len() + instrumented_call.len() + 1);
for (arg_idx, _) in thunk.signature.params().iter().enumerate() {
thunk_body.push(elements::Instruction::GetLocal(arg_idx as u32));
}
thunk_body.extend(instrumented_call.iter().cloned());
thunk_body.push(elements::Instruction::End);
// TODO: Don't generate a signature, but find an existing one.
mbuilder = mbuilder
.function()
// Signature of the thunk should match the original function signature.
.signature()
.with_params(thunk.signature.params().to_vec())
.with_results(thunk.signature.results().to_vec())
.build()
.body()
.with_instructions(elements::Instructions::new(thunk_body))
.build()
.build();
thunk.idx = Some(next_func_idx);
next_func_idx += 1;
}
let mut module = mbuilder.build();
// And finally, fixup thunks in export and table sections.
// Fixup original function index to a index of a thunk generated earlier.
let fixup = |function_idx: &mut u32| {
// Check whether this function is in replacement_map, since
// we can skip thunk generation (e.g. if stack_cost of function is 0).
if let Some(thunk) = replacement_map.get(function_idx) {
*function_idx =
thunk.idx.expect("At this point an index must be assigned to each thunk");
}
};
for section in module.sections_mut() {
match section {
elements::Section::Export(export_section) =>
for entry in export_section.entries_mut() {
if let Internal::Function(function_idx) = entry.internal_mut() {
fixup(function_idx)
}
},
elements::Section::Element(elem_section) =>
for segment in elem_section.entries_mut() {
for function_idx in segment.members_mut() {
fixup(function_idx)
}
},
elements::Section::Start(start_idx) => fixup(start_idx),
_ => {},
}
}
Ok(module)
}