diff --git a/substrate/primitives/api/proc-macro/src/mock_impl_runtime_apis.rs b/substrate/primitives/api/proc-macro/src/mock_impl_runtime_apis.rs index 0e8f18e3e6..3e2fd42951 100644 --- a/substrate/primitives/api/proc-macro/src/mock_impl_runtime_apis.rs +++ b/substrate/primitives/api/proc-macro/src/mock_impl_runtime_apis.rs @@ -24,16 +24,22 @@ use crate::utils::{ use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{quote, quote_spanned}; use syn::{ spanned::Spanned, parse_macro_input, Ident, Type, ItemImpl, ImplItem, TypePath, parse_quote, - parse::{Parse, ParseStream, Result, Error}, fold::{self, Fold}, + parse::{Parse, ParseStream, Result, Error}, fold::{self, Fold}, Attribute, Pat, }; /// Unique identifier used to make the hidden includes unique for this macro. const HIDDEN_INCLUDES_ID: &str = "MOCK_IMPL_RUNTIME_APIS"; +/// The `advanced` attribute. +/// +/// If this attribute is given to a function, the function gets access to the `BlockId` as first +/// parameter and needs to return a `Result` with the appropiate error type. +const ADVANCED_ATTRIBUTE: &str = "advanced"; + /// The structure used for parsing the runtime api implementations. struct RuntimeApiImpls { impls: Vec, @@ -65,10 +71,16 @@ fn implement_common_api_traits( let error_type = error_type.map(|e| quote!(#e)).unwrap_or_else(|| quote!(String)); - Ok(quote!( + // Quote using the span from `error_type` to generate nice error messages when the type is + // not implementing a trait or similar. + let api_error_ext = quote_spanned! { error_type.span() => impl #crate_::ApiErrorExt for #self_ty { type Error = #error_type; } + }; + + Ok(quote!( + #api_error_ext impl #crate_::ApiExt<#block_type> for #self_ty { type StateBackend = #crate_::InMemoryBackend<#crate_::HashFor<#block_type>>; @@ -89,7 +101,7 @@ fn implement_common_api_traits( fn has_api_with bool>( &self, - at: &#crate_::BlockId<#block_type>, + _: &#crate_::BlockId<#block_type>, pred: P, ) -> std::result::Result where Self: Sized { Ok(pred(A::VERSION)) @@ -153,6 +165,61 @@ fn implement_common_api_traits( )) } +/// Returns if the advanced attribute is present in the given `attributes`. +/// +/// If the attribute was found, it will be automatically removed from the vec. +fn has_advanced_attribute(attributes: &mut Vec) -> bool { + let mut found = false; + attributes.retain(|attr| if attr.path.is_ident(ADVANCED_ATTRIBUTE) { + found = true; + false + } else { + true + }); + + found +} + +/// Get the name and type of the `at` parameter that is passed to a runtime api function. +/// +/// If `is_advanced` is `false`, the name is `_`. +fn get_at_param_name( + is_advanced: bool, + param_names: &mut Vec, + param_types_and_borrows: &mut Vec<(TokenStream, bool)>, + function_span: Span, + default_block_id_type: &TokenStream, +) -> Result<(TokenStream, TokenStream)> { + if is_advanced { + if param_names.is_empty() { + return Err(Error::new( + function_span, + format!( + "If using the `{}` attribute, it is required that the function \ + takes at least one argument, the `BlockId`.", + ADVANCED_ATTRIBUTE, + ), + )) + } + + // `param_names` and `param_types` have the same length, so if `param_names` is not empty + // `param_types` can not be empty as well. + let ptype_and_borrows = param_types_and_borrows.remove(0); + let span = ptype_and_borrows.1.span(); + if !ptype_and_borrows.1 { + return Err(Error::new( + span, + "`BlockId` needs to be taken by reference and not by value!", + )) + } + + let name = param_names.remove(0); + Ok((quote!( #name ), ptype_and_borrows.0)) + } else { + Ok((quote!( _ ), default_block_id_type.clone())) + } +} + /// Auxialiry structure to fold a runtime api trait implementation into the expected format. /// /// This renames the methods, changes the method parameters and extracts the error type. @@ -170,8 +237,10 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> { fn fold_impl_item_method(&mut self, mut input: syn::ImplItemMethod) -> syn::ImplItemMethod { let block = { let crate_ = generate_crate_access(HIDDEN_INCLUDES_ID); + let is_advanced = has_advanced_attribute(&mut input.attrs); + let mut errors = Vec::new(); - let (param_names, param_types, error) = match extract_parameter_names_types_and_borrows( + let (mut param_names, mut param_types_and_borrows) = match extract_parameter_names_types_and_borrows( &input.sig, AllowSelfRefInParameters::YesButIgnore, ) { @@ -180,21 +249,40 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> { res.iter().map(|v| { let ty = &v.1; let borrow = &v.2; - quote!( #borrow #ty ) + (quote_spanned!(ty.span() => #borrow #ty ), v.2.is_some()) }).collect::>(), - None ), - Err(e) => (Vec::new(), Vec::new(), Some(e.to_compile_error())), + Err(e) => { + errors.push(e.to_compile_error()); + + (Default::default(), Default::default()) + } }; let block_type = &self.block_type; + let block_id_type = quote!( &#crate_::BlockId<#block_type> ); + let (at_param_name, block_id_type) = match get_at_param_name( + is_advanced, + &mut param_names, + &mut param_types_and_borrows, + input.span(), + &block_id_type, + ) { + Ok(res) => res, + Err(e) => { + errors.push(e.to_compile_error()); + (quote!( _ ), block_id_type) + } + }; + + let param_types = param_types_and_borrows.iter().map(|v| &v.0); // Rewrite the input parameters. input.sig.inputs = parse_quote! { &self, - _: &#crate_::BlockId<#block_type>, + #at_param_name: #block_id_type, _: #crate_::ExecutionContext, - params: Option<( #( #param_types ),* )>, + ___params___sp___api___: Option<( #( #param_types ),* )>, _: Vec, }; @@ -202,27 +290,40 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> { &self.impl_trait, &input.sig.ident, ); - let ret_type = return_type_extract_type(&input.sig.output); - // Generate the correct return type. - input.sig.output = parse_quote!( - -> std::result::Result<#crate_::NativeOrEncoded<#ret_type>, Self::Error> - ); + // When using advanced, the user needs to declare the correct return type on its own, + // otherwise do it for the user. + if !is_advanced { + let ret_type = return_type_extract_type(&input.sig.output); + + // Generate the correct return type. + input.sig.output = parse_quote!( + -> std::result::Result<#crate_::NativeOrEncoded<#ret_type>, Self::Error> + ); + } let orig_block = input.block.clone(); + let construct_return_value = if is_advanced { + quote!( (move || #orig_block)() ) + } else { + quote! { + let __fn_implementation__ = move || #orig_block; + + Ok(#crate_::NativeOrEncoded::Native(__fn_implementation__())) + } + }; + // Generate the new method implementation that calls into the runtime. parse_quote!( { // Get the error to the user (if we have one). - #error + #( #errors )* - let (#( #param_names ),*) = params + let (#( #param_names ),*) = ___params___sp___api___ .expect("Mocked runtime apis don't support calling deprecated api versions"); - let __fn_implementation__ = move || #orig_block; - - Ok(#crate_::NativeOrEncoded::Native(__fn_implementation__())) + #construct_return_value } ) }; @@ -240,10 +341,17 @@ impl<'a> Fold for FoldRuntimeApiImpl<'a> { if ty.ident == "Error" { if let Some(error_type) = self.error_type { if *error_type != ty.ty { - let error = Error::new( + let mut error = Error::new( ty.span(), "Error type can not change between runtime apis", ); + let error_first = Error::new( + error_type.span(), + "First error type was declared here." + ); + + error.combine(error_first); + ImplItem::Verbatim(error.to_compile_error()) } else { ImplItem::Verbatim(Default::default()) diff --git a/substrate/primitives/api/src/lib.rs b/substrate/primitives/api/src/lib.rs index bad6c03058..9dadce3b55 100644 --- a/substrate/primitives/api/src/lib.rs +++ b/substrate/primitives/api/src/lib.rs @@ -241,20 +241,18 @@ pub use sp_api_proc_macro::impl_runtime_apis; /// Mocks given trait implementations as runtime apis. /// -/// Accepts similar syntax as [`impl_runtime_apis!`](macro.impl_runtime_apis.html) and generates +/// Accepts similar syntax as [`impl_runtime_apis!`] and generates /// simplified mock implementations of the given runtime apis. The difference in syntax is that the /// trait does not need to be referenced by a qualified path, methods accept the `&self` parameter -/// and the error type can be specified as associated type. If no error type is specified `String` +/// and the error type can be specified as associated type. If no error type is specified [`String`] /// is used as error type. /// -/// Besides implementing the given traits, the [`Core`], [`ApiExt`] and [`ApiErrorExt`] are -/// implemented automatically. +/// Besides implementing the given traits, the [`Core`](sp_api::Core), [`ApiExt`](sp_api::ApiExt) +/// and [`ApiErrorExt`](sp_api::ApiErrorExt) are implemented automatically. /// /// # Example /// /// ```rust -/// use sp_version::create_runtime_str; -/// # /// # use sp_runtime::traits::Block as BlockT; /// # use sp_test_primitives::Block; /// # @@ -270,7 +268,6 @@ pub use sp_api_proc_macro::impl_runtime_apis; /// # fn build_block() -> Block; /// # } /// # } -/// /// struct MockApi { /// balance: u64, /// } @@ -301,6 +298,59 @@ pub use sp_api_proc_macro::impl_runtime_apis; /// /// # fn main() {} /// ``` +/// +/// # `advanced` attribute +/// +/// This attribute can be placed above individual function in the mock implementation to request +/// more control over the function declaration. From the client side each runtime api function is +/// called with the `at` parameter that is a [`BlockId`](sp_api::BlockId). When using the `advanced` +/// attribute, the macro expects that the first parameter of the function is this `at` parameter. +/// Besides that the macro also doesn't do the automatic return value rewrite, which means that full +/// return value must be specified. The full return value is constructed like +/// [`Result`]`<`[`NativeOrEncoded`](sp_api::NativeOrEncoded)`, Error>` while +/// `ReturnValue` being the return value that is specified in the trait declaration. +/// +/// ## Example +/// ```rust +/// # use sp_runtime::{traits::Block as BlockT, generic::BlockId}; +/// # use sp_test_primitives::Block; +/// # use sp_core::NativeOrEncoded; +/// # +/// # sp_api::decl_runtime_apis! { +/// # /// Declare the api trait. +/// # pub trait Balance { +/// # /// Get the balance. +/// # fn get_balance() -> u64; +/// # /// Set the balance. +/// # fn set_balance(val: u64); +/// # } +/// # } +/// struct MockApi { +/// balance: u64, +/// } +/// +/// sp_api::mock_impl_runtime_apis! { +/// impl Balance for MockApi { +/// type Error = String; +/// #[advanced] +/// fn get_balance(&self, at: &BlockId) -> Result, String> { +/// println!("Being called at: {}", at); +/// +/// Ok(self.balance.into()) +/// } +/// #[advanced] +/// fn set_balance(at: &BlockId, val: u64) -> Result, String> { +/// if let BlockId::Number(1) = at { +/// println!("Being called to set balance to: {}", val); +/// } +/// +/// Ok(().into()) +/// } +/// } +/// } +/// +/// # fn main() {} +/// ``` pub use sp_api_proc_macro::mock_impl_runtime_apis; /// A type that records all accessed trie nodes and generates a proof out of it. diff --git a/substrate/primitives/api/test/tests/decl_and_impl.rs b/substrate/primitives/api/test/tests/decl_and_impl.rs index f16f0bbe71..594882baf1 100644 --- a/substrate/primitives/api/test/tests/decl_and_impl.rs +++ b/substrate/primitives/api/test/tests/decl_and_impl.rs @@ -19,9 +19,8 @@ use sp_api::{ RuntimeApiInfo, decl_runtime_apis, impl_runtime_apis, mock_impl_runtime_apis, ApiExt, }; - use sp_runtime::{traits::{GetNodeBlockType, Block as BlockT}, generic::BlockId}; - +use sp_core::NativeOrEncoded; use substrate_test_runtime_client::runtime::Block; use sp_blockchain::Result; @@ -103,9 +102,20 @@ mock_impl_runtime_apis! { unimplemented!() } - fn same_name() {} + #[advanced] + fn same_name(_: &BlockId) -> std::result::Result, String> { + Ok(().into()) + } - fn wild_card(_: u32) {} + #[advanced] + fn wild_card(at: &BlockId, _: u32) -> std::result::Result, String> { + if let BlockId::Number(1337) = at { + // yeah + Ok(().into()) + } else { + Err("Ohh noooo".into()) + } + } } impl ApiWithCustomVersion for MockApi { @@ -180,3 +190,12 @@ fn mock_runtime_api_panics_on_calling_old_version() { #[allow(deprecated)] let _ = mock.same_name_before_version_2(&BlockId::Number(0)); } + +#[test] +fn mock_runtime_api_works_with_advanced() { + let mock = MockApi { block: None }; + + Api::::same_name(&mock, &BlockId::Number(0)).unwrap(); + mock.wild_card(&BlockId::Number(1337), 1).unwrap(); + assert_eq!(String::from("Ohh noooo"), mock.wild_card(&BlockId::Number(1336), 1).unwrap_err()); +} diff --git a/substrate/primitives/api/test/tests/ui/mock_advanced_block_id_by_value.rs b/substrate/primitives/api/test/tests/ui/mock_advanced_block_id_by_value.rs new file mode 100644 index 0000000000..1e71730cd0 --- /dev/null +++ b/substrate/primitives/api/test/tests/ui/mock_advanced_block_id_by_value.rs @@ -0,0 +1,20 @@ +use substrate_test_runtime_client::runtime::Block; + +sp_api::decl_runtime_apis! { + pub trait Api { + fn test(); + } +} + +struct MockApi; + +sp_api::mock_impl_runtime_apis! { + impl Api for MockApi { + #[advanced] + fn test(&self, _: BlockId) -> Result, String> { + Ok(().into()) + } + } +} + +fn main() {} diff --git a/substrate/primitives/api/test/tests/ui/mock_advanced_block_id_by_value.stderr b/substrate/primitives/api/test/tests/ui/mock_advanced_block_id_by_value.stderr new file mode 100644 index 0000000000..efddce05f5 --- /dev/null +++ b/substrate/primitives/api/test/tests/ui/mock_advanced_block_id_by_value.stderr @@ -0,0 +1,13 @@ +error: `BlockId` needs to be taken by reference and not by value! + --> $DIR/mock_advanced_block_id_by_value.rs:11:1 + | +11 | / sp_api::mock_impl_runtime_apis! { +12 | | impl Api for MockApi { +13 | | #[advanced] +14 | | fn test(&self, _: BlockId) -> Result, String> { +... | +17 | | } +18 | | } + | |_^ + | + = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/substrate/primitives/api/test/tests/ui/mock_advanced_missing_blockid.rs b/substrate/primitives/api/test/tests/ui/mock_advanced_missing_blockid.rs new file mode 100644 index 0000000000..407ea90ee8 --- /dev/null +++ b/substrate/primitives/api/test/tests/ui/mock_advanced_missing_blockid.rs @@ -0,0 +1,20 @@ +use substrate_test_runtime_client::runtime::Block; + +sp_api::decl_runtime_apis! { + pub trait Api { + fn test(); + } +} + +struct MockApi; + +sp_api::mock_impl_runtime_apis! { + impl Api for MockApi { + #[advanced] + fn test(&self) -> Result, String> { + Ok(().into()) + } + } +} + +fn main() {} diff --git a/substrate/primitives/api/test/tests/ui/mock_advanced_missing_blockid.stderr b/substrate/primitives/api/test/tests/ui/mock_advanced_missing_blockid.stderr new file mode 100644 index 0000000000..e7a66ebc5d --- /dev/null +++ b/substrate/primitives/api/test/tests/ui/mock_advanced_missing_blockid.stderr @@ -0,0 +1,5 @@ +error: If using the `advanced` attribute, it is required that the function takes at least one argument, the `BlockId`. + --> $DIR/mock_advanced_missing_blockid.rs:14:3 + | +14 | fn test(&self) -> Result, String> { + | ^^ diff --git a/substrate/primitives/api/test/tests/ui/mock_only_one_error_type.stderr b/substrate/primitives/api/test/tests/ui/mock_only_one_error_type.stderr index 65d05e83a7..daac5674d6 100644 --- a/substrate/primitives/api/test/tests/ui/mock_only_one_error_type.stderr +++ b/substrate/primitives/api/test/tests/ui/mock_only_one_error_type.stderr @@ -4,22 +4,22 @@ error: Error type can not change between runtime apis 23 | type Error = u64; | ^^^^ +error: First error type was declared here. + --> $DIR/mock_only_one_error_type.rs:17:16 + | +17 | type Error = u32; + | ^^^ + error[E0277]: the trait bound `u32: std::convert::From` is not satisfied - --> $DIR/mock_only_one_error_type.rs:15:1 + --> $DIR/mock_only_one_error_type.rs:17:16 | -15 | / sp_api::mock_impl_runtime_apis! { -16 | | impl Api for MockApi { -17 | | type Error = u32; -18 | | -... | -26 | | } -27 | | } - | |_^ the trait `std::convert::From` is not implemented for `u32` - | - ::: $WORKSPACE/primitives/api/src/lib.rs:350:35 +17 | type Error = u32; + | ^^^ the trait `std::convert::From` is not implemented for `u32` | -350 | type Error: std::fmt::Debug + From; - | ------------ required by this bound in `sp_api_hidden_includes_DECL_RUNTIME_APIS::sp_api::ApiErrorExt` + ::: $WORKSPACE/primitives/api/src/lib.rs + | + | type Error: std::fmt::Debug + From; + | ------------ required by this bound in `sp_api_hidden_includes_DECL_RUNTIME_APIS::sp_api::ApiErrorExt` | = help: the following implementations were found: > @@ -27,4 +27,3 @@ error[E0277]: the trait bound `u32: std::convert::From` is > > and 18 others - = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/substrate/primitives/core/src/lib.rs b/substrate/primitives/core/src/lib.rs index bef033df6c..7857937aeb 100644 --- a/substrate/primitives/core/src/lib.rs +++ b/substrate/primitives/core/src/lib.rs @@ -196,6 +196,13 @@ pub enum NativeOrEncoded { Encoded(Vec) } +#[cfg(feature = "std")] +impl From for NativeOrEncoded { + fn from(val: R) -> Self { + Self::Native(val) + } +} + #[cfg(feature = "std")] impl sp_std::fmt::Debug for NativeOrEncoded { fn fmt(&self, f: &mut sp_std::fmt::Formatter) -> sp_std::fmt::Result {