diff --git a/src/lib.rs b/src/lib.rs index aaf64e0..917a0dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,3 +9,5 @@ mod stack_limiter; pub use export_globals::export_mutable_globals; pub use parity_wasm; pub use stack_limiter::inject as inject_stack_limiter; +pub use stack_limiter::compute_stack_cost; +pub use stack_limiter::compute_stack_costs; diff --git a/src/stack_limiter/max_height.rs b/src/stack_limiter/max_height.rs index 4be939b..5f145a4 100644 --- a/src/stack_limiter/max_height.rs +++ b/src/stack_limiter/max_height.rs @@ -1,16 +1,10 @@ use super::resolve_func_type; use alloc::vec::Vec; -use parity_wasm::elements::{self, BlockType, Type}; +use parity_wasm::elements::{self, BlockType, Type, ValueType}; #[cfg(feature = "sign_ext")] use parity_wasm::elements::SignExtInstruction; -// The cost in stack items that should be charged per call of a function. This is -// is a static cost that is added to each function call. This makes sense because even -// if a function does not use any parameters or locals some stack space on the host -// machine might be consumed to hold some context. -const ACTIVATION_FRAME_COST: u32 = 2; - /// Control stack frame. #[derive(Debug)] struct Frame { @@ -18,36 +12,41 @@ struct Frame { /// never passes control further was executed. is_polymorphic: bool, - /// Count of values which will be pushed after the exit - /// from the current block. - end_arity: u32, + /// Type of value which will be pushed after exiting + /// the current block or `None` if nothing is pushed. + result_type: Option, - /// Count of values which should be poped upon a branch to - /// this frame. + /// Type of value which should be poped upon a branch to + /// this frame or `None` if nothing is popped. /// - /// This might be diffirent from `end_arity` since branch + /// This might be diffirent from `result_type` since branch /// to the loop header can't take any values. - branch_arity: u32, + branch_type: Option, /// Stack height before entering in the block. - start_height: u32, + start_height: usize, } -/// This is a compound stack that abstracts tracking height of the value stack +/// This is a compound stack that abstracts tracking height and weight of the value stack /// and manipulation of the control stack. struct Stack { - height: u32, + values: Vec, control_stack: Vec, } impl Stack { fn new() -> Stack { - Stack { height: ACTIVATION_FRAME_COST, control_stack: Vec::new() } + Stack { values: Vec::new(), control_stack: Vec::new() } + } + + /// Returns current weight of the value stack. + fn weight(&self) -> u32 { + self.values.iter().map(|v| value_cost(*v)).sum() } /// Returns current height of the value stack. - fn height(&self) -> u32 { - self.height + fn height(&self) -> usize { + self.values.len() } /// Returns a reference to a frame by specified depth relative to the top of @@ -81,43 +80,42 @@ impl Stack { } /// Truncate the height of value stack to the specified height. - fn trunc(&mut self, new_height: u32) { - self.height = new_height; + fn trunc(&mut self, new_height: usize) { + self.values.truncate(new_height); } - /// Push specified number of values into the value stack. - /// - /// Returns `Err` if the height overflow usize value. - fn push_values(&mut self, value_count: u32) -> Result<(), &'static str> { - self.height = self.height.checked_add(value_count).ok_or("stack overflow")?; + /// Push a value into the value stack. + fn push_value(&mut self, value: ValueType) -> Result<(), &'static str> { + self.values.push(value); Ok(()) } - /// Pop specified number of values from the value stack. + /// Pop a value from the value stack. /// /// Returns `Err` if the stack happen to be negative value after - /// values popped. - fn pop_values(&mut self, value_count: u32) -> Result<(), &'static str> { - if value_count == 0 { - return Ok(()) - } - { - let top_frame = self.frame(0)?; - if self.height == top_frame.start_height { - // It is an error to pop more values than was pushed in the current frame - // (ie pop values pushed in the parent frame), unless the frame became - // polymorphic. - return if top_frame.is_polymorphic { - Ok(()) - } else { - return Err("trying to pop more values than pushed") - } + /// value popped. + fn pop_value(&mut self) -> Result, &'static str> { + let top_frame = self.frame(0)?; + if self.height() == top_frame.start_height { + return if top_frame.is_polymorphic { + Ok(None) + } else { + Err("trying to pop more values than pushed") } } - self.height = self.height.checked_sub(value_count).ok_or("stack underflow")?; + if self.height() > 0 { + Ok(self.values.pop()) + } else { + Err("trying to pop more values than pushed") + } + } +} - Ok(()) +fn value_cost(val: ValueType) -> u32 { + match val { + ValueType::I32 | ValueType::F32 => 1, + ValueType::I64 | ValueType::F64 => 2, } } @@ -145,17 +143,44 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result = if let Some(global_section) = module.global_section() { + global_section.entries() + .into_iter() + .map(|g| g.global_type().content_type()) + .collect() + } else { + vec![] + }; + + let locals: Vec = func_signature + .params() + .into_iter() + .cloned() + .chain( + body.locals() + .iter() + .flat_map(|l| vec![l.value_type(); l.count() as usize]) + ).collect(); + let mut stack = Stack::new(); - let mut max_height: u32 = 0; + let mut max_weight: u32 = 0; let mut pc = 0; // Add implicit frame for the function. Breaks to this frame and execution of // the last end should deal with this frame. - let func_arity = func_signature.results().len() as u32; + let func_results = func_signature.results(); + let param_weight: u32 = func_signature + .params().iter() + .map(|v| value_cost(*v)) + .sum(); + + let func_result_type = if func_results.len() == 0 { None } else { Some(func_results[0]) }; + stack.push_frame(Frame { is_polymorphic: false, - end_arity: func_arity, - branch_arity: func_arity, + result_type: func_result_type, + branch_type: func_result_type, start_height: 0, }); @@ -164,28 +189,20 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result max_height && !stack.frame(0)?.is_polymorphic { - max_height = stack.height(); - } - let opcode = &instructions.elements()[pc]; match opcode { Nop => {}, Block(ty) | Loop(ty) | If(ty) => { - let end_arity = if *ty == BlockType::NoResult { 0 } else { 1 }; - let branch_arity = if let Loop(_) = *opcode { 0 } else { end_arity }; if let If(_) = *opcode { - stack.pop_values(1)?; + stack.pop_value()?; } let height = stack.height(); + let end_result = if let BlockType::Value(vt) = *ty { Some(vt) } else { None }; stack.push_frame(Frame { is_polymorphic: false, - end_arity, - branch_arity, + result_type: end_result, + branch_type: if let Loop(_) = *opcode { None } else { end_result }, start_height: height, }); }, @@ -196,45 +213,53 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result { let frame = stack.pop_frame()?; stack.trunc(frame.start_height); - stack.push_values(frame.end_arity)?; + if let Some(vt) = frame.result_type { + stack.push_value(vt)?; + } }, Unreachable => { stack.mark_unreachable()?; }, Br(target) => { // Pop values for the destination block result. - let target_arity = stack.frame(*target)?.branch_arity; - stack.pop_values(target_arity)?; + if let Some(_) = stack.frame(*target)?.branch_type { + stack.pop_value()?; + } // This instruction unconditionally transfers control to the specified block, // thus all instruction until the end of the current block is deemed unreachable stack.mark_unreachable()?; }, BrIf(target) => { + let target_type = stack.frame(*target)?.branch_type; // Pop values for the destination block result. - let target_arity = stack.frame(*target)?.branch_arity; - stack.pop_values(target_arity)?; + if target_type.is_some() { + stack.pop_value()?; + } // Pop condition value. - stack.pop_values(1)?; + stack.pop_value()?; // Push values back. - stack.push_values(target_arity)?; + if let Some(vt) = target_type { + stack.push_value(vt)?; + } }, BrTable(br_table_data) => { - let arity_of_default = stack.frame(br_table_data.default)?.branch_arity; + let default_type = stack.frame(br_table_data.default)?.branch_type; // Check that all jump targets have an equal arities. for target in &*br_table_data.table { - let arity = stack.frame(*target)?.branch_arity; - if arity != arity_of_default { - return Err("Arity of all jump-targets must be equal") + if stack.frame(*target)?.branch_type != default_type { + return Err("Types of all jump-targets must be equal") } } - // Because all jump targets have an equal arities, we can just take arity of + // Because all jump targets have equal types, we can just take type of // the default branch. - stack.pop_values(arity_of_default)?; + if default_type.is_some() { + stack.pop_value()?; + } // This instruction doesn't let control flow to go further, since the control flow // should take either one of branches depending on the value or the default branch. @@ -243,80 +268,114 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result { // Pop return values of the function. Mark successive instructions as unreachable // since this instruction doesn't let control flow to go further. - stack.pop_values(func_arity)?; + if func_result_type.is_some() { + stack.pop_value()?; + } stack.mark_unreachable()?; }, Call(idx) => { let ty = resolve_func_type(*idx, module)?; // Pop values for arguments of the function. - stack.pop_values(ty.params().len() as u32)?; + for _ in ty.params() { + stack.pop_value()?; + } // Push result of the function execution to the stack. - let callee_arity = ty.results().len() as u32; - stack.push_values(callee_arity)?; + let callee_results = ty.results(); + if callee_results.len() > 0 { + stack.push_value(callee_results[0])?; + } }, CallIndirect(x, _) => { let Type::Function(ty) = type_section.types().get(*x as usize).ok_or("Type not found")?; // Pop the offset into the function table. - stack.pop_values(1)?; + stack.pop_value()?; // Pop values for arguments of the function. - stack.pop_values(ty.params().len() as u32)?; + for _ in ty.params() { + stack.pop_value()?; + } // Push result of the function execution to the stack. - let callee_arity = ty.results().len() as u32; - stack.push_values(callee_arity)?; + let callee_results = ty.results(); + if callee_results.len() > 0 { + stack.push_value(callee_results[0])?; + } }, Drop => { - stack.pop_values(1)?; + stack.pop_value()?; }, Select => { // Pop two values and one condition. - stack.pop_values(2)?; - stack.pop_values(1)?; + let val = stack.pop_value()?; + stack.pop_value()?; + stack.pop_value()?; // Push the selected value. - stack.push_values(1)?; + if let Some(vt) = val { + stack.push_value(vt)?; + } }, - GetLocal(_) => { - stack.push_values(1)?; + GetLocal(idx) => { + let idx = *idx as usize; + if idx >= locals.len() { + return Err("Reference to a global is out of bounds") + } + stack.push_value(locals[idx])?; }, SetLocal(_) => { - stack.pop_values(1)?; + stack.pop_value()?; }, - TeeLocal(_) => { + TeeLocal(idx) => { // This instruction pops and pushes the value, so // effectively it doesn't modify the stack height. - stack.pop_values(1)?; - stack.push_values(1)?; + let idx = *idx as usize; + if idx >= locals.len() { + return Err("Reference to a local is out of bounds") + } + stack.pop_value()?; + stack.push_value(locals[idx])?; }, - GetGlobal(_) => { - stack.push_values(1)?; + GetGlobal(idx) => { + let idx = *idx as usize; + if idx >= globals.len() { + return Err("Reference to a global is out of bounds") + } + stack.push_value(globals[idx])?; }, SetGlobal(_) => { - stack.pop_values(1)?; + stack.pop_value()?; }, + + // These instructions pop the address and pushes the result I32Load(_, _) | - I64Load(_, _) | - F32Load(_, _) | - F64Load(_, _) | I32Load8S(_, _) | I32Load8U(_, _) | I32Load16S(_, _) | - I32Load16U(_, _) | + I32Load16U(_, _) => { + stack.pop_value()?; + stack.push_value(ValueType::I32)?; + }, + I64Load(_, _) | I64Load8S(_, _) | I64Load8U(_, _) | I64Load16S(_, _) | I64Load16U(_, _) | I64Load32S(_, _) | I64Load32U(_, _) => { - // These instructions pop the address and pushes the result, - // which effictively don't modify the stack height. - stack.pop_values(1)?; - stack.push_values(1)?; + stack.pop_value()?; + stack.push_value(ValueType::I64)?; + }, + F32Load(_, _) => { + stack.pop_value()?; + stack.push_value(ValueType::F32)?; + }, + F64Load(_, _) => { + stack.pop_value()?; + stack.push_value(ValueType::F64)?; }, I32Store(_, _) | @@ -329,29 +388,30 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result { // These instructions pop the address and the value. - stack.pop_values(2)?; + stack.pop_value()?; + stack.pop_value()?; }, CurrentMemory(_) => { // Pushes current memory size - stack.push_values(1)?; + stack.push_value(ValueType::I32)?; }, GrowMemory(_) => { // Grow memory takes the value of pages to grow and pushes - stack.pop_values(1)?; - stack.push_values(1)?; + stack.pop_value()?; + stack.push_value(ValueType::I32)?; }, - I32Const(_) | I64Const(_) | F32Const(_) | F64Const(_) => { - // These instructions just push the single literal value onto the stack. - stack.push_values(1)?; - }, + I32Const(_) => { stack.push_value(ValueType::I32)?; }, + I64Const(_) => { stack.push_value(ValueType::I64)?; }, + F32Const(_) => { stack.push_value(ValueType::F32)?; }, + F64Const(_) => { stack.push_value(ValueType::F64)?; }, I32Eqz | I64Eqz => { // These instructions pop the value and compare it against zero, and pushes // the result of the comparison. - stack.pop_values(1)?; - stack.push_values(1)?; + stack.pop_value()?; + stack.push_value(ValueType::I32)?; }, I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU | I32LeS | I32LeU | I32GeS | @@ -359,16 +419,18 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result { // Comparison operations take two operands and produce one result. - stack.pop_values(2)?; - stack.push_values(1)?; + stack.pop_value()?; + stack.pop_value()?; + stack.push_value(ValueType::I32)?; }, I32Clz | I32Ctz | I32Popcnt | I64Clz | I64Ctz | I64Popcnt | F32Abs | F32Neg | F32Ceil | F32Floor | F32Trunc | F32Nearest | F32Sqrt | F64Abs | F64Neg | F64Ceil | F64Floor | F64Trunc | F64Nearest | F64Sqrt => { // Unary operators take one operand and produce one result. - stack.pop_values(1)?; - stack.push_values(1)?; + if let Some(vt) = stack.pop_value()? { + stack.push_value(vt)?; + } }, I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU | I32And | I32Or | @@ -378,19 +440,34 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result { // Binary operators take two operands and produce one result. - stack.pop_values(2)?; - stack.push_values(1)?; + let val = stack.pop_value()?; + stack.pop_value()?; + if let Some(vt) = val { + stack.push_value(vt)?; + } }, + // Conversion operators take one value and produce one result. I32WrapI64 | I32TruncSF32 | I32TruncUF32 | I32TruncSF64 | I32TruncUF64 | + I32ReinterpretF32 => { + stack.pop_value()?; + stack.push_value(ValueType::I32)?; + }, I64ExtendSI32 | I64ExtendUI32 | I64TruncSF32 | I64TruncUF32 | I64TruncSF64 | - I64TruncUF64 | F32ConvertSI32 | F32ConvertUI32 | F32ConvertSI64 | F32ConvertUI64 | - F32DemoteF64 | F64ConvertSI32 | F64ConvertUI32 | F64ConvertSI64 | F64ConvertUI64 | - F64PromoteF32 | I32ReinterpretF32 | I64ReinterpretF64 | F32ReinterpretI32 | - F64ReinterpretI64 => { - // Conversion operators take one value and produce one result. - stack.pop_values(1)?; - stack.push_values(1)?; + I64TruncUF64 | I64ReinterpretF64 => { + stack.pop_value()?; + stack.push_value(ValueType::I64)?; + }, + F32ConvertSI32 | F32ConvertUI32 | F32ConvertSI64 | F32ConvertUI64 | + F32DemoteF64 | F32ReinterpretI32 => { + stack.pop_value()?; + stack.push_value(ValueType::F32)?; + }, + + F64ConvertSI32 | F64ConvertUI32 | F64ConvertSI64 | F64ConvertUI64 | + F64PromoteF32 | F64ReinterpretI64 => { + stack.pop_value()?; + stack.push_value(ValueType::F64)?; }, #[cfg(feature = "sign_ext")] @@ -399,14 +476,23 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result { - stack.pop_values(1)?; - stack.push_values(1)?; + if let Some(vt) = stack.pop_value()? { + stack.push_value(vt)?; + } }, } + + // If current value stack is heavier than maximal weight observed so far, + // save the new weight. + // However, we don't increase maximal value in unreachable code. + if stack.weight() > max_weight && !stack.frame(0)?.is_polymorphic { + max_weight = stack.weight(); + } + pc += 1; } - Ok(max_height) + Ok(max_weight + param_weight) } #[cfg(test)] @@ -436,8 +522,8 @@ mod tests { "#, ); - let height = compute(0, &module).unwrap(); - assert_eq!(height, 3 + ACTIVATION_FRAME_COST); + let weight = compute(0, &module).unwrap(); + assert_eq!(weight, 3); } #[test] @@ -446,15 +532,15 @@ mod tests { r#" (module (func (result i32) - i32.const 0 + i64.const 0 return ) ) "#, ); - let height = compute(0, &module).unwrap(); - assert_eq!(height, 1 + ACTIVATION_FRAME_COST); + let weight = compute(0, &module).unwrap(); + assert_eq!(weight, 2); } #[test] @@ -471,8 +557,8 @@ mod tests { "#, ); - let height = compute(0, &module).unwrap(); - assert_eq!(height, ACTIVATION_FRAME_COST); + let weight = compute(0, &module).unwrap(); + assert_eq!(weight, 0); } #[test] @@ -500,8 +586,8 @@ mod tests { "#, ); - let height = compute(0, &module).unwrap(); - assert_eq!(height, 2 + ACTIVATION_FRAME_COST); + let weight = compute(0, &module).unwrap(); + assert_eq!(weight, 2); } #[test] @@ -524,8 +610,8 @@ mod tests { "#, ); - let height = compute(0, &module).unwrap(); - assert_eq!(height, 1 + ACTIVATION_FRAME_COST); + let weight = compute(0, &module).unwrap(); + assert_eq!(weight, 1); } #[test] @@ -546,8 +632,8 @@ mod tests { "#, ); - let height = compute(0, &module).unwrap(); - assert_eq!(height, 1 + ACTIVATION_FRAME_COST); + let weight = compute(0, &module).unwrap(); + assert_eq!(weight, 1); } #[test] @@ -572,7 +658,7 @@ mod tests { "#, ); - let height = compute(0, &module).unwrap(); - assert_eq!(height, 3 + ACTIVATION_FRAME_COST); + let weight = compute(0, &module).unwrap(); + assert_eq!(weight, 3); } } diff --git a/src/stack_limiter/mod.rs b/src/stack_limiter/mod.rs index 99f13ae..be4a377 100644 --- a/src/stack_limiter/mod.rs +++ b/src/stack_limiter/mod.rs @@ -154,7 +154,7 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 { /// Calculate stack costs for all functions. /// /// Returns a vector with a stack cost for each function, including imports. -fn compute_stack_costs(module: &elements::Module) -> Result, &'static str> { +pub fn compute_stack_costs(module: &elements::Module) -> Result, &'static str> { let func_imports = module.import_count(elements::ImportCountType::Function); // TODO: optimize! @@ -173,7 +173,7 @@ fn compute_stack_costs(module: &elements::Module) -> Result, &'static s /// Stack cost of the given *defined* function is the sum of it's locals count (that is, /// number of arguments plus number of local variables) and the maximal stack /// height. -fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result { +pub fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result { // To calculate the cost of a function we need to convert index from // function index space to defined function spaces. let func_imports = module.import_count(elements::ImportCountType::Function) as u32;