diff --git a/src/error.rs b/src/error.rs index 64788f4..c1c9276 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,6 +7,8 @@ pub enum Error { StoreError(crate::string::String), /// proof items is not enough to build a tree CorruptedProof, + /// tried to verify proof of a non-leaf + NodeProofsNotSupported, /// The leaves is an empty list, or beyond the mmr range GenProofForInvalidLeaves, @@ -22,6 +24,7 @@ impl core::fmt::Display for Error { InconsistentStore => write!(f, "Inconsistent store")?, StoreError(msg) => write!(f, "Store error {}", msg)?, CorruptedProof => write!(f, "Corrupted proof")?, + NodeProofsNotSupported => write!(f, "Tried to verify membership of a non-leaf")?, GenProofForInvalidLeaves => write!(f, "Generate proof ofr invalid leaves")?, MergeError(msg) => write!(f, "Merge error {}", msg)?, } diff --git a/src/mmr.rs b/src/mmr.rs index 7a65c7b..e4eb976 100644 --- a/src/mmr.rs +++ b/src/mmr.rs @@ -128,11 +128,16 @@ impl<'a, T: Clone + PartialEq + Debug, M: Merge, S: MMRStore> MMR = pos_list.into_iter().map(|pos| (pos, 0u32)).collect(); + // Generate sub-tree merkle proof for positions while let Some((pos, height)) = queue.pop_front() { debug_assert!(pos <= peak_pos); if pos == peak_pos { - break; + if queue.is_empty() { + break; + } else { + return Err(Error::NodeProofsNotSupported); + } } // calculate sibling @@ -177,6 +182,9 @@ impl<'a, T: Clone + PartialEq + Debug, M: Merge, S: MMRStore> MMR 0) { + return Err(Error::NodeProofsNotSupported); + } // ensure positions are sorted and unique pos_list.sort_unstable(); pos_list.dedup(); @@ -288,6 +296,7 @@ fn calculate_peak_root< ) -> Result { debug_assert!(!leaves.is_empty(), "can't be empty"); // (position, hash, height) + let mut queue: VecDeque<_> = leaves .into_iter() .map(|(pos, item)| (pos, item, 0u32)) @@ -296,8 +305,12 @@ fn calculate_peak_root< // calculate tree root from each items while let Some((pos, item, height)) = queue.pop_front() { if pos == peak_pos { - // return root - return Ok(item); + if queue.is_empty() { + // return root once queue is consumed + return Ok(item); + } else { + return Err(Error::CorruptedProof); + } } // calculate sibling let next_height = pos_height_in_tree(pos + 1); @@ -323,10 +336,10 @@ fn calculate_peak_root< M::merge(&item, &sibling_item) }?; - if parent_pos < peak_pos { - queue.push_back((parent_pos, parent_item, height + 1)); + if parent_pos <= peak_pos { + queue.push_back((parent_pos, parent_item, height + 1)) } else { - return Ok(parent_item); + return Err(Error::CorruptedProof); } } Err(Error::CorruptedProof) @@ -342,6 +355,10 @@ fn calculate_peaks_hashes< mmr_size: u64, mut proof_iter: I, ) -> Result> { + if leaves.iter().any(|(pos, _)| pos_height_in_tree(*pos) > 0) { + return Err(Error::NodeProofsNotSupported); + } + // special handle the only 1 leaf MMR if mmr_size == 1 && leaves.len() == 1 && leaves[0].0 == 0 { return Ok(leaves.into_iter().map(|(_pos, item)| item).collect()); diff --git a/src/tests/test_mmr.rs b/src/tests/test_mmr.rs index 1e5f3a1..ff92d99 100644 --- a/src/tests/test_mmr.rs +++ b/src/tests/test_mmr.rs @@ -1,5 +1,7 @@ use super::{MergeNumberHash, NumberHash}; -use crate::{leaf_index_to_mmr_size, util::MemStore, Error, MMR}; +use crate::{ + helper::pos_height_in_tree, leaf_index_to_mmr_size, util::MemStore, Error, MMRStore, MMR, +}; use faster_hex::hex_string; use proptest::prelude::*; use rand::{seq::SliceRandom, thread_rng}; @@ -145,6 +147,118 @@ fn test_gen_proof_with_duplicate_leaves() { test_mmr(10, vec![5, 5]); } +fn test_invalid_proof_verification( + leaf_count: u32, + positions_to_verify: Vec, + // positions of entries that should be tampered + tampered_positions: Vec, + // optionally handroll proof from these positions + handrolled_proof_positions: Option>, +) { + use crate::{util::MemMMR, Merge, MerkleProof}; + use std::fmt::{Debug, Formatter}; + + // Simple item struct to allow debugging the contents of MMR nodes/peaks + #[derive(Clone, PartialEq)] + enum MyItem { + Number(u32), + Merged(Box, Box), + } + + impl Debug for MyItem { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + MyItem::Number(x) => f.write_fmt(format_args!("{}", x)), + MyItem::Merged(a, b) => f.write_fmt(format_args!("Merged({:#?}, {:#?})", a, b)), + } + } + } + + #[derive(Debug)] + struct MyMerge; + + impl Merge for MyMerge { + type Item = MyItem; + fn merge(lhs: &Self::Item, rhs: &Self::Item) -> Result { + Ok(MyItem::Merged(Box::new(lhs.clone()), Box::new(rhs.clone()))) + } + } + + let mut mmr: MemMMR = MemMMR::default(); + let mut positions: Vec = Vec::new(); + for i in 0u32..leaf_count { + let pos = mmr.push(MyItem::Number(i)).unwrap(); + positions.push(pos); + } + let root = mmr.get_root().unwrap(); + + let entries_to_verify: Vec<(u64, MyItem)> = positions_to_verify + .iter() + .map(|pos| (*pos, mmr.store().get_elem(*pos).unwrap().unwrap())) + .collect(); + + let mut tampered_entries_to_verify = entries_to_verify.clone(); + tampered_positions.iter().for_each(|proof_pos| { + tampered_entries_to_verify[*proof_pos] = ( + tampered_entries_to_verify[*proof_pos].0, + MyItem::Number(31337), + ) + }); + + let handrolled_proof: Option> = + handrolled_proof_positions.map(|handrolled_proof_positions| { + MerkleProof::new( + mmr.mmr_size(), + handrolled_proof_positions + .iter() + .map(|pos| mmr.store().get_elem(*pos).unwrap().unwrap()) + .collect(), + ) + }); + + // verification should fail whenever trying to prove membership of a non-member + if let Some(handrolled_proof) = handrolled_proof { + let handrolled_proof_result = + handrolled_proof.verify(root.clone(), tampered_entries_to_verify.clone()); + assert!(handrolled_proof_result.is_err() || !handrolled_proof_result.unwrap()); + } + + match mmr.gen_proof(positions_to_verify.clone()) { + Ok(proof) => { + assert!(proof.verify(root.clone(), entries_to_verify).unwrap()); + assert!(!proof.verify(root, tampered_entries_to_verify).unwrap()); + } + Err(Error::NodeProofsNotSupported) => { + // if couldn't generate proof, then it contained a non-leaf + assert!(positions_to_verify + .iter() + .any(|pos| pos_height_in_tree(*pos) > 0)); + } + Err(e) => panic!("Unexpected error: {}", e), + } +} + +#[test] +fn test_generic_proofs() { + test_invalid_proof_verification(7, vec![5], vec![0], Some(vec![2, 9, 10])); + test_invalid_proof_verification(7, vec![1, 2], vec![0], Some(vec![5, 9, 10])); + test_invalid_proof_verification(7, vec![1, 5], vec![0], Some(vec![0, 9, 10])); + test_invalid_proof_verification(7, vec![1, 6], vec![0], Some(vec![0, 5, 9, 10])); + test_invalid_proof_verification(7, vec![5, 6], vec![0], Some(vec![2, 9, 10])); + test_invalid_proof_verification(7, vec![1, 5, 6], vec![0], Some(vec![0, 9, 10])); + test_invalid_proof_verification(7, vec![1, 5, 7], vec![0], Some(vec![0, 8, 10])); + test_invalid_proof_verification(7, vec![5, 6, 7], vec![0], Some(vec![2, 8, 10])); + test_invalid_proof_verification(7, vec![5, 6, 7, 8, 9, 10], vec![0], Some(vec![2])); + test_invalid_proof_verification(7, vec![1, 5, 7, 8, 9, 10], vec![0], Some(vec![0])); + test_invalid_proof_verification(7, vec![0, 1, 5, 7, 8, 9, 10], vec![0], Some(vec![])); + test_invalid_proof_verification(7, vec![0, 1, 5, 6, 7, 8, 9, 10], vec![0], Some(vec![])); + test_invalid_proof_verification(7, vec![0, 1, 2, 5, 6, 7, 8, 9, 10], vec![0], Some(vec![])); + test_invalid_proof_verification(7, vec![0, 1, 2, 3, 7, 8, 9, 10], vec![0], Some(vec![4])); + test_invalid_proof_verification(7, vec![0, 2, 3, 7, 8, 9, 10], vec![0], Some(vec![1, 4])); + test_invalid_proof_verification(7, vec![0, 3, 7, 8, 9, 10], vec![0], Some(vec![1, 4])); + test_invalid_proof_verification(7, vec![0, 2, 3, 7, 8, 9, 10], vec![0], Some(vec![1, 4])); +} + prop_compose! { fn count_elem(count: u32) (elem in 0..count) diff --git a/src/util.rs b/src/util.rs index 96442b5..0828bc4 100644 --- a/src/util.rs +++ b/src/util.rs @@ -58,6 +58,10 @@ impl> MemMMR { &self.store } + pub fn mmr_size(&self) -> u64 { + self.mmr_size + } + pub fn get_root(&self) -> Result { let mmr = MMR::>::new(self.mmr_size, &self.store); mmr.get_root()