diff --git a/src/merge.rs b/src/merge.rs index 8cbbfd3..a4d0cbb 100644 --- a/src/merge.rs +++ b/src/merge.rs @@ -1,4 +1,9 @@ pub trait Merge { type Item; + fn merge(left: &Self::Item, right: &Self::Item) -> Self::Item; + + fn merge_peaks(peak1: &Self::Item, peak2: &Self::Item) -> Self::Item { + Self::merge(peak1, peak2) + } } diff --git a/src/mmr.rs b/src/mmr.rs index fd4bfc4..8537f91 100644 --- a/src/mmr.rs +++ b/src/mmr.rs @@ -96,7 +96,7 @@ impl<'a, T: Clone + PartialEq + Debug, M: Merge, S: MMRStore> MMR 1 { let right_peak = rhs_peaks.pop().expect("pop"); let left_peak = rhs_peaks.pop().expect("pop"); - rhs_peaks.push(M::merge(&right_peak, &left_peak)); + rhs_peaks.push(M::merge_peaks(&right_peak, &left_peak)); } Ok(rhs_peaks.pop()) } @@ -394,7 +394,7 @@ fn bagging_peaks_hashes<'a, T: 'a + PartialEq + Debug + Clone, M: Merge 1 { let right_peak = peaks_hashes.pop().expect("pop"); let left_peak = peaks_hashes.pop().expect("pop"); - peaks_hashes.push(M::merge(&right_peak, &left_peak)); + peaks_hashes.push(M::merge_peaks(&right_peak, &left_peak)); } peaks_hashes.pop().ok_or(Error::CorruptedProof) } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 221e645..276a996 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,6 +1,7 @@ mod test_accumulate_headers; mod test_helper; mod test_mmr; +mod test_sequence; use crate::Merge; use blake2b_rs::{Blake2b, Blake2bBuilder}; diff --git a/src/tests/test_sequence.rs b/src/tests/test_sequence.rs new file mode 100644 index 0000000..1d69c07 --- /dev/null +++ b/src/tests/test_sequence.rs @@ -0,0 +1,98 @@ +use std::fmt; + +use proptest::proptest; +use rand::{prelude::*, thread_rng}; + +use crate::{util::MemStore, Merge, MMR}; + +#[derive(Eq, PartialEq, Clone, Default)] +struct NumberRange { + start: u32, + end: u32, +} + +struct MergeNumberRange; + +impl fmt::Debug for NumberRange { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NumberRange({}, {})", self.start, self.end) + } +} + +impl fmt::Debug for MergeNumberRange { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MergeNumberRange") + } +} + +impl From for NumberRange { + fn from(num: u32) -> Self { + Self { + start: num, + end: num, + } + } +} + +impl NumberRange { + fn is_normalized(&self) -> bool { + self.start <= self.end + } +} + +impl Merge for MergeNumberRange { + type Item = NumberRange; + fn merge(lhs: &Self::Item, rhs: &Self::Item) -> Self::Item { + Self::Item { + start: lhs.start, + end: rhs.end, + } + } + fn merge_peaks(lhs: &Self::Item, rhs: &Self::Item) -> Self::Item { + Self::merge(rhs, lhs) + } +} + +fn test_sequence_sub_func(count: u32, proof_elem: Vec) { + let store = MemStore::default(); + let mut mmr = MMR::<_, MergeNumberRange, _>::new(0, &store); + let positions = (0..count) + .map(|i| mmr.push(NumberRange::from(i)).expect("push")) + .collect::>(); + let root = mmr.get_root().expect("get_root"); + assert!(root.is_normalized()); + let proof = mmr + .gen_proof( + proof_elem + .iter() + .map(|elem| positions[*elem as usize]) + .collect(), + ) + .expect("gen_proof"); + for item in proof.proof_items() { + assert!(item.is_normalized()) + } + mmr.commit().expect("commit"); + let result = proof + .verify( + root, + proof_elem + .iter() + .map(|elem| (positions[*elem as usize], NumberRange::from(*elem))) + .collect(), + ) + .expect("verify"); + assert!(result); +} + +proptest! { + #[test] + fn test_sequence(count in 10u32..500u32) { + let mut leaves: Vec = (0..count).collect(); + let mut rng = thread_rng(); + leaves.shuffle(&mut rng); + let leaves_count = rng.gen_range(1, count - 1); + leaves.truncate(leaves_count as usize); + test_sequence_sub_func(count, leaves); + } +}