Check for invalid modules when registering a pallet in construct_runtime (#4520)

This commit is contained in:
Bastian Köcher
2020-01-02 20:10:20 +01:00
committed by GitHub
parent 508f94acdf
commit 4a9697db94
5 changed files with 31 additions and 16 deletions
@@ -180,8 +180,8 @@ impl Parse for ModuleDeclaration {
let has_default = parts.into_iter().any(|m| m.is_default());
for entry in parts {
match entry {
ModuleEntry::Part(part) if has_default => {
if part.is_included_in_default() {
ModuleEntry::Part(part) => {
if has_default && part.is_included_in_default() {
let msg = format!(
"`{}` is already included in `default`. Either remove `default` or remove `{}`",
part.name,
@@ -189,8 +189,7 @@ impl Parse for ModuleDeclaration {
);
return Err(Error::new(part.name.span(), msg));
}
}
ModuleEntry::Part(part) => {
if !resolved.insert(part.name.clone()) {
let msg = format!(
"`{}` was already declared before. Please remove the duplicate declaration",
@@ -287,7 +286,18 @@ pub struct ModulePart {
impl Parse for ModulePart {
fn parse(input: ParseStream) -> Result<Self> {
let name = input.parse()?;
let name: Ident = input.parse()?;
if !ModulePart::all_allowed().iter().any(|n| name == n) {
return Err(syn::Error::new(
name.span(),
format!(
"Only the following modules are allowed: {}",
ModulePart::format_names(ModulePart::all_allowed()),
),
))
}
let generics: syn::Generics = input.parse()?;
if !generics.params.is_empty() && !Self::is_allowed_generic(&name) {
let valid_generics = ModulePart::format_names(ModulePart::allowed_generics());
@@ -313,6 +323,7 @@ impl Parse for ModulePart {
} else {
None
};
Ok(Self {
name,
generics,
@@ -330,15 +341,20 @@ impl ModulePart {
Self::allowed_args().into_iter().any(|n| ident == n)
}
pub fn allowed_generics() -> Vec<&'static str> {
vec!["Event", "Origin", "Config"]
pub fn allowed_generics() -> &'static [&'static str] {
&["Event", "Origin", "Config"]
}
pub fn allowed_args() -> Vec<&'static str> {
vec!["Inherent"]
pub fn allowed_args() -> &'static [&'static str] {
&["Inherent"]
}
pub fn format_names(names: Vec<&'static str>) -> String {
/// Returns all allowed names for module parts.
pub fn all_allowed() -> &'static [&'static str] {
&["Module", "Call", "Storage", "Event", "Config", "Origin", "Inherent", "ValidateUnsigned"]
}
pub fn format_names(names: &[&'static str]) -> String {
let res: Vec<_> = names.into_iter().map(|s| format!("`{}`", s)).collect();
res.join(", ")
}