generation of real benchmark functions for benchmarking v2 (#13224)

* function generation with _name working, need to modify signature

* WIP

* support custom BenchmarkResult<T> type

* full support for BenchmarkResult<T> on benchmark function defs

* support () return type for benchmark function defs that don't use ?

* uncomment

* fix where clause handling

* fix benchmark function call bodies

* proper parsing of return type

* add UI tests for bad return type

* fix detection of missing last_stmt with defined return type

* UI tests covering missing last_stmt

* properly detect and complain about empty benchmark function defs

* fix missing Comma in Result<T, BenchmarkError> parsing + test

* add additional UI test

* allow complex path for BenchmarkResult and BenchmarkError in fn defs

* add UI tests covering complex path for BenchmarkResult, BenchmarkError

* retain doc comments and attributes

* also add attributes to struct

* add docs for benchmark function definition support

* fix imports on benchmark example

* fix issue with unused variables in extrinsic call fn def

* fix up docs

* remove support for v2::BenchmarkResult because it was confusing

* fix typo

* remove ability to use custom T for Result<T, BenchmarkError> in v2

* use missing call error instead of empty_fn()

* remove unneeded match statement

* Add a proper QED

Co-authored-by: Keith Yeung <kungfukeith11@gmail.com>

* fix other QED

Co-authored-by: Keith Yeung <kungfukeith11@gmail.com>

* cargo fmt

* add an explicit error for non TypePath as return type

* tweak error warning and add a UI test for non TypePath return

* remove comment

* add docs about T and I generic params

* improve docs referring to section "below"

* pull out return type checking logic into its own function

* pull out params parsing into its own function

* pull out call_def parsing into its own function

* add doc comment for missing_call()

* replace spaces with tabs

* add a result-based example to the benchmarking examples

---------

Co-authored-by: Keith Yeung <kungfukeith11@gmail.com>
This commit is contained in:
Sam Johnson
2023-02-22 09:09:11 -05:00
committed by GitHub
parent af25310eb0
commit 55ff791d80
26 changed files with 604 additions and 98 deletions
@@ -25,11 +25,13 @@ use quote::{quote, quote_spanned, ToTokens};
use syn::{
parenthesized,
parse::{Nothing, ParseStream},
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token::{Colon2, Comma, Gt, Lt, Paren},
Attribute, Error, Expr, ExprBlock, ExprCall, ExprPath, FnArg, Item, ItemFn, ItemMod, LitInt,
Pat, Path, PathArguments, PathSegment, Result, Stmt, Token, Type, WhereClause,
Pat, Path, PathArguments, PathSegment, Result, ReturnType, Signature, Stmt, Token, Type,
TypePath, Visibility, WhereClause,
};
mod keywords {
@@ -41,6 +43,8 @@ mod keywords {
custom_keyword!(extra);
custom_keyword!(extrinsic_call);
custom_keyword!(skip_meta);
custom_keyword!(BenchmarkError);
custom_keyword!(Result);
}
/// This represents the raw parsed data for a param definition such as `x: Linear<10, 20>`.
@@ -145,62 +149,121 @@ struct BenchmarkDef {
setup_stmts: Vec<Stmt>,
call_def: BenchmarkCallDef,
verify_stmts: Vec<Stmt>,
last_stmt: Option<Stmt>,
extra: bool,
skip_meta: bool,
fn_sig: Signature,
fn_vis: Visibility,
fn_attrs: Vec<Attribute>,
}
impl BenchmarkDef {
/// Constructs a [`BenchmarkDef`] by traversing an existing [`ItemFn`] node.
pub fn from(item_fn: &ItemFn, extra: bool, skip_meta: bool) -> Result<BenchmarkDef> {
let mut params: Vec<ParamDef> = Vec::new();
/// used to parse something compatible with `Result<T, E>`
#[derive(Parse)]
struct ResultDef {
_result_kw: keywords::Result,
_lt: Token![<],
unit: Type,
_comma: Comma,
e_type: TypePath,
_gt: Token![>],
}
// parse params such as "x: Linear<0, 1>"
for arg in &item_fn.sig.inputs {
let invalid_param = |span| {
return Err(Error::new(span, "Invalid benchmark function param. A valid example would be `x: Linear<5, 10>`.", ))
};
let FnArg::Typed(arg) = arg else { return invalid_param(arg.span()) };
let Pat::Ident(ident) = &*arg.pat else { return invalid_param(arg.span()) };
// check param name
let var_span = ident.span();
let invalid_param_name = || {
/// Ensures that `ReturnType` is a `Result<(), BenchmarkError>`, if specified
fn ensure_valid_return_type(item_fn: &ItemFn) -> Result<()> {
if let ReturnType::Type(_, typ) = &item_fn.sig.output {
let non_unit = |span| return Err(Error::new(span, "expected `()`"));
let Type::Path(TypePath { path, qself: _ }) = &**typ else {
return Err(Error::new(
typ.span(),
"Only `Result<(), BenchmarkError>` or a blank return type is allowed on benchmark function definitions",
))
};
let seg = path
.segments
.last()
.expect("to be parsed as a TypePath, it must have at least one segment; qed");
let res: ResultDef = syn::parse2(seg.to_token_stream())?;
// ensure T in Result<T, E> is ()
let Type::Tuple(tup) = res.unit else { return non_unit(res.unit.span()) };
if !tup.elems.is_empty() {
return non_unit(tup.span())
}
let TypePath { path, qself: _ } = res.e_type;
let seg = path
.segments
.last()
.expect("to be parsed as a TypePath, it must have at least one segment; qed");
syn::parse2::<keywords::BenchmarkError>(seg.to_token_stream())?;
}
Ok(())
}
/// Parses params such as `x: Linear<0, 1>`
fn parse_params(item_fn: &ItemFn) -> Result<Vec<ParamDef>> {
let mut params: Vec<ParamDef> = Vec::new();
for arg in &item_fn.sig.inputs {
let invalid_param = |span| {
return Err(Error::new(
span,
"Invalid benchmark function param. A valid example would be `x: Linear<5, 10>`.",
))
};
let FnArg::Typed(arg) = arg else { return invalid_param(arg.span()) };
let Pat::Ident(ident) = &*arg.pat else { return invalid_param(arg.span()) };
// check param name
let var_span = ident.span();
let invalid_param_name = || {
return Err(Error::new(
var_span,
"Benchmark parameter names must consist of a single lowercase letter (a-z) and no other characters.",
))
};
let name = ident.ident.to_token_stream().to_string();
if name.len() > 1 {
return invalid_param_name()
};
let Some(name_char) = name.chars().next() else { return invalid_param_name() };
if !name_char.is_alphabetic() || !name_char.is_lowercase() {
return invalid_param_name()
}
// parse type
let typ = &*arg.ty;
let Type::Path(tpath) = typ else { return invalid_param(typ.span()) };
let Some(segment) = tpath.path.segments.last() else { return invalid_param(typ.span()) };
let args = segment.arguments.to_token_stream().into();
let Ok(args) = syn::parse::<RangeArgs>(args) else { return invalid_param(typ.span()) };
let Ok(start) = args.start.base10_parse::<u32>() else { return invalid_param(args.start.span()) };
let Ok(end) = args.end.base10_parse::<u32>() else { return invalid_param(args.end.span()) };
if end < start {
return Err(Error::new(
args.start.span(),
"The start of a `ParamRange` must be less than or equal to the end",
))
}
params.push(ParamDef { name, typ: typ.clone(), start, end });
};
let name = ident.ident.to_token_stream().to_string();
if name.len() > 1 {
return invalid_param_name()
};
let Some(name_char) = name.chars().next() else { return invalid_param_name() };
if !name_char.is_alphabetic() || !name_char.is_lowercase() {
return invalid_param_name()
}
// #[extrinsic_call] / #[block] handling
let call_defs = item_fn.block.stmts.iter().enumerate().filter_map(|(i, child)| {
// parse type
let typ = &*arg.ty;
let Type::Path(tpath) = typ else { return invalid_param(typ.span()) };
let Some(segment) = tpath.path.segments.last() else { return invalid_param(typ.span()) };
let args = segment.arguments.to_token_stream().into();
let Ok(args) = syn::parse::<RangeArgs>(args) else { return invalid_param(typ.span()) };
let Ok(start) = args.start.base10_parse::<u32>() else { return invalid_param(args.start.span()) };
let Ok(end) = args.end.base10_parse::<u32>() else { return invalid_param(args.end.span()) };
if end < start {
return Err(Error::new(
args.start.span(),
"The start of a `ParamRange` must be less than or equal to the end",
))
}
params.push(ParamDef { name, typ: typ.clone(), start, end });
}
Ok(params)
}
/// Used in several places where the `#[extrinsic_call]` or `#[body]` annotation is missing
fn missing_call<T>(item_fn: &ItemFn) -> Result<T> {
return Err(Error::new(
item_fn.block.brace_token.span,
"No valid #[extrinsic_call] or #[block] annotation could be found in benchmark function body."
))
}
/// Finds the `BenchmarkCallDef` and its index (within the list of stmts for the fn) and
/// returns them. Also handles parsing errors for invalid / extra call defs. AKA this is
/// general handling for `#[extrinsic_call]` and `#[block]`
fn parse_call_def(item_fn: &ItemFn) -> Result<(usize, BenchmarkCallDef)> {
// #[extrinsic_call] / #[block] handling
let call_defs = item_fn.block.stmts.iter().enumerate().filter_map(|(i, child)| {
if let Stmt::Semi(Expr::Call(expr_call), _semi) = child {
// #[extrinsic_call] case
expr_call.attrs.iter().enumerate().find_map(|(k, attr)| {
@@ -234,25 +297,60 @@ impl BenchmarkDef {
None
}
}).collect::<Result<Vec<_>>>()?;
let (i, call_def) = match &call_defs[..] {
[(i, call_def)] => (*i, call_def.clone()), // = 1
[] => return Err(Error::new( // = 0
item_fn.block.brace_token.span,
"No valid #[extrinsic_call] or #[block] annotation could be found in benchmark function body."
)),
_ => return Err(Error::new( // > 1
Ok(match &call_defs[..] {
[(i, call_def)] => (*i, call_def.clone()), // = 1
[] => return missing_call(item_fn),
_ =>
return Err(Error::new(
call_defs[1].1.attr_span(),
"Only one #[extrinsic_call] or #[block] attribute is allowed per benchmark."
"Only one #[extrinsic_call] or #[block] attribute is allowed per benchmark.",
)),
})
}
impl BenchmarkDef {
/// Constructs a [`BenchmarkDef`] by traversing an existing [`ItemFn`] node.
pub fn from(item_fn: &ItemFn, extra: bool, skip_meta: bool) -> Result<BenchmarkDef> {
let params = parse_params(item_fn)?;
ensure_valid_return_type(item_fn)?;
let (i, call_def) = parse_call_def(&item_fn)?;
let (verify_stmts, last_stmt) = match item_fn.sig.output {
ReturnType::Default =>
// no return type, last_stmt should be None
(Vec::from(&item_fn.block.stmts[(i + 1)..item_fn.block.stmts.len()]), None),
ReturnType::Type(_, _) => {
// defined return type, last_stmt should be Result<(), BenchmarkError>
// compatible and should not be included in verify_stmts
if i + 1 >= item_fn.block.stmts.len() {
return Err(Error::new(
item_fn.block.span(),
"Benchmark `#[block]` or `#[extrinsic_call]` item cannot be the \
last statement of your benchmark function definition if you have \
defined a return type. You should return something compatible \
with Result<(), BenchmarkError> (i.e. `Ok(())`) as the last statement \
or change your signature to a blank return type.",
))
}
let Some(stmt) = item_fn.block.stmts.last() else { return missing_call(item_fn) };
(
Vec::from(&item_fn.block.stmts[(i + 1)..item_fn.block.stmts.len() - 1]),
Some(stmt.clone()),
)
},
};
Ok(BenchmarkDef {
params,
setup_stmts: Vec::from(&item_fn.block.stmts[0..i]),
call_def,
verify_stmts: Vec::from(&item_fn.block.stmts[(i + 1)..item_fn.block.stmts.len()]),
verify_stmts,
last_stmt,
extra,
skip_meta,
fn_sig: item_fn.sig.clone(),
fn_vis: item_fn.vis.clone(),
fn_attrs: item_fn.attrs.clone(),
})
}
}
@@ -643,6 +741,7 @@ fn expand_benchmark(
let traits = quote!(#krate::frame_support::traits);
let setup_stmts = benchmark_def.setup_stmts;
let verify_stmts = benchmark_def.verify_stmts;
let last_stmt = benchmark_def.last_stmt;
let test_ident = Ident::new(format!("test_{}", name.to_string()).as_str(), Span::call_site());
// unroll params (prepare for quoting)
@@ -661,7 +760,8 @@ fn expand_benchmark(
true => quote!(T: Config<I>, I: 'static),
};
let (pre_call, post_call) = match benchmark_def.call_def {
// used in the benchmarking impls
let (pre_call, post_call, fn_call_body) = match &benchmark_def.call_def {
BenchmarkCallDef::ExtrinsicCall { origin, expr_call, attr_span: _ } => {
let mut expr_call = expr_call.clone();
@@ -705,36 +805,97 @@ fn expand_benchmark(
qself: None,
path: Path { leading_colon: None, segments: punct },
});
let pre_call = quote! {
let __call = Call::<#type_use_generics>::#expr_call;
let __benchmarked_call_encoded = #codec::Encode::encode(&__call);
};
let post_call = quote! {
let __call_decoded = <Call<#type_use_generics> as #codec::Decode>
::decode(&mut &__benchmarked_call_encoded[..])
.expect("call is encoded above, encoding must be correct");
let __origin = #origin.into();
<Call<#type_use_generics> as #traits::UnfilteredDispatchable>::dispatch_bypass_filter(
__call_decoded,
__origin,
)
};
(
// (pre_call, post_call):
// (pre_call, post_call, fn_call_body):
pre_call.clone(),
quote!(#post_call?;),
quote! {
let __call = Call::<#type_use_generics>::#expr_call;
let __benchmarked_call_encoded = #codec::Encode::encode(&__call);
},
quote! {
let __call_decoded = <Call<#type_use_generics> as #codec::Decode>
::decode(&mut &__benchmarked_call_encoded[..])
.expect("call is encoded above, encoding must be correct");
let __origin = #origin.into();
<Call<#type_use_generics> as #traits::UnfilteredDispatchable>::dispatch_bypass_filter(
__call_decoded,
__origin,
)?;
#pre_call
#post_call.unwrap();
},
)
},
BenchmarkCallDef::Block { block, attr_span: _ } => (quote!(), quote!(#block)),
BenchmarkCallDef::Block { block, attr_span: _ } =>
(quote!(), quote!(#block), quote!(#block)),
};
let vis = benchmark_def.fn_vis;
// remove #[benchmark] attribute
let fn_attrs: Vec<&Attribute> = benchmark_def
.fn_attrs
.iter()
.filter(|attr| !syn::parse2::<keywords::benchmark>(attr.path.to_token_stream()).is_ok())
.collect();
// modify signature generics, ident, and inputs, e.g:
// before: `fn bench(u: Linear<1, 100>) -> Result<(), BenchmarkError>`
// after: `fn _bench <T: Config<I>, I: 'static>(u: u32, verify: bool) -> Result<(),
// BenchmarkError>`
let mut sig = benchmark_def.fn_sig;
sig.generics = parse_quote!(<#type_impl_generics>);
if !where_clause.is_empty() {
sig.generics.where_clause = parse_quote!(where #where_clause);
}
sig.ident =
Ident::new(format!("_{}", name.to_token_stream().to_string()).as_str(), Span::call_site());
let mut fn_param_inputs: Vec<TokenStream2> =
param_names.iter().map(|name| quote!(#name: u32)).collect();
fn_param_inputs.push(quote!(verify: bool));
sig.inputs = parse_quote!(#(#fn_param_inputs),*);
// used in instance() impl
let impl_last_stmt = match &last_stmt {
Some(stmt) => quote!(#stmt),
None => quote!(Ok(())),
};
let fn_def = quote! {
#(
#fn_attrs
)*
#vis #sig {
#(
#setup_stmts
)*
#fn_call_body
if verify {
#(
#verify_stmts
)*
}
#last_stmt
}
};
// generate final quoted tokens
let res = quote! {
// benchmark function definition
#fn_def
// compile-time assertions that each referenced param type implements ParamRange
#(
#home::assert_impl_all!(#param_types: #home::ParamRange);
)*
#[allow(non_camel_case_types)]
#(
#fn_attrs
)*
struct #name;
#[allow(unused_variables)]
@@ -773,7 +934,7 @@ fn expand_benchmark(
#verify_stmts
)*
}
Ok(())
#impl_last_stmt
}))
}
}