1 Commits

Author SHA1 Message Date
Sergei Shulepov 310db5f92c Improve performance of stack height injection
Specifically,

- avoid using `import_count` and `function_space` since those are slow.
- Avoid using the `builder` module since it requires recreation of the module essentially and is not efficient.
- That also allows us to avoid creation of redundant signatures for thunks.
2022-10-19 17:56:29 +02:00
3 changed files with 232 additions and 184 deletions
+40 -40
View File
@@ -1,10 +1,11 @@
use super::resolve_func_type;
use alloc::vec::Vec; use alloc::vec::Vec;
use parity_wasm::elements::{self, BlockType, Type}; use parity_wasm::elements::{self, BlockType, Type};
#[cfg(feature = "sign_ext")] #[cfg(feature = "sign_ext")]
use parity_wasm::elements::SignExtInstruction; use parity_wasm::elements::SignExtInstruction;
use super::Context;
// The cost in stack items that should be charged per call of a function. This is // The cost in stack items that should be charged per call of a function. This is
// is a static cost that is added to each function call. This makes sense because even // is a static cost that is added to each function call. This makes sense because even
// if a function does not use any parameters or locals some stack space on the host // if a function does not use any parameters or locals some stack space on the host
@@ -122,7 +123,11 @@ impl Stack {
} }
/// This function expects the function to be validated. /// This function expects the function to be validated.
pub fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> { pub fn compute(
func_idx: u32,
ctx: &Context,
module: &elements::Module,
) -> Result<u32, &'static str> {
use parity_wasm::elements::Instruction::*; use parity_wasm::elements::Instruction::*;
let func_section = module.function_section().ok_or("No function section")?; let func_section = module.function_section().ok_or("No function section")?;
@@ -246,8 +251,10 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static
stack.pop_values(func_arity)?; stack.pop_values(func_arity)?;
stack.mark_unreachable()?; stack.mark_unreachable()?;
}, },
Call(idx) => { Call(fn_idx) => {
let ty = resolve_func_type(*idx, module)?; let ty_idx = ctx.func_type(*fn_idx).ok_or("function idx is not found in the func types list")?;
let Type::Function(ty) =
type_section.types().get(ty_idx as usize).ok_or("Type not found")?;
// Pop values for arguments of the function. // Pop values for arguments of the function.
stack.pop_values(ty.params().len() as u32)?; stack.pop_values(ty.params().len() as u32)?;
@@ -411,7 +418,8 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::ACTIVATION_FRAME_COST;
use crate::stack_limiter::prepare_context;
use parity_wasm::elements; use parity_wasm::elements;
fn parse_wat(source: &str) -> elements::Module { fn parse_wat(source: &str) -> elements::Module {
@@ -419,10 +427,15 @@ mod tests {
.expect("Failed to deserialize the module") .expect("Failed to deserialize the module")
} }
fn test_compute(func_idx: u32, source: &str) -> u32 {
let module = parse_wat(source);
let ctx = prepare_context(&module, 0).unwrap();
ctx.stack_cost(func_idx).unwrap()
}
#[test] #[test]
fn simple_test() { fn simple_test() {
let module = parse_wat( let module = r#"
r#"
(module (module
(func (func
i32.const 1 i32.const 1
@@ -433,34 +446,30 @@ mod tests {
drop drop
) )
) )
"#, "#;
);
let height = compute(0, &module).unwrap(); let height = test_compute(0, module);
assert_eq!(height, 3 + ACTIVATION_FRAME_COST); assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
} }
#[test] #[test]
fn implicit_and_explicit_return() { fn implicit_and_explicit_return() {
let module = parse_wat( let module = r#"
r#"
(module (module
(func (result i32) (func (result i32)
i32.const 0 i32.const 0
return return
) )
) )
"#, "#;
);
let height = compute(0, &module).unwrap(); let height = test_compute(0, module);
assert_eq!(height, 1 + ACTIVATION_FRAME_COST); assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
} }
#[test] #[test]
fn dont_count_in_unreachable() { fn dont_count_in_unreachable() {
let module = parse_wat( let module = r#"
r#"
(module (module
(memory 0) (memory 0)
(func (result i32) (func (result i32)
@@ -468,17 +477,15 @@ mod tests {
grow_memory grow_memory
) )
) )
"#, "#;
);
let height = compute(0, &module).unwrap(); let height = test_compute(0, module);
assert_eq!(height, ACTIVATION_FRAME_COST); assert_eq!(height, ACTIVATION_FRAME_COST);
} }
#[test] #[test]
fn yet_another_test() { fn yet_another_test() {
let module = parse_wat( let module = r#"
r#"
(module (module
(memory 0) (memory 0)
(func (func
@@ -497,17 +504,15 @@ mod tests {
i32.const 2 i32.const 2
) )
) )
"#, "#;
);
let height = compute(0, &module).unwrap(); let height = test_compute(0, module);
assert_eq!(height, 2 + ACTIVATION_FRAME_COST); assert_eq!(height, 2 + ACTIVATION_FRAME_COST);
} }
#[test] #[test]
fn call_indirect() { fn call_indirect() {
let module = parse_wat( let module = r#"
r#"
(module (module
(table $ptr 1 1 funcref) (table $ptr 1 1 funcref)
(elem $ptr (i32.const 0) func 1) (elem $ptr (i32.const 0) func 1)
@@ -521,17 +526,15 @@ mod tests {
drop drop
) )
) )
"#, "#;
);
let height = compute(0, &module).unwrap(); let height = test_compute(0, module);
assert_eq!(height, 1 + ACTIVATION_FRAME_COST); assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
} }
#[test] #[test]
fn breaks() { fn breaks() {
let module = parse_wat( let module = r#"
r#"
(module (module
(func $main (func $main
block (result i32) block (result i32)
@@ -543,17 +546,15 @@ mod tests {
drop drop
) )
) )
"#, "#;
);
let height = compute(0, &module).unwrap(); let height = test_compute(0, module);
assert_eq!(height, 1 + ACTIVATION_FRAME_COST); assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
} }
#[test] #[test]
fn if_else_works() { fn if_else_works() {
let module = parse_wat( let module = r#"
r#"
(module (module
(func $main (func $main
i32.const 7 i32.const 7
@@ -569,10 +570,9 @@ mod tests {
drop drop
) )
) )
"#, "#;
);
let height = compute(0, &module).unwrap(); let height = test_compute(0, module);
assert_eq!(height, 3 + ACTIVATION_FRAME_COST); assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
} }
} }
+134 -101
View File
@@ -39,7 +39,15 @@ mod max_height;
mod thunk; mod thunk;
pub struct Context { pub struct Context {
/// Number of functions that the module imports. Required to convert defined functions indicies
/// into the global function index space.
func_imports: u32,
/// For each function in the function space this vector stores the respective type index.
func_types: Vec<u32>,
/// The index of the global variable that contains the current stack height.
stack_height_global_idx: u32, stack_height_global_idx: u32,
/// Logical stack costs for each function in the function space. Imported functions have cost
/// of 0.
func_stack_costs: Vec<u32>, func_stack_costs: Vec<u32>,
stack_limit: u32, stack_limit: u32,
} }
@@ -55,6 +63,11 @@ impl Context {
self.func_stack_costs.get(func_idx as usize).cloned() self.func_stack_costs.get(func_idx as usize).cloned()
} }
/// Returns a reference to the function type index given by the index into the function space.
fn func_type(&self, func_idx: u32) -> Option<u32> {
self.func_types.get(func_idx as usize).copied()
}
/// Returns stack limit specified by the rules. /// Returns stack limit specified by the rules.
fn stack_limit(&self) -> u32 { fn stack_limit(&self) -> u32 {
self.stack_limit self.stack_limit
@@ -115,20 +128,106 @@ pub fn inject(
mut module: elements::Module, mut module: elements::Module,
stack_limit: u32, stack_limit: u32,
) -> Result<elements::Module, &'static str> { ) -> Result<elements::Module, &'static str> {
let mut ctx = Context { let mut ctx = prepare_context(&module, stack_limit)?;
stack_height_global_idx: generate_stack_height_global(&mut module),
func_stack_costs: compute_stack_costs(&module)?,
stack_limit,
};
instrument_functions(&mut ctx, &mut module)?; generate_stack_height_global(&mut ctx.stack_height_global_idx, &mut module)?;
instrument_functions(&ctx, &mut module)?;
let module = thunk::generate_thunks(&mut ctx, module)?; let module = thunk::generate_thunks(&mut ctx, module)?;
Ok(module) Ok(module)
} }
fn prepare_context(module: &elements::Module, stack_limit: u32) -> Result<Context, &'static str> {
let mut ctx = Context {
func_imports: module.import_count(elements::ImportCountType::Function) as u32,
func_types: Vec::new(),
stack_height_global_idx: 0,
func_stack_costs: Vec::new(),
stack_limit,
};
collect_func_types(&mut ctx, &module)?;
compute_stack_costs(&mut ctx, &module)?;
Ok(ctx)
}
fn collect_func_types(ctx: &mut Context, module: &elements::Module) -> Result<(), &'static str> {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
let imports = module.import_section().map(|is| is.entries()).unwrap_or(&[]);
let ensure_ty = |sig_idx: u32| -> Result<(), &'static str> {
let Type::Function(_) = types
.get(sig_idx as usize)
.ok_or("The signature as specified by a function isn't defined")?;
Ok(())
};
for import in imports {
if let elements::External::Function(sig_idx) = import.external() {
ensure_ty(*sig_idx)?;
ctx.func_types.push(*sig_idx);
}
}
for def_func_idx in functions {
ensure_ty(def_func_idx.type_ref())?;
ctx.func_types.push(def_func_idx.type_ref());
}
Ok(())
}
/// Calculate stack costs for all functions in the function space.
///
/// The function space consists of the imported functions followed by defined functions.
/// All imported functions assumed to have the cost of 0.
fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), &'static str> {
for _ in 0..ctx.func_imports {
ctx.func_stack_costs.push(0);
}
let def_func_n = module.function_section().map(|fs| fs.entries().len()).unwrap_or(0) as u32;
for def_func_idx in 0..def_func_n {
let cost = compute_stack_cost(def_func_idx, ctx, module)?;
ctx.func_stack_costs.push(cost);
}
Ok(())
}
/// Computes the stack cost of a given function. The function is specified by its index in the
/// declared function space.
///
/// Stack cost of a given function is the sum of it's locals count (that is,
/// number of arguments plus number of local variables) and the maximal stack
/// height.
fn compute_stack_cost(
def_func_idx: u32,
ctx: &Context,
module: &elements::Module,
) -> Result<u32, &'static str> {
let code_section =
module.code_section().ok_or("Due to validation code section should exists")?;
let body = &code_section
.bodies()
.get(def_func_idx as usize)
.ok_or("Function body is out of bounds")?;
let mut locals_count: u32 = 0;
for local_group in body.locals() {
locals_count =
locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
}
let max_stack_height = max_height::compute(def_func_idx, ctx, module)?;
locals_count
.checked_add(max_stack_height)
.ok_or("Overflow in adding locals_count and max_stack_height")
}
/// Generate a new global that will be used for tracking current stack height. /// Generate a new global that will be used for tracking current stack height.
fn generate_stack_height_global(module: &mut elements::Module) -> u32 { fn generate_stack_height_global(
stack_height_global_idx: &mut u32,
module: &mut elements::Module,
) -> Result<(), &'static str> {
let global_entry = builder::global() let global_entry = builder::global()
.value_type() .value_type()
.i32() .i32()
@@ -140,71 +239,41 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
for section in module.sections_mut() { for section in module.sections_mut() {
if let elements::Section::Global(gs) = section { if let elements::Section::Global(gs) = section {
gs.entries_mut().push(global_entry); gs.entries_mut().push(global_entry);
return (gs.entries().len() as u32) - 1 *stack_height_global_idx = (gs.entries().len() as u32) - 1;
return Ok(());
} }
} }
// Existing section not found, create one! // Existing section not found, create one!
module //
.sections_mut() // It's a bit tricky since the sections have a strict prescribed order.
.push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry]))); let global_section = elements::GlobalSection::with_entries(vec![global_entry]);
0 let prec_index = module
} .sections()
.iter()
/// Calculate stack costs for all functions. .rposition(|section| {
/// use elements::Section::*;
/// Returns a vector with a stack cost for each function, including imports. match section {
fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, &'static str> { Type(_) | Import(_) | Function(_) | Table(_) | Memory(_) => true,
let func_imports = module.import_count(elements::ImportCountType::Function); _ => false,
// TODO: optimize!
(0..module.functions_space())
.map(|func_idx| {
if func_idx < func_imports {
// We can't calculate stack_cost of the import functions.
Ok(0)
} else {
compute_stack_cost(func_idx as u32, module)
} }
}) })
.collect() .ok_or("generate stack height global hasn't found any preceding sections")?;
// now `prec_index` points to the last section preceding the `global_section`. It's guaranteed that at least
// one of those functions is present. Therefore, the candidate position for the global section is the following
// one. However, technically, custom sections could occupy any place between the well-known sections.
//
// Now, regarding `+1` here. `insert` panics iff `index > len`. `prec_index + 1` can only be equal to `len`.
module
.sections_mut()
.insert(prec_index + 1, elements::Section::Global(global_section));
// First entry in the brand new globals section.
*stack_height_global_idx = 0;
Ok(())
} }
/// Stack cost of the given *defined* function is the sum of it's locals count (that is, fn instrument_functions(ctx: &Context, module: &mut elements::Module) -> Result<(), &'static str> {
/// number of arguments plus number of local variables) and the maximal stack
/// height.
fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
// To calculate the cost of a function we need to convert index from
// function index space to defined function spaces.
let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
let defined_func_idx = func_idx
.checked_sub(func_imports)
.ok_or("This should be a index of a defined function")?;
let code_section =
module.code_section().ok_or("Due to validation code section should exists")?;
let body = &code_section
.bodies()
.get(defined_func_idx as usize)
.ok_or("Function body is out of bounds")?;
let mut locals_count: u32 = 0;
for local_group in body.locals() {
locals_count =
locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
}
let max_stack_height = max_height::compute(defined_func_idx, module)?;
locals_count
.checked_add(max_stack_height)
.ok_or("Overflow in adding locals_count and max_stack_height")
}
fn instrument_functions(
ctx: &mut Context,
module: &mut elements::Module,
) -> Result<(), &'static str> {
for section in module.sections_mut() { for section in module.sections_mut() {
if let elements::Section::Code(code_section) = section { if let elements::Section::Code(code_section) = section {
for func_body in code_section.bodies_mut() { for func_body in code_section.bodies_mut() {
@@ -242,7 +311,7 @@ fn instrument_functions(
/// ///
/// drop /// drop
/// ``` /// ```
fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), &'static str> { fn instrument_function(ctx: &Context, func: &mut Instructions) -> Result<(), &'static str> {
use Instruction::*; use Instruction::*;
struct InstrumentCall { struct InstrumentCall {
@@ -309,42 +378,6 @@ fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(),
Ok(()) Ok(())
} }
fn resolve_func_type(
func_idx: u32,
module: &elements::Module,
) -> Result<&elements::FunctionType, &'static str> {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
let func_imports = module.import_count(elements::ImportCountType::Function);
let sig_idx = if func_idx < func_imports as u32 {
module
.import_section()
.expect("function import count is not zero; import section must exists; qed")
.entries()
.iter()
.filter_map(|entry| match entry.external() {
elements::External::Function(idx) => Some(*idx),
_ => None,
})
.nth(func_idx as usize)
.expect(
"func_idx is less than function imports count;
nth function import must be `Some`;
qed",
)
} else {
functions
.get(func_idx as usize - func_imports)
.ok_or("Function at the specified index is not defined")?
.type_ref()
};
let Type::Function(ty) = types
.get(sig_idx as usize)
.ok_or("The signature as specified by a function isn't defined")?;
Ok(ty)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
+58 -43
View File
@@ -1,17 +1,17 @@
#[cfg(not(features = "std"))] #[cfg(not(features = "std"))]
use alloc::collections::BTreeMap as Map; use alloc::collections::BTreeMap as Map;
use alloc::vec::Vec; use alloc::vec::Vec;
use parity_wasm::{ use parity_wasm::elements::{self, Internal};
builder,
elements::{self, FunctionType, Internal},
};
#[cfg(features = "std")] #[cfg(features = "std")]
use std::collections::HashMap as Map; use std::collections::HashMap as Map;
use super::{resolve_func_type, Context}; use super::Context;
struct Thunk { struct Thunk {
signature: FunctionType, /// The index of the signature in the type section.
type_idx: u32,
/// The number of parameters the function has.
param_num: u32,
// Index in function space of this thunk. // Index in function space of this thunk.
idx: Option<u32>, idx: Option<u32>,
callee_stack_cost: u32, callee_stack_cost: u32,
@@ -19,10 +19,11 @@ struct Thunk {
pub fn generate_thunks( pub fn generate_thunks(
ctx: &mut Context, ctx: &mut Context,
module: elements::Module, mut module: elements::Module,
) -> Result<elements::Module, &'static str> { ) -> Result<elements::Module, &'static str> {
// First, we need to collect all function indices that should be replaced by thunks // First, we need to collect all function indices that should be replaced by thunks
let mut replacement_map: Map<u32, Thunk> = { let mut replacement_map: Map<u32, Thunk> = {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]); let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]);
let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]); let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]);
let start_func_idx = module.start_section(); let start_func_idx = module.start_section();
@@ -45,14 +46,12 @@ pub fn generate_thunks(
// Don't generate a thunk if stack_cost of a callee is zero. // Don't generate a thunk if stack_cost of a callee is zero.
if callee_stack_cost != 0 { if callee_stack_cost != 0 {
replacement_map.insert( let type_idx = ctx.func_type(func_idx).ok_or("type idx for thunk not found")?;
func_idx, let elements::Type::Function(func_ty) =
Thunk { types.get(type_idx as usize).ok_or("sig for thunk is not found")?;
signature: resolve_func_type(func_idx, &module)?.clone(), let param_num = func_ty.params().len() as u32;
idx: None, replacement_map
callee_stack_cost, .insert(func_idx, Thunk { type_idx, param_num, idx: None, callee_stack_cost });
},
);
} }
} }
@@ -61,10 +60,6 @@ pub fn generate_thunks(
// Then, we generate a thunk for each original function. // Then, we generate a thunk for each original function.
// Save current func_idx
let mut next_func_idx = module.functions_space() as u32;
let mut mbuilder = builder::from_module(module);
for (func_idx, thunk) in replacement_map.iter_mut() { for (func_idx, thunk) in replacement_map.iter_mut() {
let instrumented_call = instrument_call!( let instrumented_call = instrument_call!(
*func_idx, *func_idx,
@@ -77,32 +72,23 @@ pub fn generate_thunks(
// - instrumented call // - instrumented call
// - end // - end
let mut thunk_body: Vec<elements::Instruction> = let mut thunk_body: Vec<elements::Instruction> =
Vec::with_capacity(thunk.signature.params().len() + instrumented_call.len() + 1); Vec::with_capacity(thunk.param_num as usize + instrumented_call.len() + 1);
for (arg_idx, _) in thunk.signature.params().iter().enumerate() { for arg_idx in 0..thunk.param_num {
thunk_body.push(elements::Instruction::GetLocal(arg_idx as u32)); thunk_body.push(elements::Instruction::GetLocal(arg_idx));
} }
thunk_body.extend_from_slice(&instrumented_call); thunk_body.extend_from_slice(&instrumented_call);
thunk_body.push(elements::Instruction::End); thunk_body.push(elements::Instruction::End);
// TODO: Don't generate a signature, but find an existing one. let func_idx = insert_function(
ctx,
mbuilder = mbuilder &mut module,
.function() thunk.type_idx,
// Signature of the thunk should match the original function signature. Vec::new(), // No declared local variables.
.signature() elements::Instructions::new(thunk_body),
.with_params(thunk.signature.params().to_vec()) )?;
.with_results(thunk.signature.results().to_vec()) thunk.idx = Some(func_idx);
.build()
.body()
.with_instructions(elements::Instructions::new(thunk_body))
.build()
.build();
thunk.idx = Some(next_func_idx);
next_func_idx += 1;
} }
let mut module = mbuilder.build();
// And finally, fixup thunks in export and table sections. // And finally, fixup thunks in export and table sections.
@@ -118,18 +104,20 @@ pub fn generate_thunks(
for section in module.sections_mut() { for section in module.sections_mut() {
match section { match section {
elements::Section::Export(export_section) => elements::Section::Export(export_section) => {
for entry in export_section.entries_mut() { for entry in export_section.entries_mut() {
if let Internal::Function(function_idx) = entry.internal_mut() { if let Internal::Function(function_idx) = entry.internal_mut() {
fixup(function_idx) fixup(function_idx)
} }
}, }
elements::Section::Element(elem_section) => },
elements::Section::Element(elem_section) => {
for segment in elem_section.entries_mut() { for segment in elem_section.entries_mut() {
for function_idx in segment.members_mut() { for function_idx in segment.members_mut() {
fixup(function_idx) fixup(function_idx)
} }
}, }
},
elements::Section::Start(start_idx) => fixup(start_idx), elements::Section::Start(start_idx) => fixup(start_idx),
_ => {}, _ => {},
} }
@@ -137,3 +125,30 @@ pub fn generate_thunks(
Ok(module) Ok(module)
} }
/// Inserts a new function into the module and returns it's index in the function space.
///
/// Specifically, inserts entires into the function section and the code section.
fn insert_function(
ctx: &Context,
module: &mut elements::Module,
type_idx: u32,
locals: Vec<elements::Local>,
insns: elements::Instructions,
) -> Result<u32, &'static str> {
let funcs = module
.function_section_mut()
.ok_or("insert function no function section")?
.entries_mut();
let new_func_idx = ctx
.func_imports
.checked_add(funcs.len() as u32)
.ok_or("insert function func idx overflow")?;
funcs.push(elements::Func::new(type_idx));
let func_bodies =
module.code_section_mut().ok_or("insert function no code section")?.bodies_mut();
func_bodies.push(elements::FuncBody::new(locals, insns));
Ok(new_func_idx)
}