diff --git a/.gitignore b/.gitignore index 1e7caa9..9182e2f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ Cargo.lock target/ +.vscode/ diff --git a/benches/mmr_benchmark.rs b/benches/mmr_benchmark.rs index ba4447e..f27e8e6 100644 --- a/benches/mmr_benchmark.rs +++ b/benches/mmr_benchmark.rs @@ -65,7 +65,7 @@ fn bench(c: &mut Criterion) { let (mmr_size, store, positions) = prepare_mmr(100_0000); let mmr = MMR::<_, MergeNumberHash, _>::new(mmr_size, &store); let mut rng = thread_rng(); - b.iter(|| mmr.gen_proof(*positions.choose(&mut rng).unwrap())); + b.iter(|| mmr.gen_proof(vec![*positions.choose(&mut rng).unwrap()])); }); c.bench_function("MMR verify", |b| { @@ -77,13 +77,15 @@ fn bench(c: &mut Criterion) { .map(|_| { let pos = positions.choose(&mut rng).unwrap(); let elem = (&store).get_elem(*pos).unwrap().unwrap(); - let proof = mmr.gen_proof(*pos).unwrap(); + let proof = mmr.gen_proof(vec![*pos]).unwrap(); (pos, elem, proof) }) .collect(); b.iter(|| { let (pos, elem, proof) = proofs.choose(&mut rng).unwrap(); - proof.verify(root.clone(), **pos, elem.clone()).unwrap(); + proof + .verify(root.clone(), vec![(**pos, elem.clone())]) + .unwrap(); }); }); } diff --git a/src/error.rs b/src/error.rs index 4e246b8..843694a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,6 +5,10 @@ pub enum Error { GetRootOnEmpty, InconsistentStore, StoreError(crate::string::String), + /// proof items is not enough to build a tree + CorruptedProof, + /// The leaves is an empty list, or beyond the mmr range + GenProofForInvalidLeaves, } impl core::fmt::Display for Error { @@ -14,6 +18,8 @@ impl core::fmt::Display for Error { GetRootOnEmpty => write!(f, "Get root on an empty MMR")?, InconsistentStore => write!(f, "Inconsistent store")?, StoreError(msg) => write!(f, "Store error {}", msg)?, + CorruptedProof => write!(f, "Corrupted proof")?, + GenProofForInvalidLeaves => write!(f, "Generate proof ofr invalid leaves")?, } Ok(()) } diff --git a/src/mmr.rs b/src/mmr.rs index 8889283..b2e1b59 100644 --- a/src/mmr.rs +++ b/src/mmr.rs @@ -5,6 +5,7 @@ //! https://github.com/mimblewimble/grin/blob/0ff6763ee64e5a14e70ddd4642b99789a1648a32/core/src/core/pmmr.rs#L606 use crate::borrow::Cow; +use crate::collections::{btree_map::Entry, BTreeMap}; use crate::helper::{get_peaks, parent_offset, pos_height_in_tree, sibling_offset}; use crate::mmr_store::{MMRBatch, MMRStore}; use crate::vec; @@ -79,42 +80,52 @@ impl<'a, T: Clone + PartialEq + Debug, M: Merge, S: MMRStore> MMR = get_peaks(self.mmr_size) + .into_iter() + .map(|peak_pos| { + self.batch + .get_elem(peak_pos) + .and_then(|elem| elem.ok_or(Error::InconsistentStore)) + }) + .collect::>>()?; + self.bag_rhs_peaks(peaks)?.ok_or(Error::InconsistentStore) } - fn bag_rhs_peaks(&self, skip_peak_pos: u64, peaks: &[u64]) -> Result> { - let mut rhs_peak_elems: Vec = peaks - .iter() - .filter(|&&p| p > skip_peak_pos) - .map(|&p| self.batch.get_elem(p)) - .collect::>>()? - .ok_or(Error::InconsistentStore)?; - while rhs_peak_elems.len() > 1 { - let right_peak = rhs_peak_elems.pop().expect("pop"); - let left_peak = rhs_peak_elems.pop().expect("pop"); - rhs_peak_elems.push(M::merge(&right_peak, &left_peak)); + fn bag_rhs_peaks(&self, mut rhs_peaks: Vec) -> Result> { + while rhs_peaks.len() > 1 { + let right_peak = rhs_peaks.pop().expect("pop"); + let left_peak = rhs_peaks.pop().expect("pop"); + rhs_peaks.push(M::merge(&right_peak, &left_peak)); } - Ok(rhs_peak_elems.pop()) + Ok(rhs_peaks.pop()) } - pub fn gen_proof(&self, mut pos: u64) -> Result> { - let mut proof: Vec = Vec::new(); - let mut height = 0; - while pos < self.mmr_size { + fn build_sub_merkle_path( + &self, + mut pos: u64, + mut height: u32, + peak_pos: u64, + stop_pos: u64, + tree_buf: &BTreeMap, + proof: &mut Vec, + ) -> Result<(u64, u32)> { + while pos < peak_pos { let pos_height = pos_height_in_tree(pos); let next_height = pos_height_in_tree(pos + 1); - let (sib_pos, next_pos) = if next_height > pos_height { + let sib_pos = if next_height > pos_height { // implies pos is right sibling let sib_pos = pos - sibling_offset(height); - (sib_pos, pos + 1) + pos += 1; + sib_pos } else { // pos is left sibling let sib_pos = pos + sibling_offset(height); - (sib_pos, pos + parent_offset(height)) + pos += parent_offset(height); + sib_pos }; - if sib_pos > self.mmr_size - 1 { + height += 1; + if pos > stop_pos || tree_buf.contains_key(&pos) { + // means that current merkle path is complete break; } proof.push( @@ -122,24 +133,98 @@ impl<'a, T: Clone + PartialEq + Debug, M: Merge, S: MMRStore> MMR, + pos_list: Vec, + peak_pos: u64, + ) -> Result<()> { + // do nothing if position itself is the peak + if pos_list.len() == 1 && pos_list == [peak_pos] { + return Ok(()); + } + // take peak root from store if no positions need to be proof + if pos_list.is_empty() { + proof.push( + self.batch + .get_elem(peak_pos)? + .ok_or(Error::InconsistentStore)?, + ); + return Ok(()); + } + + // buf, positon -> height map + let mut tree_buf: BTreeMap = + pos_list.into_iter().map(|pos| (pos, 0u32)).collect(); + // Generate sub-tree merkle proof for positions + loop { + let (&pos, &height) = tree_buf.iter().next().unwrap(); + tree_buf.remove(&pos); + debug_assert!(pos <= peak_pos); + if pos == peak_pos { + break; + } + + let next_pos = *tree_buf + .iter() + .next() + .map(|(pos, _height)| pos) + .unwrap_or(&peak_pos); + let (pos, height) = + self.build_sub_merkle_path(pos, height, peak_pos, next_pos, &tree_buf, proof)?; + // save pos to tree buf + tree_buf.entry(pos).or_insert(height); + } + Ok(()) + } + + /// Generate merkle proof for positions + /// 1. sort positions + /// 2. push merkle proof to proof by peak from left to right + /// 3. push bagged right hand side root + pub fn gen_proof(&self, mut pos_list: Vec) -> Result> { + if pos_list.is_empty() { + return Err(Error::GenProofForInvalidLeaves); + } + if self.mmr_size == 1 && pos_list == [0] { + return Ok(MerkleProof::new(self.mmr_size, Vec::new())); + } + // ensure positions is sorted + pos_list.sort_unstable(); let peaks = get_peaks(self.mmr_size); - if let Some(rhs_peak_hash) = self.bag_rhs_peaks(peak_pos, &peaks[..])? { - proof.push(rhs_peak_hash); + let mut proof: Vec = Vec::new(); + // generate merkle proof for each peaks + let mut bagging_track = 0; + for peak_pos in peaks { + let pos_list: Vec<_> = take_while_vec(&mut pos_list, |&pos| pos <= peak_pos); + if pos_list.is_empty() { + bagging_track += 1; + } else { + bagging_track = 0; + } + self.gen_proof_for_peak(&mut proof, pos_list, peak_pos)?; } - let lhs_peaks: Vec<_> = peaks - .iter() - .filter(|&&p| p < peak_pos) - .map(|&p| self.batch.get_elem(p)) - .rev() - .collect::>>()? - .ok_or(Error::InconsistentStore)?; - proof.extend(lhs_peaks); + + // ensure no remain positions + if !pos_list.is_empty() { + return Err(Error::GenProofForInvalidLeaves); + } + + if bagging_track > 1 { + let rhs_peaks = proof.split_off(proof.len() - bagging_track); + proof.push(self.bag_rhs_peaks(rhs_peaks)?.expect("bagging rhs peaks")); + } + Ok(MerkleProof::new(self.mmr_size, proof)) } @@ -155,7 +240,7 @@ pub struct MerkleProof { merge: PhantomData, } -impl> MerkleProof { +impl> MerkleProof { pub fn new(mmr_size: u64, proof: Vec) -> Self { MerkleProof { mmr_size, @@ -172,8 +257,8 @@ impl> MerkleProof { &self.proof } - pub fn calculate_root(&self, pos: u64, elem: T) -> Result { - calculate_root::<_, M, _>(pos, elem, self.mmr_size, self.proof.iter()) + pub fn calculate_root(&self, leaves: Vec<(u64, T)>) -> Result { + calculate_root::<_, M, _>(leaves, self.mmr_size, self.proof.iter()) } /// from merkle proof of leaf n to calculate merkle root of n + 1 leaves. @@ -182,100 +267,196 @@ impl> MerkleProof { /// this is kinda tricky, but it works, and useful pub fn calculate_root_with_new_leaf( &self, - pos: u64, - elem: T, + mut leaves: Vec<(u64, T)>, new_pos: u64, new_elem: T, new_mmr_size: u64, ) -> Result { - if self.mmr_size == 0 { - return Ok(elem); - } let pos_height = pos_height_in_tree(new_pos); let next_height = pos_height_in_tree(new_pos + 1); - if next_height > pos_height { - // new elem on right branch - let new_proof = vec![elem]; - let new_proof_iter = new_proof.iter().chain(self.proof.iter()); - calculate_root::<_, M, _>(new_pos, new_elem, new_mmr_size, new_proof_iter) + let mut peaks_hashes = + calculate_peaks_hashes::<_, M, _>(leaves, self.mmr_size, self.proof.iter())?; + let peaks_pos = get_peaks(new_mmr_size); + // reverse touched peaks + let mut i = 0; + while peaks_pos[i] < new_pos { + i += 1 + } + peaks_hashes[i..].reverse(); + calculate_root::<_, M, _>(vec![(new_pos, new_elem)], new_mmr_size, peaks_hashes.iter()) } else { - // new elem on left branch - debug_assert_eq!(self.mmr_size + 1, new_mmr_size); - let peaks = get_peaks(self.mmr_size); - let mut proof_iter = self.proof.iter(); - let (root_elem, _) = - calculate_peak_root::<_, M, _>(pos, &peaks, elem, &mut proof_iter)?; - let new_proof = vec![root_elem]; - let new_proof_iter = new_proof.iter().chain(proof_iter); - calculate_root::<_, M, _>(new_pos, new_elem, new_mmr_size, new_proof_iter) + leaves.push((new_pos, new_elem)); + calculate_root::<_, M, _>(leaves, new_mmr_size, self.proof.iter()) } } - pub fn verify(&self, root: T, pos: u64, elem: T) -> Result { - self.calculate_root(pos, elem) + pub fn verify(&self, root: T, leaves: Vec<(u64, T)>) -> Result { + self.calculate_root(leaves) .map(|calculated_root| calculated_root == root) } } fn calculate_peak_root< 'a, - T: 'a + PartialEq + Debug, + T: 'a + PartialEq + Debug + Clone, M: Merge, I: Iterator, >( - mut pos: u64, - peaks: &[u64], - elem: T, + leaves: Vec<(u64, T)>, + peak_pos: u64, proof_iter: &mut I, -) -> Result<(T, u64)> { - let mut root_elem = elem; - let mut height = 0; - // calculate peak's merkle root - // start bagging peaks if pos reach a peak pos - while peaks.binary_search(&pos).is_err() { - let proof = match proof_iter.next() { - Some(proof) => proof, - None => break, - }; - // verify merkle path - let pos_height = pos_height_in_tree(pos); - let next_height = pos_height_in_tree(pos + 1); - root_elem = if next_height > pos_height { - // to next pos - pos += 1; - M::merge(proof, &root_elem) - } else { - pos += parent_offset(height); - M::merge(&root_elem, proof) - }; - height += 1 +) -> Result { + debug_assert!(!leaves.is_empty(), "can't be empty"); + // tree parent_pos -> sub tree root + let mut tree_buf: BTreeMap = leaves + .into_iter() + .map(|(pos, item)| (pos, (item, 0u32))) + .collect(); + + // calculate tree root from each items + while !tree_buf.is_empty() { + let (pos, _item) = tree_buf.iter().next().unwrap(); + let mut pos = *pos; + let (item, mut height) = tree_buf.remove(&pos).unwrap(); + if pos == peak_pos { + // return root + return Ok(item); + } + let next_pos = tree_buf + .iter() + .next() + .map(|(pos, _item)| *pos) + .unwrap_or(peak_pos); + let mut item = item.clone(); + while pos < peak_pos { + // verify merkle path + let pos_height = pos_height_in_tree(pos); + let next_height = pos_height_in_tree(pos + 1); + let is_right_side = next_height > pos_height; + if is_right_side { + // to next pos + pos += 1; + } else { + pos += parent_offset(height); + } + height += 1; + if pos > next_pos || tree_buf.contains_key(&pos) { + break; + } + let proof = proof_iter.next().ok_or(Error::CorruptedProof)?; + item = if is_right_side { + M::merge(proof, &item) + } else { + M::merge(&item, proof) + }; + } + match tree_buf.entry(pos) { + Entry::Vacant(entry) => { + entry.insert((item, height)); + } + Entry::Occupied(mut entry) => { + // exists a same parent node sibling, merge then update the slot + // note, we are always on right branch since the tree is calculated from left to right + item = M::merge(&entry.get().0, &item); + entry.insert((item, height)); + } + } } - Ok((root_elem, pos)) + Err(Error::CorruptedProof) } -fn calculate_root<'a, T: 'a + PartialEq + Debug, M: Merge, I: Iterator>( - pos: u64, - elem: T, +fn calculate_peaks_hashes< + 'a, + T: 'a + PartialEq + Debug + Clone, + M: Merge, + I: Iterator, +>( + mut leaves: Vec<(u64, T)>, mmr_size: u64, mut proof_iter: I, -) -> Result { - let peaks = get_peaks(mmr_size); - let (mut root_elem, pos) = calculate_peak_root::<_, M, _>(pos, &peaks, elem, &mut proof_iter)?; - - // bagging peaks - // bagging with left peaks if pos is last peak(last pos) - let mut bagging_left = pos == mmr_size - 1; - for proof in &mut proof_iter { - root_elem = if bagging_left { - M::merge(&root_elem, &proof) - } else { - // we are not in the last peak, so bag with right peaks first - // notice the right peaks is already bagging into one hash in proof, - // so after this merge, the remain proofs are always left peaks. - bagging_left = true; - M::merge(&proof, &root_elem) - }; +) -> Result> { + // special handle the only 1 leaf MMR + if mmr_size == 1 && leaves.len() == 1 && leaves[0].0 == 0 { + return Ok(leaves.into_iter().map(|(_pos, item)| item).collect()); } - Ok(root_elem) + // sort items by position + leaves.sort_by_key(|(pos, _)| *pos); + let peaks = get_peaks(mmr_size); + + let mut peaks_hashes: Vec = Vec::with_capacity(peaks.len() + 1); + for peak_pos in peaks { + let mut leaves: Vec<_> = take_while_vec(&mut leaves, |(pos, _)| *pos <= peak_pos); + let peak_root = if leaves.len() == 1 && leaves[0].0 == peak_pos { + // leaf is the peak + leaves.remove(0).1 + } else if leaves.is_empty() { + // if empty, means the next proof is a peak root or rhs bagged root + if let Some(peak_root) = proof_iter.next() { + peak_root.clone() + } else { + // means that either all right peaks are bagged, or proof is corrupted + // so we break loop and check no items left + break; + } + } else { + calculate_peak_root::<_, M, _>(leaves, peak_pos, &mut proof_iter)? + }; + peaks_hashes.push(peak_root.clone()); + } + + // ensure nothing left in leaves + if !leaves.is_empty() { + return Err(Error::CorruptedProof); + } + + // check rhs peaks + if let Some(rhs_peaks_hashes) = proof_iter.next() { + peaks_hashes.push(rhs_peaks_hashes.clone()); + } + // ensure nothing left in proof_iter + if proof_iter.next().is_some() { + return Err(Error::CorruptedProof); + } + Ok(peaks_hashes) +} + +fn bagging_peaks_hashes<'a, T: 'a + PartialEq + Debug + Clone, M: Merge>( + mut peaks_hashes: Vec, +) -> Result { + // bagging peaks + // bagging from right to left via hash(right, left). + while peaks_hashes.len() > 1 { + let right_peak = peaks_hashes.pop().expect("pop"); + let left_peak = peaks_hashes.pop().expect("pop"); + peaks_hashes.push(M::merge(&right_peak, &left_peak)); + } + peaks_hashes.pop().ok_or(Error::CorruptedProof) +} + +/// merkle proof +/// 1. sort items by position +/// 2. calculate root of each peak +/// 3. bagging peaks +fn calculate_root< + 'a, + T: 'a + PartialEq + Debug + Clone, + M: Merge, + I: Iterator, +>( + leaves: Vec<(u64, T)>, + mmr_size: u64, + proof_iter: I, +) -> Result { + let peaks_hashes = calculate_peaks_hashes::<_, M, _>(leaves, mmr_size, proof_iter)?; + bagging_peaks_hashes::<_, M>(peaks_hashes) +} + +fn take_while_vec bool>(v: &mut Vec, p: P) -> Vec { + for i in 0..v.len() { + if !p(&v[i]) { + return v.drain(..i).collect(); + } + } + v.drain(..).collect() } diff --git a/src/tests/test_accumulate_headers.rs b/src/tests/test_accumulate_headers.rs index ddbc007..372eae7 100644 --- a/src/tests/test_accumulate_headers.rs +++ b/src/tests/test_accumulate_headers.rs @@ -156,7 +156,7 @@ impl Prover { mmr.get_root()?.serialize(), self.headers[later_number as usize].0.chain_root ); - mmr.gen_proof(pos) + mmr.gen_proof(vec![pos]) } fn get_pos(&self, number: u64) -> u64 { @@ -188,6 +188,6 @@ fn test_insert_header() { let pos = leaf_index_to_pos(h1); assert_eq!(pos, prover.get_pos(h1)); assert_eq!(prove_elem, (&prover.store).get_elem(pos).unwrap().unwrap()); - let result = proof.verify(root, pos, prove_elem).expect("verify"); + let result = proof.verify(root, vec![(pos, prove_elem)]).expect("verify"); assert!(result); } diff --git a/src/tests/test_mmr.rs b/src/tests/test_mmr.rs index 4f33b04..2763fd3 100644 --- a/src/tests/test_mmr.rs +++ b/src/tests/test_mmr.rs @@ -2,8 +2,9 @@ use super::{MergeNumberHash, NumberHash}; use crate::{leaf_index_to_mmr_size, util::MemStore, Error, MMR}; use faster_hex::hex_string; use proptest::prelude::*; +use rand::{seq::SliceRandom, thread_rng}; -fn test_mmr(count: u32, proof_elem: u32) { +fn test_mmr(count: u32, proof_elem: Vec) { let store = MemStore::default(); let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); let positions: Vec = (0u32..count) @@ -11,14 +12,21 @@ fn test_mmr(count: u32, proof_elem: u32) { .collect(); let root = mmr.get_root().expect("get root"); let proof = mmr - .gen_proof(positions[proof_elem as usize]) + .gen_proof( + proof_elem + .iter() + .map(|elem| positions[*elem as usize]) + .collect(), + ) .expect("gen proof"); mmr.commit().expect("commit changes"); let result = proof .verify( root, - positions[proof_elem as usize], - NumberHash::from(proof_elem), + proof_elem + .iter() + .map(|elem| (positions[*elem as usize], NumberHash::from(*elem))) + .collect(), ) .unwrap(); assert!(result); @@ -32,15 +40,14 @@ fn test_gen_new_root_from_proof(count: u32) { .collect(); let elem = count - 1; let pos = positions[elem as usize]; - let proof = mmr.gen_proof(pos).expect("gen proof"); + let proof = mmr.gen_proof(vec![pos]).expect("gen proof"); let new_elem = count; let new_pos = mmr.push(NumberHash::from(new_elem)).unwrap(); let root = mmr.get_root().expect("get root"); mmr.commit().expect("commit changes"); let calculated_root = proof .calculate_root_with_new_leaf( - pos, - NumberHash::from(elem), + vec![(pos, NumberHash::from(elem))], new_pos, NumberHash::from(new_elem), leaf_index_to_mmr_size(new_elem.into()), @@ -73,38 +80,64 @@ fn test_empty_mmr_root() { #[test] fn test_mmr_3_peaks() { - test_mmr(11, 5); + test_mmr(11, vec![5]); } #[test] fn test_mmr_2_peaks() { - test_mmr(10, 5); + test_mmr(10, vec![5]); } #[test] fn test_mmr_1_peak() { - test_mmr(8, 5); + test_mmr(8, vec![5]); } #[test] fn test_mmr_first_elem_proof() { - test_mmr(11, 0); + test_mmr(11, vec![0]); } #[test] fn test_mmr_last_elem_proof() { - test_mmr(11, 10); + test_mmr(11, vec![10]); } #[test] fn test_mmr_1_elem() { - test_mmr(1, 0); + test_mmr(1, vec![0]); } #[test] fn test_mmr_2_elems() { - test_mmr(2, 0); - test_mmr(2, 1); + test_mmr(2, vec![0]); + test_mmr(2, vec![1]); +} + +#[test] +fn test_mmr_2_leaves_merkle_proof() { + test_mmr(11, vec![3, 7]); + test_mmr(11, vec![3, 4]); +} + +#[test] +fn test_mmr_2_sibling_leaves_merkle_proof() { + test_mmr(11, vec![4, 5]); + test_mmr(11, vec![5, 6]); + test_mmr(11, vec![6, 7]); +} + +#[test] +fn test_mmr_3_leaves_merkle_proof() { + test_mmr(11, vec![4, 5, 6]); + test_mmr(11, vec![3, 5, 7]); + test_mmr(11, vec![3, 4, 5]); + test_mmr(100, vec![3, 5, 13]); +} + +#[test] +fn test_gen_root_from_proof() { + test_gen_new_root_from_proof(11); } prop_compose! { @@ -117,8 +150,13 @@ prop_compose! { proptest! { #[test] - fn test_random_mmr((count , elem) in count_elem(500)) { - test_mmr(count, elem); + fn test_random_mmr(count in 10u32..500u32) { + let mut leaves: Vec = (0..count).collect(); + let mut rng = thread_rng(); + leaves.shuffle(&mut rng); + let leaves_count = rng.gen_range(1, count - 1); + leaves.truncate(leaves_count as usize); + test_mmr(count, leaves); } #[test] diff --git a/src/util.rs b/src/util.rs index f174cb5..96442b5 100644 --- a/src/util.rs +++ b/src/util.rs @@ -71,8 +71,8 @@ impl> MemMMR { Ok(pos) } - pub fn gen_proof(&self, pos: u64) -> Result> { + pub fn gen_proof(&self, pos_list: Vec) -> Result> { let mmr = MMR::>::new(self.mmr_size, &self.store); - mmr.gen_proof(pos) + mmr.gen_proof(pos_list) } }