From 310db5f92c0574aa4539098b204d5cf918a294ed Mon Sep 17 00:00:00 2001 From: Sergei Shulepov Date: Wed, 19 Oct 2022 17:56:02 +0200 Subject: [PATCH] Improve performance of stack height injection Specifically, - avoid using `import_count` and `function_space` since those are slow. - Avoid using the `builder` module since it requires recreation of the module essentially and is not efficient. - That also allows us to avoid creation of redundant signatures for thunks. --- src/stack_limiter/max_height.rs | 80 +++++------ src/stack_limiter/mod.rs | 235 ++++++++++++++++++-------------- src/stack_limiter/thunk.rs | 101 ++++++++------ 3 files changed, 232 insertions(+), 184 deletions(-) diff --git a/src/stack_limiter/max_height.rs b/src/stack_limiter/max_height.rs index 4be939b..089bf5e 100644 --- a/src/stack_limiter/max_height.rs +++ b/src/stack_limiter/max_height.rs @@ -1,10 +1,11 @@ -use super::resolve_func_type; use alloc::vec::Vec; use parity_wasm::elements::{self, BlockType, Type}; #[cfg(feature = "sign_ext")] use parity_wasm::elements::SignExtInstruction; +use super::Context; + // The cost in stack items that should be charged per call of a function. This is // is a static cost that is added to each function call. This makes sense because even // if a function does not use any parameters or locals some stack space on the host @@ -122,7 +123,11 @@ impl Stack { } /// This function expects the function to be validated. -pub fn compute(func_idx: u32, module: &elements::Module) -> Result { +pub fn compute( + func_idx: u32, + ctx: &Context, + module: &elements::Module, +) -> Result { use parity_wasm::elements::Instruction::*; let func_section = module.function_section().ok_or("No function section")?; @@ -246,8 +251,10 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result { - let ty = resolve_func_type(*idx, module)?; + Call(fn_idx) => { + let ty_idx = ctx.func_type(*fn_idx).ok_or("function idx is not found in the func types list")?; + let Type::Function(ty) = + type_section.types().get(ty_idx as usize).ok_or("Type not found")?; // Pop values for arguments of the function. stack.pop_values(ty.params().len() as u32)?; @@ -411,7 +418,8 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result elements::Module { @@ -419,10 +427,15 @@ mod tests { .expect("Failed to deserialize the module") } + fn test_compute(func_idx: u32, source: &str) -> u32 { + let module = parse_wat(source); + let ctx = prepare_context(&module, 0).unwrap(); + ctx.stack_cost(func_idx).unwrap() + } + #[test] fn simple_test() { - let module = parse_wat( - r#" + let module = r#" (module (func i32.const 1 @@ -433,34 +446,30 @@ mod tests { drop ) ) -"#, - ); +"#; - let height = compute(0, &module).unwrap(); + let height = test_compute(0, module); assert_eq!(height, 3 + ACTIVATION_FRAME_COST); } #[test] fn implicit_and_explicit_return() { - let module = parse_wat( - r#" + let module = r#" (module (func (result i32) i32.const 0 return ) ) -"#, - ); +"#; - let height = compute(0, &module).unwrap(); + let height = test_compute(0, module); assert_eq!(height, 1 + ACTIVATION_FRAME_COST); } #[test] fn dont_count_in_unreachable() { - let module = parse_wat( - r#" + let module = r#" (module (memory 0) (func (result i32) @@ -468,17 +477,15 @@ mod tests { grow_memory ) ) -"#, - ); +"#; - let height = compute(0, &module).unwrap(); + let height = test_compute(0, module); assert_eq!(height, ACTIVATION_FRAME_COST); } #[test] fn yet_another_test() { - let module = parse_wat( - r#" + let module = r#" (module (memory 0) (func @@ -497,17 +504,15 @@ mod tests { i32.const 2 ) ) -"#, - ); +"#; - let height = compute(0, &module).unwrap(); + let height = test_compute(0, module); assert_eq!(height, 2 + ACTIVATION_FRAME_COST); } #[test] fn call_indirect() { - let module = parse_wat( - r#" + let module = r#" (module (table $ptr 1 1 funcref) (elem $ptr (i32.const 0) func 1) @@ -521,17 +526,15 @@ mod tests { drop ) ) -"#, - ); +"#; - let height = compute(0, &module).unwrap(); + let height = test_compute(0, module); assert_eq!(height, 1 + ACTIVATION_FRAME_COST); } #[test] fn breaks() { - let module = parse_wat( - r#" + let module = r#" (module (func $main block (result i32) @@ -543,17 +546,15 @@ mod tests { drop ) ) -"#, - ); +"#; - let height = compute(0, &module).unwrap(); + let height = test_compute(0, module); assert_eq!(height, 1 + ACTIVATION_FRAME_COST); } #[test] fn if_else_works() { - let module = parse_wat( - r#" + let module = r#" (module (func $main i32.const 7 @@ -569,10 +570,9 @@ mod tests { drop ) ) -"#, - ); +"#; - let height = compute(0, &module).unwrap(); + let height = test_compute(0, module); assert_eq!(height, 3 + ACTIVATION_FRAME_COST); } } diff --git a/src/stack_limiter/mod.rs b/src/stack_limiter/mod.rs index 99f13ae..a4934a2 100644 --- a/src/stack_limiter/mod.rs +++ b/src/stack_limiter/mod.rs @@ -39,7 +39,15 @@ mod max_height; mod thunk; pub struct Context { + /// Number of functions that the module imports. Required to convert defined functions indicies + /// into the global function index space. + func_imports: u32, + /// For each function in the function space this vector stores the respective type index. + func_types: Vec, + /// The index of the global variable that contains the current stack height. stack_height_global_idx: u32, + /// Logical stack costs for each function in the function space. Imported functions have cost + /// of 0. func_stack_costs: Vec, stack_limit: u32, } @@ -55,6 +63,11 @@ impl Context { self.func_stack_costs.get(func_idx as usize).cloned() } + /// Returns a reference to the function type index given by the index into the function space. + fn func_type(&self, func_idx: u32) -> Option { + self.func_types.get(func_idx as usize).copied() + } + /// Returns stack limit specified by the rules. fn stack_limit(&self) -> u32 { self.stack_limit @@ -115,20 +128,106 @@ pub fn inject( mut module: elements::Module, stack_limit: u32, ) -> Result { - let mut ctx = Context { - stack_height_global_idx: generate_stack_height_global(&mut module), - func_stack_costs: compute_stack_costs(&module)?, - stack_limit, - }; + let mut ctx = prepare_context(&module, stack_limit)?; - instrument_functions(&mut ctx, &mut module)?; + generate_stack_height_global(&mut ctx.stack_height_global_idx, &mut module)?; + instrument_functions(&ctx, &mut module)?; let module = thunk::generate_thunks(&mut ctx, module)?; Ok(module) } +fn prepare_context(module: &elements::Module, stack_limit: u32) -> Result { + let mut ctx = Context { + func_imports: module.import_count(elements::ImportCountType::Function) as u32, + func_types: Vec::new(), + stack_height_global_idx: 0, + func_stack_costs: Vec::new(), + stack_limit, + }; + collect_func_types(&mut ctx, &module)?; + compute_stack_costs(&mut ctx, &module)?; + Ok(ctx) +} + +fn collect_func_types(ctx: &mut Context, module: &elements::Module) -> Result<(), &'static str> { + let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]); + let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]); + let imports = module.import_section().map(|is| is.entries()).unwrap_or(&[]); + + let ensure_ty = |sig_idx: u32| -> Result<(), &'static str> { + let Type::Function(_) = types + .get(sig_idx as usize) + .ok_or("The signature as specified by a function isn't defined")?; + Ok(()) + }; + + for import in imports { + if let elements::External::Function(sig_idx) = import.external() { + ensure_ty(*sig_idx)?; + ctx.func_types.push(*sig_idx); + } + } + for def_func_idx in functions { + ensure_ty(def_func_idx.type_ref())?; + ctx.func_types.push(def_func_idx.type_ref()); + } + + Ok(()) +} + +/// Calculate stack costs for all functions in the function space. +/// +/// The function space consists of the imported functions followed by defined functions. +/// All imported functions assumed to have the cost of 0. +fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), &'static str> { + for _ in 0..ctx.func_imports { + ctx.func_stack_costs.push(0); + } + let def_func_n = module.function_section().map(|fs| fs.entries().len()).unwrap_or(0) as u32; + for def_func_idx in 0..def_func_n { + let cost = compute_stack_cost(def_func_idx, ctx, module)?; + ctx.func_stack_costs.push(cost); + } + Ok(()) +} + +/// Computes the stack cost of a given function. The function is specified by its index in the +/// declared function space. +/// +/// Stack cost of a given function is the sum of it's locals count (that is, +/// number of arguments plus number of local variables) and the maximal stack +/// height. +fn compute_stack_cost( + def_func_idx: u32, + ctx: &Context, + module: &elements::Module, +) -> Result { + let code_section = + module.code_section().ok_or("Due to validation code section should exists")?; + let body = &code_section + .bodies() + .get(def_func_idx as usize) + .ok_or("Function body is out of bounds")?; + + let mut locals_count: u32 = 0; + for local_group in body.locals() { + locals_count = + locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?; + } + + let max_stack_height = max_height::compute(def_func_idx, ctx, module)?; + + locals_count + .checked_add(max_stack_height) + .ok_or("Overflow in adding locals_count and max_stack_height") +} + /// Generate a new global that will be used for tracking current stack height. -fn generate_stack_height_global(module: &mut elements::Module) -> u32 { +fn generate_stack_height_global( + stack_height_global_idx: &mut u32, + module: &mut elements::Module, +) -> Result<(), &'static str> { let global_entry = builder::global() .value_type() .i32() @@ -140,71 +239,41 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 { for section in module.sections_mut() { if let elements::Section::Global(gs) = section { gs.entries_mut().push(global_entry); - return (gs.entries().len() as u32) - 1 + *stack_height_global_idx = (gs.entries().len() as u32) - 1; + return Ok(()); } } // Existing section not found, create one! - module - .sections_mut() - .push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry]))); - 0 -} - -/// Calculate stack costs for all functions. -/// -/// Returns a vector with a stack cost for each function, including imports. -fn compute_stack_costs(module: &elements::Module) -> Result, &'static str> { - let func_imports = module.import_count(elements::ImportCountType::Function); - - // TODO: optimize! - (0..module.functions_space()) - .map(|func_idx| { - if func_idx < func_imports { - // We can't calculate stack_cost of the import functions. - Ok(0) - } else { - compute_stack_cost(func_idx as u32, module) + // + // It's a bit tricky since the sections have a strict prescribed order. + let global_section = elements::GlobalSection::with_entries(vec![global_entry]); + let prec_index = module + .sections() + .iter() + .rposition(|section| { + use elements::Section::*; + match section { + Type(_) | Import(_) | Function(_) | Table(_) | Memory(_) => true, + _ => false, } }) - .collect() + .ok_or("generate stack height global hasn't found any preceding sections")?; + // now `prec_index` points to the last section preceding the `global_section`. It's guaranteed that at least + // one of those functions is present. Therefore, the candidate position for the global section is the following + // one. However, technically, custom sections could occupy any place between the well-known sections. + // + // Now, regarding `+1` here. `insert` panics iff `index > len`. `prec_index + 1` can only be equal to `len`. + module + .sections_mut() + .insert(prec_index + 1, elements::Section::Global(global_section)); + // First entry in the brand new globals section. + *stack_height_global_idx = 0; + + Ok(()) } -/// Stack cost of the given *defined* function is the sum of it's locals count (that is, -/// number of arguments plus number of local variables) and the maximal stack -/// height. -fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result { - // To calculate the cost of a function we need to convert index from - // function index space to defined function spaces. - let func_imports = module.import_count(elements::ImportCountType::Function) as u32; - let defined_func_idx = func_idx - .checked_sub(func_imports) - .ok_or("This should be a index of a defined function")?; - - let code_section = - module.code_section().ok_or("Due to validation code section should exists")?; - let body = &code_section - .bodies() - .get(defined_func_idx as usize) - .ok_or("Function body is out of bounds")?; - - let mut locals_count: u32 = 0; - for local_group in body.locals() { - locals_count = - locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?; - } - - let max_stack_height = max_height::compute(defined_func_idx, module)?; - - locals_count - .checked_add(max_stack_height) - .ok_or("Overflow in adding locals_count and max_stack_height") -} - -fn instrument_functions( - ctx: &mut Context, - module: &mut elements::Module, -) -> Result<(), &'static str> { +fn instrument_functions(ctx: &Context, module: &mut elements::Module) -> Result<(), &'static str> { for section in module.sections_mut() { if let elements::Section::Code(code_section) = section { for func_body in code_section.bodies_mut() { @@ -242,7 +311,7 @@ fn instrument_functions( /// /// drop /// ``` -fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), &'static str> { +fn instrument_function(ctx: &Context, func: &mut Instructions) -> Result<(), &'static str> { use Instruction::*; struct InstrumentCall { @@ -309,42 +378,6 @@ fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Ok(()) } -fn resolve_func_type( - func_idx: u32, - module: &elements::Module, -) -> Result<&elements::FunctionType, &'static str> { - let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]); - let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]); - - let func_imports = module.import_count(elements::ImportCountType::Function); - let sig_idx = if func_idx < func_imports as u32 { - module - .import_section() - .expect("function import count is not zero; import section must exists; qed") - .entries() - .iter() - .filter_map(|entry| match entry.external() { - elements::External::Function(idx) => Some(*idx), - _ => None, - }) - .nth(func_idx as usize) - .expect( - "func_idx is less than function imports count; - nth function import must be `Some`; - qed", - ) - } else { - functions - .get(func_idx as usize - func_imports) - .ok_or("Function at the specified index is not defined")? - .type_ref() - }; - let Type::Function(ty) = types - .get(sig_idx as usize) - .ok_or("The signature as specified by a function isn't defined")?; - Ok(ty) -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/stack_limiter/thunk.rs b/src/stack_limiter/thunk.rs index d28cd35..b365814 100644 --- a/src/stack_limiter/thunk.rs +++ b/src/stack_limiter/thunk.rs @@ -1,17 +1,17 @@ #[cfg(not(features = "std"))] use alloc::collections::BTreeMap as Map; use alloc::vec::Vec; -use parity_wasm::{ - builder, - elements::{self, FunctionType, Internal}, -}; +use parity_wasm::elements::{self, Internal}; #[cfg(features = "std")] use std::collections::HashMap as Map; -use super::{resolve_func_type, Context}; +use super::Context; struct Thunk { - signature: FunctionType, + /// The index of the signature in the type section. + type_idx: u32, + /// The number of parameters the function has. + param_num: u32, // Index in function space of this thunk. idx: Option, callee_stack_cost: u32, @@ -19,10 +19,11 @@ struct Thunk { pub fn generate_thunks( ctx: &mut Context, - module: elements::Module, + mut module: elements::Module, ) -> Result { // First, we need to collect all function indices that should be replaced by thunks let mut replacement_map: Map = { + let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]); let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]); let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]); let start_func_idx = module.start_section(); @@ -45,14 +46,12 @@ pub fn generate_thunks( // Don't generate a thunk if stack_cost of a callee is zero. if callee_stack_cost != 0 { - replacement_map.insert( - func_idx, - Thunk { - signature: resolve_func_type(func_idx, &module)?.clone(), - idx: None, - callee_stack_cost, - }, - ); + let type_idx = ctx.func_type(func_idx).ok_or("type idx for thunk not found")?; + let elements::Type::Function(func_ty) = + types.get(type_idx as usize).ok_or("sig for thunk is not found")?; + let param_num = func_ty.params().len() as u32; + replacement_map + .insert(func_idx, Thunk { type_idx, param_num, idx: None, callee_stack_cost }); } } @@ -61,10 +60,6 @@ pub fn generate_thunks( // Then, we generate a thunk for each original function. - // Save current func_idx - let mut next_func_idx = module.functions_space() as u32; - - let mut mbuilder = builder::from_module(module); for (func_idx, thunk) in replacement_map.iter_mut() { let instrumented_call = instrument_call!( *func_idx, @@ -77,32 +72,23 @@ pub fn generate_thunks( // - instrumented call // - end let mut thunk_body: Vec = - Vec::with_capacity(thunk.signature.params().len() + instrumented_call.len() + 1); + Vec::with_capacity(thunk.param_num as usize + instrumented_call.len() + 1); - for (arg_idx, _) in thunk.signature.params().iter().enumerate() { - thunk_body.push(elements::Instruction::GetLocal(arg_idx as u32)); + for arg_idx in 0..thunk.param_num { + thunk_body.push(elements::Instruction::GetLocal(arg_idx)); } thunk_body.extend_from_slice(&instrumented_call); thunk_body.push(elements::Instruction::End); - // TODO: Don't generate a signature, but find an existing one. - - mbuilder = mbuilder - .function() - // Signature of the thunk should match the original function signature. - .signature() - .with_params(thunk.signature.params().to_vec()) - .with_results(thunk.signature.results().to_vec()) - .build() - .body() - .with_instructions(elements::Instructions::new(thunk_body)) - .build() - .build(); - - thunk.idx = Some(next_func_idx); - next_func_idx += 1; + let func_idx = insert_function( + ctx, + &mut module, + thunk.type_idx, + Vec::new(), // No declared local variables. + elements::Instructions::new(thunk_body), + )?; + thunk.idx = Some(func_idx); } - let mut module = mbuilder.build(); // And finally, fixup thunks in export and table sections. @@ -118,18 +104,20 @@ pub fn generate_thunks( for section in module.sections_mut() { match section { - elements::Section::Export(export_section) => + elements::Section::Export(export_section) => { for entry in export_section.entries_mut() { if let Internal::Function(function_idx) = entry.internal_mut() { fixup(function_idx) } - }, - elements::Section::Element(elem_section) => + } + }, + elements::Section::Element(elem_section) => { for segment in elem_section.entries_mut() { for function_idx in segment.members_mut() { fixup(function_idx) } - }, + } + }, elements::Section::Start(start_idx) => fixup(start_idx), _ => {}, } @@ -137,3 +125,30 @@ pub fn generate_thunks( Ok(module) } + +/// Inserts a new function into the module and returns it's index in the function space. +/// +/// Specifically, inserts entires into the function section and the code section. +fn insert_function( + ctx: &Context, + module: &mut elements::Module, + type_idx: u32, + locals: Vec, + insns: elements::Instructions, +) -> Result { + let funcs = module + .function_section_mut() + .ok_or("insert function no function section")? + .entries_mut(); + let new_func_idx = ctx + .func_imports + .checked_add(funcs.len() as u32) + .ok_or("insert function func idx overflow")?; + funcs.push(elements::Func::new(type_idx)); + + let func_bodies = + module.code_section_mut().ok_or("insert function no code section")?.bodies_mut(); + func_bodies.push(elements::FuncBody::new(locals, insns)); + + Ok(new_func_idx) +}