Improve mock_impl_runtime_apis! (#7370)

* Improve `mock_impl_runtime_apis!`

This adds a new attribute for functions being implemented in the
`mock_impl_runtime_apis!` macro, the `advanced` attribute. When this
attribute is given the user gets access to the `at` parameter and is
able to return a `Result`, instead of letting the macro generate this
stuff.

* Use the `at_param_name` directly

* Prevent clashing of `params`
This commit is contained in:
Bastian Köcher
2020-10-22 21:09:13 +02:00
committed by GitHub
parent d847c9b019
commit 79be077774
9 changed files with 287 additions and 46 deletions
@@ -24,16 +24,22 @@ use crate::utils::{
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{quote, quote_spanned};
use syn::{
spanned::Spanned, parse_macro_input, Ident, Type, ItemImpl, ImplItem, TypePath, parse_quote,
parse::{Parse, ParseStream, Result, Error}, fold::{self, Fold},
parse::{Parse, ParseStream, Result, Error}, fold::{self, Fold}, Attribute, Pat,
};
/// Unique identifier used to make the hidden includes unique for this macro.
const HIDDEN_INCLUDES_ID: &str = "MOCK_IMPL_RUNTIME_APIS";
/// The `advanced` attribute.
///
/// If this attribute is given to a function, the function gets access to the `BlockId` as first
/// parameter and needs to return a `Result` with the appropiate error type.
const ADVANCED_ATTRIBUTE: &str = "advanced";
/// The structure used for parsing the runtime api implementations.
struct RuntimeApiImpls {
impls: Vec<ItemImpl>,
@@ -65,10 +71,16 @@ fn implement_common_api_traits(
let error_type = error_type.map(|e| quote!(#e)).unwrap_or_else(|| quote!(String));
Ok(quote!(
// Quote using the span from `error_type` to generate nice error messages when the type is
// not implementing a trait or similar.
let api_error_ext = quote_spanned! { error_type.span() =>
impl #crate_::ApiErrorExt for #self_ty {
type Error = #error_type;
}
};
Ok(quote!(
#api_error_ext
impl #crate_::ApiExt<#block_type> for #self_ty {
type StateBackend = #crate_::InMemoryBackend<#crate_::HashFor<#block_type>>;
@@ -89,7 +101,7 @@ fn implement_common_api_traits(
fn has_api_with<A: #crate_::RuntimeApiInfo + ?Sized, P: Fn(u32) -> bool>(
&self,
at: &#crate_::BlockId<#block_type>,
_: &#crate_::BlockId<#block_type>,
pred: P,
) -> std::result::Result<bool, #error_type> where Self: Sized {
Ok(pred(A::VERSION))
@@ -153,6 +165,61 @@ fn implement_common_api_traits(
))
}
/// Returns if the advanced attribute is present in the given `attributes`.
///
/// If the attribute was found, it will be automatically removed from the vec.
fn has_advanced_attribute(attributes: &mut Vec<Attribute>) -> bool {
let mut found = false;
attributes.retain(|attr| if attr.path.is_ident(ADVANCED_ATTRIBUTE) {
found = true;
false
} else {
true
});
found
}
/// Get the name and type of the `at` parameter that is passed to a runtime api function.
///
/// If `is_advanced` is `false`, the name is `_`.
fn get_at_param_name(
is_advanced: bool,
param_names: &mut Vec<Pat>,
param_types_and_borrows: &mut Vec<(TokenStream, bool)>,
function_span: Span,
default_block_id_type: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
if is_advanced {
if param_names.is_empty() {
return Err(Error::new(
function_span,
format!(
"If using the `{}` attribute, it is required that the function \
takes at least one argument, the `BlockId`.",
ADVANCED_ATTRIBUTE,
),
))
}
// `param_names` and `param_types` have the same length, so if `param_names` is not empty
// `param_types` can not be empty as well.
let ptype_and_borrows = param_types_and_borrows.remove(0);
let span = ptype_and_borrows.1.span();
if !ptype_and_borrows.1 {
return Err(Error::new(
span,
"`BlockId` needs to be taken by reference and not by value!",
))
}
let name = param_names.remove(0);
Ok((quote!( #name ), ptype_and_borrows.0))
} else {
Ok((quote!( _ ), default_block_id_type.clone()))
}
}
/// Auxialiry structure to fold a runtime api trait implementation into the expected format.
///
/// This renames the methods, changes the method parameters and extracts the error type.
@@ -170,8 +237,10 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> {
fn fold_impl_item_method(&mut self, mut input: syn::ImplItemMethod) -> syn::ImplItemMethod {
let block = {
let crate_ = generate_crate_access(HIDDEN_INCLUDES_ID);
let is_advanced = has_advanced_attribute(&mut input.attrs);
let mut errors = Vec::new();
let (param_names, param_types, error) = match extract_parameter_names_types_and_borrows(
let (mut param_names, mut param_types_and_borrows) = match extract_parameter_names_types_and_borrows(
&input.sig,
AllowSelfRefInParameters::YesButIgnore,
) {
@@ -180,21 +249,40 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> {
res.iter().map(|v| {
let ty = &v.1;
let borrow = &v.2;
quote!( #borrow #ty )
(quote_spanned!(ty.span() => #borrow #ty ), v.2.is_some())
}).collect::<Vec<_>>(),
None
),
Err(e) => (Vec::new(), Vec::new(), Some(e.to_compile_error())),
Err(e) => {
errors.push(e.to_compile_error());
(Default::default(), Default::default())
}
};
let block_type = &self.block_type;
let block_id_type = quote!( &#crate_::BlockId<#block_type> );
let (at_param_name, block_id_type) = match get_at_param_name(
is_advanced,
&mut param_names,
&mut param_types_and_borrows,
input.span(),
&block_id_type,
) {
Ok(res) => res,
Err(e) => {
errors.push(e.to_compile_error());
(quote!( _ ), block_id_type)
}
};
let param_types = param_types_and_borrows.iter().map(|v| &v.0);
// Rewrite the input parameters.
input.sig.inputs = parse_quote! {
&self,
_: &#crate_::BlockId<#block_type>,
#at_param_name: #block_id_type,
_: #crate_::ExecutionContext,
params: Option<( #( #param_types ),* )>,
___params___sp___api___: Option<( #( #param_types ),* )>,
_: Vec<u8>,
};
@@ -202,27 +290,40 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> {
&self.impl_trait,
&input.sig.ident,
);
let ret_type = return_type_extract_type(&input.sig.output);
// Generate the correct return type.
input.sig.output = parse_quote!(
-> std::result::Result<#crate_::NativeOrEncoded<#ret_type>, Self::Error>
);
// When using advanced, the user needs to declare the correct return type on its own,
// otherwise do it for the user.
if !is_advanced {
let ret_type = return_type_extract_type(&input.sig.output);
// Generate the correct return type.
input.sig.output = parse_quote!(
-> std::result::Result<#crate_::NativeOrEncoded<#ret_type>, Self::Error>
);
}
let orig_block = input.block.clone();
let construct_return_value = if is_advanced {
quote!( (move || #orig_block)() )
} else {
quote! {
let __fn_implementation__ = move || #orig_block;
Ok(#crate_::NativeOrEncoded::Native(__fn_implementation__()))
}
};
// Generate the new method implementation that calls into the runtime.
parse_quote!(
{
// Get the error to the user (if we have one).
#error
#( #errors )*
let (#( #param_names ),*) = params
let (#( #param_names ),*) = ___params___sp___api___
.expect("Mocked runtime apis don't support calling deprecated api versions");
let __fn_implementation__ = move || #orig_block;
Ok(#crate_::NativeOrEncoded::Native(__fn_implementation__()))
#construct_return_value
}
)
};
@@ -240,10 +341,17 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> {
if ty.ident == "Error" {
if let Some(error_type) = self.error_type {
if *error_type != ty.ty {
let error = Error::new(
let mut error = Error::new(
ty.span(),
"Error type can not change between runtime apis",
);
let error_first = Error::new(
error_type.span(),
"First error type was declared here."
);
error.combine(error_first);
ImplItem::Verbatim(error.to_compile_error())
} else {
ImplItem::Verbatim(Default::default())