Prevent bug when reusing type ids in hashing (#1075)

* practice TDD

* implement a hashmap 2-phases approach

* use nicer types

* add test for cache filling

* adjust test

---------

Co-authored-by: James Wilson <james@jsdw.me>
This commit is contained in:
Tadeo Hepperle
2023-07-19 19:49:08 +02:00
committed by GitHub
parent 84df959d28
commit 475a1413e4
2 changed files with 139 additions and 61 deletions
+137 -59
View File
@@ -9,10 +9,7 @@ use crate::{
RuntimeApiMethodMetadata, StorageEntryMetadata, StorageEntryType,
};
use scale_info::{form::PortableForm, Field, PortableRegistry, TypeDef, TypeDefVariant, Variant};
use std::collections::HashSet;
/// Predefined value to be returned when we already visited a type.
const MAGIC_RECURSIVE_TYPE_VALUE: &[u8] = &[123];
use std::collections::HashMap;
// The number of bytes our `hash` function produces.
const HASH_LEN: usize = 32;
@@ -75,7 +72,7 @@ concat_and_hash_n!(concat_and_hash5(a b c d e));
fn get_field_hash(
registry: &PortableRegistry,
field: &Field<PortableForm>,
visited_ids: &mut HashSet<u32>,
cache: &mut HashMap<u32, CachedHash>,
) -> [u8; HASH_LEN] {
let field_name_bytes = match &field.name {
Some(name) => hash(name.as_bytes()),
@@ -84,7 +81,7 @@ fn get_field_hash(
concat_and_hash2(
&field_name_bytes,
&get_type_hash(registry, field.ty.id, visited_ids),
&get_type_hash(registry, field.ty.id, cache),
)
}
@@ -92,13 +89,13 @@ fn get_field_hash(
fn get_variant_hash(
registry: &PortableRegistry,
var: &Variant<PortableForm>,
visited_ids: &mut HashSet<u32>,
cache: &mut HashMap<u32, CachedHash>,
) -> [u8; HASH_LEN] {
let variant_name_bytes = hash(var.name.as_bytes());
let variant_field_bytes = var.fields.iter().fold([0u8; HASH_LEN], |bytes, field| {
// EncodeAsType and DecodeAsType don't care about variant field ordering,
// so XOR the fields to ensure that it doesn't matter.
xor(bytes, get_field_hash(registry, field, visited_ids))
xor(bytes, get_field_hash(registry, field, cache))
});
concat_and_hash2(&variant_name_bytes, &variant_field_bytes)
@@ -108,7 +105,7 @@ fn get_type_def_variant_hash(
registry: &PortableRegistry,
variant: &TypeDefVariant<PortableForm>,
only_these_variants: Option<&[&str]>,
visited_ids: &mut HashSet<u32>,
cache: &mut HashMap<u32, CachedHash>,
) -> [u8; HASH_LEN] {
let variant_id_bytes = [TypeBeingHashed::Variant as u8; HASH_LEN];
let variant_field_bytes = variant.variants.iter().fold([0u8; HASH_LEN], |bytes, var| {
@@ -120,7 +117,7 @@ fn get_type_def_variant_hash(
.unwrap_or(true);
if should_hash {
xor(bytes, get_variant_hash(registry, var, visited_ids))
xor(bytes, get_variant_hash(registry, var, cache))
} else {
bytes
}
@@ -132,7 +129,7 @@ fn get_type_def_variant_hash(
fn get_type_def_hash(
registry: &PortableRegistry,
ty_def: &TypeDef<PortableForm>,
visited_ids: &mut HashSet<u32>,
cache: &mut HashMap<u32, CachedHash>,
) -> [u8; HASH_LEN] {
match ty_def {
TypeDef::Composite(composite) => {
@@ -144,16 +141,14 @@ fn get_type_def_hash(
.fold([0u8; HASH_LEN], |bytes, field| {
// With EncodeAsType and DecodeAsType we no longer care which order the fields are in,
// as long as all of the names+types are there. XOR to not care about ordering.
xor(bytes, get_field_hash(registry, field, visited_ids))
xor(bytes, get_field_hash(registry, field, cache))
});
concat_and_hash2(&composite_id_bytes, &composite_field_bytes)
}
TypeDef::Variant(variant) => {
get_type_def_variant_hash(registry, variant, None, visited_ids)
}
TypeDef::Variant(variant) => get_type_def_variant_hash(registry, variant, None, cache),
TypeDef::Sequence(sequence) => concat_and_hash2(
&[TypeBeingHashed::Sequence as u8; HASH_LEN],
&get_type_hash(registry, sequence.type_param.id, visited_ids),
&get_type_hash(registry, sequence.type_param.id, cache),
),
TypeDef::Array(array) => {
// Take length into account too; different length must lead to different hash.
@@ -165,13 +160,13 @@ fn get_type_def_hash(
};
concat_and_hash2(
&array_id_bytes,
&get_type_hash(registry, array.type_param.id, visited_ids),
&get_type_hash(registry, array.type_param.id, cache),
)
}
TypeDef::Tuple(tuple) => {
let mut bytes = hash(&[TypeBeingHashed::Tuple as u8]);
for field in &tuple.fields {
bytes = concat_and_hash2(&bytes, &get_type_hash(registry, field.id, visited_ids));
bytes = concat_and_hash2(&bytes, &get_type_hash(registry, field.id, cache));
}
bytes
}
@@ -181,31 +176,64 @@ fn get_type_def_hash(
}
TypeDef::Compact(compact) => concat_and_hash2(
&[TypeBeingHashed::Compact as u8; HASH_LEN],
&get_type_hash(registry, compact.type_param.id, visited_ids),
&get_type_hash(registry, compact.type_param.id, cache),
),
TypeDef::BitSequence(bitseq) => concat_and_hash3(
&[TypeBeingHashed::BitSequence as u8; HASH_LEN],
&get_type_hash(registry, bitseq.bit_order_type.id, visited_ids),
&get_type_hash(registry, bitseq.bit_store_type.id, visited_ids),
&get_type_hash(registry, bitseq.bit_order_type.id, cache),
&get_type_hash(registry, bitseq.bit_store_type.id, cache),
),
}
}
/// indicates whether a hash has been fully computed for a type or not
#[derive(Clone, Debug)]
pub enum CachedHash {
/// hash not known yet, but computation has already started
Recursive,
/// hash of the type, computation was finished
Hash([u8; HASH_LEN]),
}
impl CachedHash {
fn hash(&self) -> [u8; HASH_LEN] {
match &self {
CachedHash::Hash(hash) => *hash,
CachedHash::Recursive => [123; HASH_LEN], // some magical value
}
}
}
/// Obtain the hash representation of a `scale_info::Type` identified by id.
pub fn get_type_hash(
registry: &PortableRegistry,
id: u32,
visited_ids: &mut HashSet<u32>,
cache: &mut HashMap<u32, CachedHash>,
) -> [u8; HASH_LEN] {
// Guard against recursive types and return a fixed arbitrary hash
if !visited_ids.insert(id) {
return hash(MAGIC_RECURSIVE_TYPE_VALUE);
}
// Guard against recursive types, with a 2 step caching approach:
// if the cache has an entry for the id, just return a hash derived from it.
// if the type has not been seen yet, mark it with `CachedHash::Recursive` in the cache and proceed to `get_type_def_hash()`.
// -> During the execution of get_type_def_hash() we might get into get_type_hash(id) again for the original id
// -> in this case the `CachedHash::Recursive` provokes an early return.
// -> Once we return from `get_type_def_hash()` we need to update the cache entry:
// -> We set the cache value to `CachedHash::Hash(type_hash)`, where `type_hash` was returned from `get_type_def_hash()`
// -> It makes sure, that different types end up with different cache values.
//
// Values in the cache can be thought of as a mapping like this:
// type_id -> not contained = We haven't seen the type yet.
// -> `CachedHash::Recursive` = We have seen the type but hash calculation for it hasn't finished yet.
// -> `CachedHash::Hash(hash)` = Hash calculation for the type was completed.
if let Some(cached_hash) = cache.get(&id) {
return cached_hash.hash();
}
cache.insert(id, CachedHash::Recursive);
let ty = registry
.resolve(id)
.expect("Type ID provided by the metadata is registered; qed");
get_type_def_hash(registry, &ty.type_def, visited_ids)
let type_hash = get_type_def_hash(registry, &ty.type_def, cache);
cache.insert(id, CachedHash::Hash(type_hash));
type_hash
}
/// Obtain the hash representation of a `frame_metadata::v15::ExtrinsicMetadata`.
@@ -213,13 +241,13 @@ fn get_extrinsic_hash(
registry: &PortableRegistry,
extrinsic: &ExtrinsicMetadata,
) -> [u8; HASH_LEN] {
let mut visited_ids = HashSet::<u32>::new();
let mut cache = HashMap::<u32, CachedHash>::new();
// Get the hashes of the extrinsic type.
let address_hash = get_type_hash(registry, extrinsic.address_ty, &mut visited_ids);
let address_hash = get_type_hash(registry, extrinsic.address_ty, &mut cache);
// The `RuntimeCall` type is intentionally omitted and hashed by the outer enums instead.
let signature_hash = get_type_hash(registry, extrinsic.signature_ty, &mut visited_ids);
let extra_hash = get_type_hash(registry, extrinsic.extra_ty, &mut visited_ids);
let signature_hash = get_type_hash(registry, extrinsic.signature_ty, &mut cache);
let extra_hash = get_type_hash(registry, extrinsic.extra_ty, &mut cache);
let mut bytes = concat_and_hash4(
&address_hash,
@@ -232,8 +260,8 @@ fn get_extrinsic_hash(
bytes = concat_and_hash4(
&bytes,
&hash(signed_extension.identifier.as_bytes()),
&get_type_hash(registry, signed_extension.extra_ty, &mut visited_ids),
&get_type_hash(registry, signed_extension.additional_ty, &mut visited_ids),
&get_type_hash(registry, signed_extension.extra_ty, &mut cache),
&get_type_hash(registry, signed_extension.additional_ty, &mut cache),
)
}
@@ -258,9 +286,9 @@ fn get_outer_enums_hash(
.expect("Metadata should contain enum type in registry");
if let TypeDef::Variant(variant) = &ty.ty.type_def {
get_type_def_variant_hash(registry, variant, only_these_variants, &mut HashSet::new())
get_type_def_variant_hash(registry, variant, only_these_variants, &mut HashMap::new())
} else {
get_type_hash(registry, id, &mut HashSet::new())
get_type_hash(registry, id, &mut HashMap::new())
}
}
@@ -277,7 +305,7 @@ fn get_outer_enums_hash(
fn get_storage_entry_hash(
registry: &PortableRegistry,
entry: &StorageEntryMetadata,
visited_ids: &mut HashSet<u32>,
cache: &mut HashMap<u32, CachedHash>,
) -> [u8; HASH_LEN] {
let mut bytes = concat_and_hash3(
&hash(entry.name.as_bytes()),
@@ -288,7 +316,7 @@ fn get_storage_entry_hash(
match &entry.entry_type {
StorageEntryType::Plain(ty) => {
concat_and_hash2(&bytes, &get_type_hash(registry, *ty, visited_ids))
concat_and_hash2(&bytes, &get_type_hash(registry, *ty, cache))
}
StorageEntryType::Map {
hashers,
@@ -301,8 +329,8 @@ fn get_storage_entry_hash(
}
concat_and_hash3(
&bytes,
&get_type_hash(registry, *key_ty, visited_ids),
&get_type_hash(registry, *value_ty, visited_ids),
&get_type_hash(registry, *key_ty, cache),
&get_type_hash(registry, *value_ty, cache),
)
}
}
@@ -313,7 +341,7 @@ fn get_runtime_method_hash(
registry: &PortableRegistry,
trait_name: &str,
method_metadata: &RuntimeApiMethodMetadata,
visited_ids: &mut HashSet<u32>,
cache: &mut HashMap<u32, CachedHash>,
) -> [u8; HASH_LEN] {
// The trait name is part of the runtime API call that is being
// generated for this method. Therefore the trait name is strongly
@@ -328,13 +356,13 @@ fn get_runtime_method_hash(
bytes = concat_and_hash3(
&bytes,
&hash(input.name.as_bytes()),
&get_type_hash(registry, input.ty, visited_ids),
&get_type_hash(registry, input.ty, cache),
);
}
bytes = concat_and_hash2(
&bytes,
&get_type_hash(registry, method_metadata.output_ty, visited_ids),
&get_type_hash(registry, method_metadata.output_ty, cache),
);
bytes
@@ -342,7 +370,7 @@ fn get_runtime_method_hash(
/// Obtain the hash of all of a runtime API trait, including all of its methods.
pub fn get_runtime_trait_hash(trait_metadata: RuntimeApiMetadata) -> [u8; HASH_LEN] {
let mut visited_ids = HashSet::new();
let mut cache = HashMap::new();
let trait_name = &*trait_metadata.inner.name;
let method_bytes = trait_metadata
.methods()
@@ -357,7 +385,7 @@ pub fn get_runtime_trait_hash(trait_metadata: RuntimeApiMetadata) -> [u8; HASH_L
trait_metadata.types,
trait_name,
method_metadata,
&mut visited_ids,
&mut cache,
),
)
});
@@ -370,7 +398,7 @@ pub fn get_storage_hash(pallet: &PalletMetadata, entry_name: &str) -> Option<[u8
let storage = pallet.storage()?;
let entry = storage.entry_by_name(entry_name)?;
let hash = get_storage_entry_hash(pallet.types, entry, &mut HashSet::new());
let hash = get_storage_entry_hash(pallet.types, entry, &mut HashMap::new());
Some(hash)
}
@@ -379,7 +407,7 @@ pub fn get_constant_hash(pallet: &PalletMetadata, constant_name: &str) -> Option
let constant = pallet.constant_by_name(constant_name)?;
// We only need to check that the type of the constant asked for matches.
let bytes = get_type_hash(pallet.types, constant.ty, &mut HashSet::new());
let bytes = get_type_hash(pallet.types, constant.ty, &mut HashMap::new());
Some(bytes)
}
@@ -388,7 +416,7 @@ pub fn get_call_hash(pallet: &PalletMetadata, call_name: &str) -> Option<[u8; HA
let call_variant = pallet.call_variant_by_name(call_name)?;
// hash the specific variant representing the call we are interested in.
let hash = get_variant_hash(pallet.types, call_variant, &mut HashSet::new());
let hash = get_variant_hash(pallet.types, call_variant, &mut HashMap::new());
Some(hash)
}
@@ -404,25 +432,25 @@ pub fn get_runtime_api_hash(
runtime_apis.types,
trait_name,
method_metadata,
&mut HashSet::new(),
&mut HashMap::new(),
))
}
/// Obtain the hash representation of a `frame_metadata::v15::PalletMetadata`.
pub fn get_pallet_hash(pallet: PalletMetadata) -> [u8; HASH_LEN] {
let mut visited_ids = HashSet::<u32>::new();
let mut cache = HashMap::<u32, CachedHash>::new();
let registry = pallet.types;
let call_bytes = match pallet.call_ty_id() {
Some(calls) => get_type_hash(registry, calls, &mut visited_ids),
Some(calls) => get_type_hash(registry, calls, &mut cache),
None => [0u8; HASH_LEN],
};
let event_bytes = match pallet.event_ty_id() {
Some(event) => get_type_hash(registry, event, &mut visited_ids),
Some(event) => get_type_hash(registry, event, &mut cache),
None => [0u8; HASH_LEN],
};
let error_bytes = match pallet.error_ty_id() {
Some(error) => get_type_hash(registry, error, &mut visited_ids),
Some(error) => get_type_hash(registry, error, &mut cache),
None => [0u8; HASH_LEN],
};
let constant_bytes = pallet.constants().fold([0u8; HASH_LEN], |bytes, constant| {
@@ -430,7 +458,7 @@ pub fn get_pallet_hash(pallet: PalletMetadata) -> [u8; HASH_LEN] {
// of (constantName, constantType) to make the order we see them irrelevant.
let constant_hash = concat_and_hash2(
&hash(constant.name.as_bytes()),
&get_type_hash(registry, constant.ty(), &mut visited_ids),
&get_type_hash(registry, constant.ty(), &mut cache),
);
xor(bytes, constant_hash)
});
@@ -443,10 +471,7 @@ pub fn get_pallet_hash(pallet: PalletMetadata) -> [u8; HASH_LEN] {
.fold([0u8; HASH_LEN], |bytes, entry| {
// We don't care what order the storage entries occur in, so XOR them together
// to make the order irrelevant.
xor(
bytes,
get_storage_entry_hash(registry, entry, &mut visited_ids),
)
xor(bytes, get_storage_entry_hash(registry, entry, &mut cache))
});
concat_and_hash2(&prefix_hash, &entries_hash)
}
@@ -537,7 +562,7 @@ impl<'a> MetadataHasher<'a> {
let extrinsic_hash = get_extrinsic_hash(&metadata.types, &metadata.extrinsic);
let runtime_hash =
get_type_hash(&metadata.types, metadata.runtime_ty(), &mut HashSet::new());
get_type_hash(&metadata.types, metadata.runtime_ty(), &mut HashMap::new());
let outer_enums_hash = get_outer_enums_hash(
&metadata.types,
&metadata.outer_enums(),
@@ -559,7 +584,7 @@ mod tests {
use super::*;
use bitvec::{order::Lsb0, vec::BitVec};
use frame_metadata::v15;
use scale_info::meta_type;
use scale_info::{meta_type, Registry};
// Define recursive types.
#[allow(dead_code)]
@@ -743,6 +768,59 @@ mod tests {
assert_eq!(hash, hash_swap);
}
#[allow(dead_code)]
#[derive(scale_info::TypeInfo)]
struct Aba {
ab: (A, B),
other: A,
}
#[allow(dead_code)]
#[derive(scale_info::TypeInfo)]
struct Abb {
ab: (A, B),
other: B,
}
#[test]
/// Ensure ABB and ABA have a different structure:
fn do_not_reuse_visited_type_ids() {
let metadata_hash_with_type = |ty| {
let mut pallets = build_default_pallets();
pallets[0].calls = Some(v15::PalletCallMetadata { ty });
let metadata = pallets_to_metadata(pallets);
MetadataHasher::new(&metadata).hash()
};
let aba_hash = metadata_hash_with_type(meta_type::<Aba>());
let abb_hash = metadata_hash_with_type(meta_type::<Abb>());
assert_ne!(aba_hash, abb_hash);
}
#[test]
fn hash_cache_gets_filled_with_correct_hashes() {
let mut registry = Registry::new();
let a_type_id = registry.register_type(&meta_type::<A>()).id;
let b_type_id = registry.register_type(&meta_type::<B>()).id;
let registry: PortableRegistry = registry.into();
let mut cache = HashMap::new();
let a_hash = get_type_hash(&registry, a_type_id, &mut cache);
let a_hash2 = get_type_hash(&registry, a_type_id, &mut cache);
let b_hash = get_type_hash(&registry, b_type_id, &mut cache);
let CachedHash::Hash(a_cache_hash) = cache[&a_type_id] else { panic!() };
let CachedHash::Hash(b_cache_hash) = cache[&b_type_id] else { panic!() };
assert_eq!(a_hash, a_cache_hash);
assert_eq!(b_hash, b_cache_hash);
assert_eq!(a_hash, a_hash2);
assert_ne!(a_hash, b_hash);
}
#[test]
// Redundant clone clippy warning is a lie; https://github.com/rust-lang/rust-clippy/issues/10870
#[allow(clippy::redundant_clone)]