diff --git a/Cargo.toml b/Cargo.toml index 0db70b0..c2d137b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,8 @@ byteorder = "1" [dev-dependencies] tempdir = "0.3" -wabt = "0.1" +wabt = "0.2" +diff = "0.1.11" [lib] @@ -33,3 +34,7 @@ path = "gas/src/main.rs" [[bin]] name = "wasm-build" path = "build/src/main.rs" + +[[bin]] +name = "wasm-stack-height" +path = "stack_height/src/main.rs" diff --git a/src/lib.rs b/src/lib.rs index adce5c4..5e3aec5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,8 @@ mod ext; mod pack; mod runtime_type; +pub mod stack_height; + pub use optimizer::{optimize, Error as OptimizerError}; pub use gas::inject_gas_counter; pub use logger::init_log; diff --git a/src/rules.rs b/src/rules.rs index 31e7cf3..50d85cd 100644 --- a/src/rules.rs +++ b/src/rules.rs @@ -281,7 +281,7 @@ impl Default for Set { impl Set { pub fn new(regular: u32, entries: HashMap) -> Self { - Set { regular: regular, entries: entries, grow: 0, } + Set { regular: regular, entries: entries, grow: 0 } } pub fn process(&self, opcode: &elements::Opcode) -> Result { @@ -308,4 +308,4 @@ impl Set { self.entries.insert(InstructionType::FloatConversion, Metering::Forbidden); self } -} \ No newline at end of file +} diff --git a/src/stack_height/max_height.rs b/src/stack_height/max_height.rs new file mode 100644 index 0000000..03515bc --- /dev/null +++ b/src/stack_height/max_height.rs @@ -0,0 +1,524 @@ +use parity_wasm::elements::{self, BlockType, Type}; +use super::{resolve_func_type, Error}; + +/// Control stack frame. +#[derive(Debug)] +struct Frame { + /// Stack becomes polymorphic only after an instruction that + /// never passes control further was executed. + is_polymorphic: bool, + + /// Count of values which will be pushed after the exit + /// from the current block. + end_arity: u32, + + /// Count of values which should be poped upon a branch to + /// this frame. + /// + /// This might be diffirent from `end_arity` since branch + /// to the loop header can't take any values. + branch_arity: u32, + + /// Stack height before entering in the block. + start_height: u32, +} + +/// This is a compound stack that abstracts tracking height of the value stack +/// and manipulation of the control stack. +struct Stack { + height: u32, + control_stack: Vec, +} + +impl Stack { + fn new() -> Stack { + Stack { + height: 0, + control_stack: Vec::new(), + } + } + + /// Returns current height of the value stack. + fn height(&self) -> u32 { + self.height + } + + /// Returns a reference to a frame by specified depth relative to the top of + /// control stack. + fn frame(&self, rel_depth: u32) -> Result<&Frame, Error> { + let control_stack_height: usize = self.control_stack.len(); + let last_idx = control_stack_height + .checked_sub(1) + .ok_or_else(|| Error("control stack is empty".into()))?; + let idx = last_idx + .checked_sub(rel_depth as usize) + .ok_or_else(|| Error("control stack out-of-bounds".into()))?; + Ok(&self.control_stack[idx]) + } + + /// Mark successive instructions as unreachable. + /// + /// This effectively makes stack polymorphic. + fn mark_unreachable(&mut self) -> Result<(), Error> { + trace!("unreachable"); + let top_frame = self.control_stack + .last_mut() + .ok_or_else(|| Error("stack must be non-empty".into()))?; + top_frame.is_polymorphic = true; + Ok(()) + } + + /// Push control frame into the control stack. + fn push_frame(&mut self, frame: Frame) { + trace!("push_frame: {:?}", frame); + self.control_stack.push(frame); + } + + /// Pop control frame from the control stack. + /// + /// Returns `Err` if the control stack is empty. + fn pop_frame(&mut self) -> Result { + trace!("pop_frame: {:?}", self.control_stack.last()); + Ok(self.control_stack + .pop() + .ok_or_else(|| Error("stack must be non-empty".into()))?) + } + + /// Truncate the height of value stack to the specified height. + fn trunc(&mut self, new_height: u32) { + trace!("trunc: {}", new_height); + self.height = new_height; + } + + /// Push specified number of values into the value stack. + /// + /// Returns `Err` if the height overflow usize value. + fn push_values(&mut self, value_count: u32) -> Result<(), Error> { + trace!("push: {}", value_count); + self.height = self.height + .checked_add(value_count) + .ok_or_else(|| Error("stack overflow".into()))?; + Ok(()) + } + + /// Pop specified number of values from the value stack. + /// + /// Returns `Err` if the stack happen to be negative value after + /// values popped. + fn pop_values(&mut self, value_count: u32) -> Result<(), Error> { + trace!("pop: {}", value_count); + if value_count == 0 { + return Ok(()); + } + { + let top_frame = self.frame(0)?; + if self.height == top_frame.start_height { + // It is an error to pop more values than was pushed in the current frame + // (ie pop values pushed in the parent frame), unless the frame became + // polymorphic. + return if top_frame.is_polymorphic { + Ok(()) + } else { + Err(Error("trying to pop more values than pushed".into())) + } + } + } + + self.height = self.height + .checked_sub(value_count) + .ok_or_else(|| Error("stack underflow".into()))?; + + Ok(()) + } +} + +/// This function expects the function to be validated. +pub(crate) fn compute(func_idx: u32, module: &elements::Module) -> Result { + use parity_wasm::elements::Opcode::*; + + let func_section = module + .function_section() + .ok_or_else(|| Error("No function section".into()))?; + let code_section = module + .code_section() + .ok_or_else(|| Error("No code section".into()))?; + let type_section = module + .type_section() + .ok_or_else(|| Error("No type section".into()))?; + + + trace!("func_idx: {}", func_idx); + + // Get a signature and a body of the specified function. + let func_sig_idx = func_section + .entries() + .get(func_idx as usize) + .ok_or_else(|| Error("Function is not found in func section".into()))? + .type_ref(); + let Type::Function(ref func_signature) = *type_section + .types() + .get(func_sig_idx as usize) + .ok_or_else(|| Error("Function is not found in func section".into()))?; + let body = code_section + .bodies() + .get(func_idx as usize) + .ok_or_else(|| Error("Function body for the index isn't found".into()))?; + let opcodes = body.code(); + + let mut stack = Stack::new(); + let mut max_height: u32 = 0; + let mut pc = 0; + + // Add implicit frame for the function. Breaks to this frame and execution of + // the last end should deal with this frame. + let func_arity: u32 = if func_signature.return_type().is_some() { + 1 + } else { + 0 + }; + stack.push_frame(Frame { + is_polymorphic: false, + end_arity: func_arity, + branch_arity: func_arity, + start_height: 0, + }); + + loop { + if pc >= opcodes.elements().len() { + break; + } + + // If current value stack is higher than maximal height observed so far, + // save the new height. + // However, we don't increase maximal value in unreachable code. + if stack.height() > max_height && !stack.frame(0)?.is_polymorphic { + max_height = stack.height(); + } + + let opcode = &opcodes.elements()[pc]; + trace!("{:?}", opcode); + + match *opcode { + Nop => {} + Block(ty) | Loop(ty) | If(ty) => { + let end_arity = if ty == BlockType::NoResult { 0 } else { 1 }; + let branch_arity = if let Loop(_) = *opcode { 0 } else { end_arity }; + let height = stack.height(); + stack.push_frame(Frame { + is_polymorphic: false, + end_arity, + branch_arity, + start_height: height, + }); + } + Else => { + // The frame at the top should be pushed by `If`. So we leave + // it as is. + } + End => { + let frame = stack.pop_frame()?; + stack.trunc(frame.start_height); + stack.push_values(frame.end_arity)?; + } + Unreachable => { + stack.mark_unreachable()?; + } + Br(target) => { + // Pop values for the destination block result. + let target_arity = stack.frame(target)?.branch_arity; + stack.pop_values(target_arity)?; + + // This instruction unconditionally transfers control to the specified block, + // thus all instruction until the end of the current block is deemed unreachable + stack.mark_unreachable()?; + } + BrIf(target) => { + // Pop values for the destination block result. + let target_arity = stack.frame(target)?.branch_arity; + stack.pop_values(target_arity)?; + + // Pop condition value. + stack.pop_values(1)?; + + // Push values back. + stack.push_values(target_arity)?; + } + BrTable(ref targets, default_target) => { + let arity_of_default = stack.frame(default_target)?.branch_arity; + + // Check that all jump targets have an equal arities. + for target in targets.iter() { + let arity = stack.frame(*target)?.branch_arity; + if arity != arity_of_default { + return Err(Error( + "Arity of all jump-targets must be equal".into() + )) + } + } + + // Because all jump targets have an equal arities, we can just take arity of + // the default branch. + stack.pop_values(arity_of_default)?; + + // This instruction doesn't let control flow to go further, since the control flow + // should take either one of branches depending on the value or the default branch. + stack.mark_unreachable()?; + } + Return => { + // Pop return values of the function. Mark successive instructions as unreachable + // since this instruction doesn't let control flow to go further. + stack.pop_values(func_arity)?; + stack.mark_unreachable()?; + } + Call(idx) => { + let ty = resolve_func_type(idx, module)?; + + // Pop values for arguments of the function. + stack.pop_values(ty.params().len() as u32)?; + + // Push result of the function execution to the stack. + let callee_arity = if ty.return_type().is_some() { 1 } else { 0 }; + stack.push_values(callee_arity)?; + } + CallIndirect(x, _) => { + let Type::Function(ref ty) = *type_section + .types() + .get(x as usize) + .ok_or_else(|| Error("Type not found".into()))?; + + // Pop values for arguments of the function. + stack.pop_values(ty.params().len() as u32)?; + + // Push result of the function execution to the stack. + let callee_arity = if ty.return_type().is_some() { 1 } else { 0 }; + stack.push_values(callee_arity)?; + } + Drop => { + stack.pop_values(1)?; + } + Select => { + // Pop two values and one condition. + stack.pop_values(2)?; + stack.pop_values(1)?; + + // Push the selected value. + stack.push_values(1)?; + } + GetLocal(_) => { + stack.push_values(1)?; + } + SetLocal(_) => { + stack.pop_values(1)?; + } + TeeLocal(_) => { + // This instruction pops and pushes the value, so + // effectively it doesn't modify the stack height. + stack.pop_values(1)?; + stack.push_values(1)?; + } + GetGlobal(_) => { + stack.push_values(1)?; + } + SetGlobal(_) => { + stack.pop_values(1)?; + } + I32Load(_, _) + | I64Load(_, _) + | F32Load(_, _) + | F64Load(_, _) + | I32Load8S(_, _) + | I32Load8U(_, _) + | I32Load16S(_, _) + | I32Load16U(_, _) + | I64Load8S(_, _) + | I64Load8U(_, _) + | I64Load16S(_, _) + | I64Load16U(_, _) + | I64Load32S(_, _) + | I64Load32U(_, _) => { + // These instructions pop the address and pushes the result, + // which effictively don't modify the stack height. + stack.pop_values(1)?; + stack.push_values(1)?; + } + + I32Store(_, _) + | I64Store(_, _) + | F32Store(_, _) + | F64Store(_, _) + | I32Store8(_, _) + | I32Store16(_, _) + | I64Store8(_, _) + | I64Store16(_, _) + | I64Store32(_, _) => { + // These instructions pop the address and the value. + stack.pop_values(2)?; + } + + CurrentMemory(_) => { + // Pushes current memory size + stack.push_values(1)?; + } + GrowMemory(_) => { + // Grow memory takes the value of pages to grow and pushes + stack.pop_values(1)?; + stack.push_values(1)?; + } + + I32Const(_) | I64Const(_) | F32Const(_) | F64Const(_) => { + // These instructions just push the single literal value onto the stack. + stack.push_values(1)?; + } + + I32Eqz | I64Eqz => { + // These instructions pop the value and compare it against zero, and pushes + // the result of the comparison. + stack.pop_values(1)?; + stack.push_values(1)?; + } + + I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU | I32LeS | I32LeU | I32GeS + | I32GeU | I64Eq | I64Ne | I64LtS | I64LtU | I64GtS | I64GtU | I64LeS | I64LeU + | I64GeS | I64GeU | F32Eq | F32Ne | F32Lt | F32Gt | F32Le | F32Ge | F64Eq | F64Ne + | F64Lt | F64Gt | F64Le | F64Ge => { + // Comparison operations take two operands and produce one result. + stack.pop_values(2)?; + stack.push_values(1)?; + } + + I32Clz | I32Ctz | I32Popcnt | I64Clz | I64Ctz | I64Popcnt | F32Abs | F32Neg + | F32Ceil | F32Floor | F32Trunc | F32Nearest | F32Sqrt | F64Abs | F64Neg | F64Ceil + | F64Floor | F64Trunc | F64Nearest | F64Sqrt => { + // Unary operators take one operand and produce one result. + stack.pop_values(1)?; + stack.push_values(1)?; + } + + I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU | I32And | I32Or + | I32Xor | I32Shl | I32ShrS | I32ShrU | I32Rotl | I32Rotr | I64Add | I64Sub + | I64Mul | I64DivS | I64DivU | I64RemS | I64RemU | I64And | I64Or | I64Xor | I64Shl + | I64ShrS | I64ShrU | I64Rotl | I64Rotr | F32Add | F32Sub | F32Mul | F32Div + | F32Min | F32Max | F32Copysign | F64Add | F64Sub | F64Mul | F64Div | F64Min + | F64Max | F64Copysign => { + // Binary operators take two operands and produce one result. + stack.pop_values(2)?; + stack.push_values(1)?; + } + + I32WrapI64 | I32TruncSF32 | I32TruncUF32 | I32TruncSF64 | I32TruncUF64 + | I64ExtendSI32 | I64ExtendUI32 | I64TruncSF32 | I64TruncUF32 | I64TruncSF64 + | I64TruncUF64 | F32ConvertSI32 | F32ConvertUI32 | F32ConvertSI64 | F32ConvertUI64 + | F32DemoteF64 | F64ConvertSI32 | F64ConvertUI32 | F64ConvertSI64 | F64ConvertUI64 + | F64PromoteF32 | I32ReinterpretF32 | I64ReinterpretF64 | F32ReinterpretI32 + | F64ReinterpretI64 => { + // Conversion operators take one value and produce one result. + stack.pop_values(1)?; + stack.push_values(1)?; + } + } + pc += 1; + } + + Ok(max_height) +} + +#[cfg(test)] +mod tests { + extern crate wabt; + use parity_wasm::elements; + use super::*; + + fn parse_wat(source: &str) -> elements::Module { + elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm")) + .expect("Failed to deserialize the module") + } + + #[test] + fn simple_test() { + let module = parse_wat( + r#" +(module + (func + i32.const 1 + i32.const 2 + i32.const 3 + drop + drop + drop + ) +) +"#, + ); + + let height = compute(0, &module).unwrap(); + assert_eq!(height, 3); + } + + #[test] + fn implicit_and_explicit_return() { + let module = parse_wat( + r#" +(module + (func (result i32) + i32.const 0 + return + ) +) +"#, + ); + + let height = compute(0, &module).unwrap(); + assert_eq!(height, 1); + } + + #[test] + fn dont_count_in_unreachable() { + let module = parse_wat( + r#" +(module + (memory 0) + (func (result i32) + unreachable + grow_memory + ) +) +"#, + ); + + let height = compute(0, &module).unwrap(); + assert_eq!(height, 0); + } + + #[test] + fn yet_another_test() { + const SOURCE: &'static str = r#" +(module + (memory 0) + (func + ;; Push two values and then pop them. + ;; This will make max depth to be equal to 2. + i32.const 0 + i32.const 1 + drop + drop + + ;; Code after `unreachable` shouldn't have an effect + ;; on the max depth. + unreachable + i32.const 0 + i32.const 1 + i32.const 2 + ) +) +"#; + let module = elements::deserialize_buffer(&wabt::Wat2Wasm::new() + .validate(false) + .convert(SOURCE) + .expect("Failed to wat2wasm") + .as_ref()) + .expect("Failed to deserialize the module"); + + let height = compute(0, &module).unwrap(); + assert_eq!(height, 2); + } +} diff --git a/src/stack_height/mod.rs b/src/stack_height/mod.rs new file mode 100644 index 0000000..9502ff5 --- /dev/null +++ b/src/stack_height/mod.rs @@ -0,0 +1,433 @@ +//! The pass that tries to make stack overflows deterministic, by introducing +//! an upper bound of the stack size. +//! +//! This pass introduces a global mutable variable to track stack height, +//! and instruments all calls with preamble and postamble. +//! +//! Stack height is increased prior the call. Otherwise, the check would +//! be made after the stack frame is allocated. +//! +//! The preamble is inserted before the call. It increments +//! the global stack height variable with statically determined "stack cost" +//! of the callee. If after the increment the stack height exceeds +//! the limit (specified by the `rules`) then execution traps. +//! Otherwise, the call is executed. +//! +//! The postamble is inserted after the call. The purpose of the postamble is to decrease +//! the stack height by the "stack cost" of the callee function. +//! +//! Note, that we can't instrument all possible ways to return from the function. The simplest +//! example would be a trap issued by the host function. +//! That means stack height global won't be equal to zero upon the next execution after such trap. +//! +//! # Thunks +//! +//! Because stack height is increased prior the call few problems arises: +//! +//! - Stack height isn't increased upon an entry to the first function, i.e. exported function. +//! - It is statically unknown what function will be invoked in an indirect call. +//! +//! The solution for this problems is to generate a intermediate functions, called 'thunks', which +//! will increase before and decrease the stack height after the call to original function, and +//! then make exported function and table entries to point to a corresponding thunks. +//! +//! # Stack cost +//! +//! Stack cost of the function is calculated as a sum of it's locals +//! and the maximal height of the value stack. +//! +//! All values are treated equally, as they have the same size. +//! +//! The rationale for this it makes it possible to use this very naive wasm executor, that is: +//! +//! - values are implemented by a union, so each value takes a size equal to +//! the size of the largest possible value type this union can hold. (In MVP it is 8 bytes) +//! - each value from the value stack is placed on the native stack. +//! - each local variable and function argument is placed on the native stack. +//! - arguments pushed by the caller are copied into callee stack rather than shared +//! between the frames. +//! - upon entry into the function entire stack frame is allocated. + +use parity_wasm::elements::{self, Type}; +use parity_wasm::builder; + +/// Macro to generate preamble and postamble. +macro_rules! instrument_call { + ($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{ + use $crate::parity_wasm::elements::Opcode::*; + [ + // stack_height += stack_cost(F) + GetGlobal($stack_height_global_idx), + I32Const($callee_stack_cost), + I32Add, + SetGlobal($stack_height_global_idx), + // if stack_counter > LIMIT: unreachable + GetGlobal($stack_height_global_idx), + I32Const($stack_limit as i32), + I32GtU, + If(elements::BlockType::NoResult), + Unreachable, + End, + // Original call + Call($callee_idx), + // stack_height -= stack_cost(F) + GetGlobal($stack_height_global_idx), + I32Const($callee_stack_cost), + I32Sub, + SetGlobal($stack_height_global_idx), + ] + }}; +} + +mod max_height; +mod thunk; + +/// Error that occured during processing the module. +/// +/// This means that the module is invalid. +#[derive(Debug)] +pub struct Error(String); + +pub(crate) struct Context { + stack_height_global_idx: Option, + func_stack_costs: Option>, + stack_limit: u32, +} + +impl Context { + /// Returns index in a global index space of a stack_height global variable. + /// + /// Panics if it haven't generated yet. + fn stack_height_global_idx(&self) -> u32 { + self.stack_height_global_idx.expect( + "stack_height_global_idx isn't yet generated; + Did you call `inject_stack_counter_global`", + ) + } + + /// Returns `stack_cost` for `func_idx`. + /// + /// Panics if stack costs haven't computed yet or `func_idx` is greater + /// than the last function index. + fn stack_cost(&self, func_idx: u32) -> Option { + self.func_stack_costs + .as_ref() + .expect( + "func_stack_costs isn't yet computed; + Did you call `compute_stack_costs`?", + ) + .get(func_idx as usize) + .cloned() + } + + /// Returns stack limit specified by the rules. + fn stack_limit(&self) -> u32 { + self.stack_limit + } +} + +/// Instrument a module with stack height limiter. +/// +/// See module-level documentation for more details. +/// +/// # Errors +/// +/// Returns `Err` if module is invalid and can't be +pub fn inject_limiter( + mut module: elements::Module, + stack_limit: u32, +) -> Result { + let mut ctx = Context { + stack_height_global_idx: None, + func_stack_costs: None, + stack_limit, + }; + + generate_stack_height_global(&mut ctx, &mut module); + compute_stack_costs(&mut ctx, &module)?; + instrument_functions(&mut ctx, &mut module)?; + let module = thunk::generate_thunks(&mut ctx, module)?; + + Ok(module) +} + +/// Generate a new global that will be used for tracking current stack height. +fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module) { + let global_entry = builder::global() + .value_type() + .i32() + .mutable() + .init_expr(elements::Opcode::I32Const(0)) + .build(); + + // Try to find an existing global section. + for section in module.sections_mut() { + match *section { + elements::Section::Global(ref mut gs) => { + gs.entries_mut().push(global_entry); + + let stack_height_global_idx = (gs.entries().len() as u32) - 1; + ctx.stack_height_global_idx = Some(stack_height_global_idx); + return; + } + _ => {} + } + } + + // Existing section not found, create one! + module.sections_mut().push(elements::Section::Global( + elements::GlobalSection::with_entries(vec![global_entry]), + )); + ctx.stack_height_global_idx = Some(0); +} + +/// Calculate stack costs for all functions. +/// +/// Returns a vector with a stack cost for each function, including imports. +fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), Error> { + let func_imports = module.import_count(elements::ImportCountType::Function); + let mut func_stack_costs = vec![0; module.functions_space()]; + // TODO: optimize! + for (func_idx, func_stack_cost) in func_stack_costs.iter_mut().enumerate() { + // We can't calculate stack_cost of the import functions. + if func_idx >= func_imports { + *func_stack_cost = compute_stack_cost(func_idx as u32, &module)?; + } + } + + ctx.func_stack_costs = Some(func_stack_costs); + Ok(()) +} + +/// Stack cost of the given *defined* function is the sum of it's locals count (that is, +/// number of arguments plus number of local variables) and the maximal stack +/// height. +fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result { + // To calculate the cost of a function we need to convert index from + // function index space to defined function spaces. + let func_imports = module.import_count(elements::ImportCountType::Function) as u32; + let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| { + Error("This should be a index of a defined function".into()) + })?; + + let code_section = module.code_section().ok_or_else(|| { + Error("Due to validation code section should exists".into()) + })?; + let body = &code_section + .bodies() + .get(defined_func_idx as usize) + .ok_or_else(|| Error("Function body is out of bounds".into()))?; + let locals_count = body.locals().len() as u32; + + let max_stack_height = + max_height::compute( + defined_func_idx, + module + )?; + + Ok(locals_count + max_stack_height) +} + +fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> { + for section in module.sections_mut() { + match *section { + elements::Section::Code(ref mut code_section) => { + for func_body in code_section.bodies_mut() { + let mut opcodes = func_body.code_mut(); + instrument_function(ctx, opcodes)?; + } + } + _ => {} + } + } + Ok(()) +} + +/// This function searches `call` instructions and wrap each call +/// with preamble and postamble. +/// +/// Before: +/// +/// ```text +/// get_local 0 +/// get_local 1 +/// call 228 +/// drop +/// ``` +/// +/// After: +/// +/// ```text +/// get_local 0 +/// get_local 1 +/// +/// < ... preamble ... > +/// +/// call 228 +/// +/// < .. postamble ... > +/// +/// drop +/// ``` +fn instrument_function( + ctx: &mut Context, + opcodes: &mut elements::Opcodes, +) -> Result<(), Error> { + use parity_wasm::elements::Opcode::*; + + let mut cursor = 0; + loop { + if cursor >= opcodes.elements().len() { + break; + } + + enum Action { + InstrumentCall { + callee_idx: u32, + callee_stack_cost: u32, + }, + Nop, + } + + let action: Action = { + let opcode = &opcodes.elements()[cursor]; + match *opcode { + Call(ref callee_idx) => { + let callee_stack_cost = ctx + .stack_cost(*callee_idx) + .ok_or_else(|| + Error( + format!("Call to function that out-of-bounds: {}", callee_idx) + ) + )?; + + // Instrument only calls to a functions which stack_cost is + // non-zero. + if callee_stack_cost > 0 { + Action::InstrumentCall { + callee_idx: *callee_idx, + callee_stack_cost, + } + } else { + Action::Nop + } + }, + _ => Action::Nop, + } + }; + + match action { + // We need to wrap a `call idx` instruction + // with a code that adjusts stack height counter + // and then restores it. + Action::InstrumentCall { callee_idx, callee_stack_cost } => { + let new_seq = instrument_call!( + callee_idx, + callee_stack_cost as i32, + ctx.stack_height_global_idx(), + ctx.stack_limit() + ); + + // Replace the original `call idx` instruction with + // a wrapped call sequence. + // + // To splice actually take a place, we need to consume iterator + // splice returns. So we just `count()` it. + let _ = opcodes + .elements_mut() + .splice(cursor..(cursor + 1), new_seq.iter().cloned()) + .count(); + + // Advance cursor to be after the inserted sequence. + cursor += new_seq.len(); + } + // Do nothing for other instructions. + _ => { + cursor += 1; + } + } + } + + Ok(()) +} + +fn resolve_func_type( + func_idx: u32, + module: &elements::Module, +) -> Result<&elements::FunctionType, Error> { + let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]); + let functions = module + .function_section() + .map(|fs| fs.entries()) + .unwrap_or(&[]); + + let func_imports = module.import_count(elements::ImportCountType::Function); + let sig_idx = if func_idx < func_imports as u32 { + module + .import_section() + .expect("function import count is not zero; import section must exists; qed") + .entries() + .iter() + .filter_map(|entry| match *entry.external() { + elements::External::Function(ref idx) => Some(*idx), + _ => None, + }) + .nth(func_idx as usize) + .expect( + "func_idx is less than function imports count; + nth function import must be `Some`; + qed", + ) + } else { + functions + .get(func_idx as usize - func_imports) + .ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))? + .type_ref() + }; + let Type::Function(ref ty) = *types.get(sig_idx as usize).ok_or_else(|| { + Error(format!( + "Signature {} (specified by func {}) isn't defined", + sig_idx, func_idx + )) + })?; + Ok(ty) +} + +#[cfg(test)] +mod tests { + extern crate wabt; + use parity_wasm::elements; + use super::*; + + fn parse_wat(source: &str) -> elements::Module { + elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm")) + .expect("Failed to deserialize the module") + } + + fn validate_module(module: elements::Module) { + let binary = elements::serialize(module).expect("Failed to serialize"); + wabt::Module::read_binary(&binary, &Default::default()) + .expect("Wabt failed to read final binary") + .validate() + .expect("Invalid module"); + } + + #[test] + fn test_with_params_and_result() { + let module = parse_wat( + r#" +(module + (func (export "i32.add") (param i32 i32) (result i32) + get_local 0 + get_local 1 + i32.add + ) +) +"#, + ); + + let module = inject_limiter(module, 1024) + .expect("Failed to inject stack counter"); + validate_module(module); + } +} diff --git a/src/stack_height/thunk.rs b/src/stack_height/thunk.rs new file mode 100644 index 0000000..be1d626 --- /dev/null +++ b/src/stack_height/thunk.rs @@ -0,0 +1,159 @@ +use parity_wasm::elements::{self, FunctionType, Internal}; +use parity_wasm::builder; + +use std::collections::HashMap; + +use super::{resolve_func_type, Context, Error}; + +struct Thunk { + signature: FunctionType, + // Index in function space of this thunk. + idx: Option, + original_func_idx: u32, + callee_stack_cost: u32, +} + +pub(crate) fn generate_thunks( + ctx: &mut Context, + module: elements::Module, +) -> Result { + // First, we need to collect all function indicies that should be replaced by thunks + + // Function indicies which needs to generate thunks. + let mut need_thunks: Vec = Vec::new(); + + let mut replacement_map: HashMap = { + let exports = module + .export_section() + .map(|es| es.entries()) + .unwrap_or(&[]); + let elem_segments = module + .elements_section() + .map(|es| es.entries()) + .unwrap_or(&[]); + + let exported_func_indicies = exports.iter().filter_map(|entry| match *entry.internal() { + Internal::Function(ref function_idx) => Some(*function_idx), + _ => None, + }); + let table_func_indicies = elem_segments + .iter() + .flat_map(|segment| segment.members()) + .cloned(); + + // Replacement map is at least export section size. + let mut replacement_map: HashMap = HashMap::new(); + + for func_idx in exported_func_indicies.chain(table_func_indicies) { + let callee_stack_cost = ctx.stack_cost(func_idx).ok_or_else(|| { + Error(format!("function with idx {} isn't found", func_idx)) + })?; + + // Don't generate a thunk if stack_cost of a callee is zero. + if callee_stack_cost != 0 { + need_thunks.push(func_idx); + replacement_map.insert(func_idx, Thunk { + signature: resolve_func_type(func_idx, &module)?.clone(), + idx: None, + callee_stack_cost, + original_func_idx: func_idx, + }); + } + } + + replacement_map + }; + + // Then, we generate a thunk for each original function. + + // Save current func_idx + let mut next_func_idx = module.functions_space() as u32; + + let mut mbuilder = builder::from_module(module); + for func_idx in need_thunks { + let mut thunk = replacement_map + .get_mut(&func_idx) + .expect( + "`func_idx` should come from `need_thunks`; + `need_thunks` is populated with the same items that in `replacement_map`; + qed" + ); + + let instrumented_call = instrument_call!( + thunk.original_func_idx as u32, + thunk.callee_stack_cost as i32, + ctx.stack_height_global_idx(), + ctx.stack_limit() + ); + // Thunk body consist of: + // - argument pushing + // - instrumented call + // - end + let mut thunk_body: Vec = Vec::with_capacity( + thunk.signature.params().len() + + instrumented_call.len() + + 1 + ); + + for (arg_idx, _) in thunk.signature.params().iter().enumerate() { + thunk_body.push(elements::Opcode::GetLocal(arg_idx as u32)); + } + thunk_body.extend(instrumented_call.iter().cloned()); + thunk_body.push(elements::Opcode::End); + + // TODO: Don't generate a signature, but find an existing one. + + mbuilder = mbuilder.function() + // Signature of the thunk should match the original function signature. + .signature() + .with_params(thunk.signature.params().to_vec()) + .with_return_type(thunk.signature.return_type().clone()) + .build() + .body() + .with_opcodes(elements::Opcodes::new( + thunk_body + )) + .build() + .build(); + + thunk.idx = Some(next_func_idx); + next_func_idx += 1; + } + let mut module = mbuilder.build(); + + // And finally, fixup thunks in export and table sections. + + // Fixup original function index to a index of a thunk generated earlier. + let fixup = |function_idx: &mut u32| { + // Check whether this function is in replacement_map, since + // we can skip thunk generation (e.g. if stack_cost of function is 0). + if let Some(ref thunk) = replacement_map.get(function_idx) { + *function_idx = thunk + .idx + .expect("At this point an index must be assigned to each thunk"); + } + }; + + for section in module.sections_mut() { + match *section { + elements::Section::Export(ref mut export_section) => { + for entry in export_section.entries_mut() { + match *entry.internal_mut() { + Internal::Function(ref mut function_idx) => fixup(function_idx), + _ => {} + } + } + } + elements::Section::Element(ref mut elem_section) => { + for segment in elem_section.entries_mut() { + for function_idx in segment.members_mut() { + fixup(function_idx) + } + } + } + _ => {} + } + } + + Ok(module) +} diff --git a/stack_height/Cargo.toml b/stack_height/Cargo.toml new file mode 100644 index 0000000..ea4e7f5 --- /dev/null +++ b/stack_height/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "stack_height" +version = "0.1.0" +authors = ["Sergey Pepyakin "] + +[dependencies] +wasm-utils = { path = "../" } diff --git a/stack_height/src/main.rs b/stack_height/src/main.rs new file mode 100644 index 0000000..f65028c --- /dev/null +++ b/stack_height/src/main.rs @@ -0,0 +1,27 @@ +extern crate wasm_utils; +extern crate parity_wasm; + +use std::env; +use wasm_utils::stack_height; + +fn main() { + wasm_utils::init_log(); + + let args = env::args().collect::>(); + if args.len() != 3 { + println!("Usage: {} input_file.wasm output_file.wasm", args[0]); + return; + } + + let input_file = &args[1]; + let output_file = &args[2]; + + // Loading module + let module = parity_wasm::deserialize_file(&input_file).expect("Module deserialization to succeed"); + + let result = stack_height::inject_limiter( + module, 1024 + ).expect("Failed to inject stack height counter"); + + parity_wasm::serialize_to_file(&output_file, result).expect("Module serialization to succeed") +} diff --git a/tests/diff.rs b/tests/diff.rs new file mode 100644 index 0000000..1ad3d43 --- /dev/null +++ b/tests/diff.rs @@ -0,0 +1,93 @@ +extern crate diff; +extern crate wasm_utils; +extern crate wabt; +extern crate parity_wasm; + +use std::fs; +use std::io::{self, Read, Write}; +use std::path::{Path, PathBuf}; +use parity_wasm::elements; + +fn slurp>(path: P) -> io::Result> { + let mut f = fs::File::open(path)?; + let mut buf = vec![]; + f.read_to_end(&mut buf)?; + Ok(buf) +} + +fn dump>(path: P, buf: &[u8]) -> io::Result<()> { + let mut f = fs::File::create(path)?; + f.write_all(buf)?; + Ok(()) +} + +fn validate_wasm(binary: &[u8]) -> Result<(), wabt::Error> { + wabt::Module::read_binary( + &binary, + &Default::default() + )?.validate()?; + Ok(()) +} + +fn run_diff_test Vec>(test_dir: &str, name: &str, test: F) { + let mut fixture_path = PathBuf::from(concat!( + env!("CARGO_MANIFEST_DIR"), + "/tests/fixtures/", + )); + fixture_path.push(test_dir); + fixture_path.push(name); + + let mut expected_path = PathBuf::from(concat!( + env!("CARGO_MANIFEST_DIR"), + "/tests/expectations/" + )); + expected_path.push(test_dir); + expected_path.push(name); + + let fixture_wat = slurp(&fixture_path).expect("Failed to read fixture"); + let fixture_wasm = wabt::wat2wasm(fixture_wat).expect("Failed to read fixture"); + validate_wasm(&fixture_wasm).expect("Fixture is invalid"); + + let expected_wat = slurp(&expected_path).unwrap_or_default(); + let expected_wat = String::from_utf8_lossy(&expected_wat); + + let actual_wasm = test(fixture_wasm.as_ref()); + validate_wasm(&actual_wasm).expect("Result module is invalid"); + + let actual_wat = wabt::wasm2wat(&actual_wasm).expect("Failed to convert result wasm to wat"); + + if actual_wat != expected_wat { + println!("difference!"); + println!("--- {}", expected_path.display()); + println!("+++ {} test {}", test_dir, name); + for diff in diff::lines(&expected_wat, &actual_wat) { + match diff { + diff::Result::Left(l) => println!("-{}", l), + diff::Result::Both(l, _) => println!(" {}", l), + diff::Result::Right(r) => println!("+{}", r), + } + } + + dump(&expected_path, actual_wat.as_bytes()).expect("Failed to write to expected"); + + panic!(); + } +} + +macro_rules! def_stack_height_test { + ( $name:ident ) => { + #[test] + fn $name() { + run_diff_test("stack-height", concat!(stringify!($name), ".wat"), |input| { + let module = elements::deserialize_buffer(input).expect("Failed to deserialize"); + let instrumented = wasm_utils::stack_height::inject_limiter(module, 1024).expect("Failed to instrument with stack counter"); + elements::serialize(instrumented).expect("Failed to serialize") + }); + } + }; +} + +def_stack_height_test!(simple); +def_stack_height_test!(table); +def_stack_height_test!(global); +def_stack_height_test!(imports); diff --git a/tests/expectations/stack-height/global.wat b/tests/expectations/stack-height/global.wat new file mode 100644 index 0000000..f0f3e00 --- /dev/null +++ b/tests/expectations/stack-height/global.wat @@ -0,0 +1,56 @@ +(module + (type (;0;) (func)) + (type (;1;) (func (param i32 i32) (result i32))) + (type (;2;) (func (param i32))) + (type (;3;) (func (param i32 i32) (result i32))) + (import "env" "foo" (func (;0;) (type 0))) + (func (;1;) (type 1) (param i32 i32) (result i32) + get_local 0 + get_local 1 + i32.add) + (func (;2;) (type 2) (param i32) + (local i32) + get_global 0 + i32.const 1 + i32.add + tee_local 1 + set_global 0 + get_local 1 + get_local 0 + get_global 1 + i32.const 2 + i32.add + set_global 1 + get_global 1 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 1 + get_global 1 + i32.const 2 + i32.sub + set_global 1 + drop) + (func (;3;) (type 3) (param i32 i32) (result i32) + get_local 0 + get_local 1 + get_global 1 + i32.const 2 + i32.add + set_global 1 + get_global 1 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 1 + get_global 1 + i32.const 2 + i32.sub + set_global 1) + (global (;0;) (mut i32) (i32.const 1)) + (global (;1;) (mut i32) (i32.const 0)) + (export "i32.add" (func 3))) diff --git a/tests/expectations/stack-height/imports.wat b/tests/expectations/stack-height/imports.wat new file mode 100644 index 0000000..ee14b2a --- /dev/null +++ b/tests/expectations/stack-height/imports.wat @@ -0,0 +1,32 @@ +(module + (type (;0;) (func)) + (type (;1;) (func (param i32 i32) (result i32))) + (type (;2;) (func (param i32 i32) (result i32))) + (import "env" "foo" (func (;0;) (type 0))) + (import "env" "boo" (func (;1;) (type 0))) + (func (;2;) (type 1) (param i32 i32) (result i32) + call 0 + call 1 + get_local 0 + get_local 1 + i32.add) + (func (;3;) (type 2) (param i32 i32) (result i32) + get_local 0 + get_local 1 + get_global 0 + i32.const 2 + i32.add + set_global 0 + get_global 0 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 2 + get_global 0 + i32.const 2 + i32.sub + set_global 0) + (global (;0;) (mut i32) (i32.const 0)) + (export "i32.add" (func 3))) diff --git a/tests/expectations/stack-height/simple.wat b/tests/expectations/stack-height/simple.wat new file mode 100644 index 0000000..4a6c324 --- /dev/null +++ b/tests/expectations/stack-height/simple.wat @@ -0,0 +1,24 @@ +(module + (type (;0;) (func)) + (type (;1;) (func)) + (func (;0;) (type 0) + i32.const 123 + drop) + (func (;1;) (type 1) + get_global 0 + i32.const 1 + i32.add + set_global 0 + get_global 0 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 0 + get_global 0 + i32.const 1 + i32.sub + set_global 0) + (global (;0;) (mut i32) (i32.const 0)) + (export "simple" (func 1))) diff --git a/tests/expectations/stack-height/table.wat b/tests/expectations/stack-height/table.wat new file mode 100644 index 0000000..f2d4f2c --- /dev/null +++ b/tests/expectations/stack-height/table.wat @@ -0,0 +1,88 @@ +(module + (type (;0;) (func)) + (type (;1;) (func (param i32))) + (type (;2;) (func (param i32 i32) (result i32))) + (type (;3;) (func (param i32 i32) (result i32))) + (type (;4;) (func (param i32))) + (type (;5;) (func (param i32 i32) (result i32))) + (import "env" "foo" (func (;0;) (type 0))) + (func (;1;) (type 1) (param i32) + get_local 0 + i32.const 0 + get_global 0 + i32.const 2 + i32.add + set_global 0 + get_global 0 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 2 + get_global 0 + i32.const 2 + i32.sub + set_global 0 + drop) + (func (;2;) (type 2) (param i32 i32) (result i32) + get_local 0 + get_local 1 + i32.add) + (func (;3;) (type 3) (param i32 i32) (result i32) + get_local 0 + get_local 1 + get_global 0 + i32.const 2 + i32.add + set_global 0 + get_global 0 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 2 + get_global 0 + i32.const 2 + i32.sub + set_global 0) + (func (;4;) (type 4) (param i32) + get_local 0 + get_global 0 + i32.const 2 + i32.add + set_global 0 + get_global 0 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 1 + get_global 0 + i32.const 2 + i32.sub + set_global 0) + (func (;5;) (type 5) (param i32 i32) (result i32) + get_local 0 + get_local 1 + get_global 0 + i32.const 2 + i32.add + set_global 0 + get_global 0 + i32.const 1024 + i32.gt_u + if ;; label = @1 + unreachable + end + call 2 + get_global 0 + i32.const 2 + i32.sub + set_global 0) + (table (;0;) 10 anyfunc) + (global (;0;) (mut i32) (i32.const 0)) + (export "i32.add" (func 5)) + (elem (i32.const 0) 0 4 5)) diff --git a/tests/fixtures/stack-height/global.wat b/tests/fixtures/stack-height/global.wat new file mode 100644 index 0000000..7251122 --- /dev/null +++ b/tests/fixtures/stack-height/global.wat @@ -0,0 +1,26 @@ +(module + (import "env" "foo" (func $foo)) + + ;; Declare a global. + (global $counter (mut i32) (i32.const 1)) + + (func $i32.add (export "i32.add") (param i32 i32) (result i32) + get_local 0 + get_local 1 + i32.add + ) + (func (param $arg i32) + (local $tmp i32) + + get_global 0 + i32.const 1 + i32.add + tee_local $tmp + set_global $counter + + get_local $tmp + get_local $arg + call $i32.add + drop + ) +) diff --git a/tests/fixtures/stack-height/imports.wat b/tests/fixtures/stack-height/imports.wat new file mode 100644 index 0000000..9602769 --- /dev/null +++ b/tests/fixtures/stack-height/imports.wat @@ -0,0 +1,15 @@ +;; This test + +(module + (import "env" "foo" (func $foo)) + (import "env" "boo" (func $boo)) + + (func (export "i32.add") (param i32 i32) (result i32) + call $foo + call $boo + + get_local 0 + get_local 1 + i32.add + ) +) diff --git a/tests/fixtures/stack-height/simple.wat b/tests/fixtures/stack-height/simple.wat new file mode 100644 index 0000000..94290ec --- /dev/null +++ b/tests/fixtures/stack-height/simple.wat @@ -0,0 +1,6 @@ +(module + (func (export "simple") + i32.const 123 + drop + ) +) diff --git a/tests/fixtures/stack-height/table.wat b/tests/fixtures/stack-height/table.wat new file mode 100644 index 0000000..f17acdd --- /dev/null +++ b/tests/fixtures/stack-height/table.wat @@ -0,0 +1,18 @@ +(module + (import "env" "foo" (func $foo)) + (func (param i32) + get_local 0 + i32.const 0 + call $i32.add + drop + ) + (func $i32.add (export "i32.add") (param i32 i32) (result i32) + get_local 0 + get_local 1 + i32.add + ) + (table 10 anyfunc) + + ;; Refer all types of functions: imported, defined not exported and defined exported. + (elem (i32.const 0) 0 1 2) +)