commit c1211d2ca9a23ab005a43a26cc65454c20d2665c Author: jjy Date: Thu Sep 19 16:23:58 2019 +0800 Init merkle-mountain-range The MMR code is copy from https://github.com/nervosnetwork/ckb/pull/1598 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1e7caa9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +Cargo.lock +target/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..706fe8d --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "ckb-merkle-mountain-range" +version = "0.1.0" +authors = ["Nervos Core Dev "] +edition = "2018" +license = "MIT" +description = "A generalized merkle mountain range implementation" + +[dependencies] +failure = "0.1.5" + +[dev-dependencies] +faster-hex = "0.3" +criterion = "0.3" +rand = "0.6" +proptest = "0.9.4" +lazy_static = "1.3.0" +bytes = "0.4" +blake2b-rs = "0.1.4" + +[[bench]] +name = "mmr_benchmark" +harness = false + +[[bench]] +name = "helper_benchmark" +harness = false diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..19a3ffa --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +default: ci + +ci: fmt clippy test bench-test + +test: + cargo test --all --all-features + +bench-test: + cargo bench -- --test + +clippy: + cargo clippy --all --all-features --all-targets + +fmt: + cargo fmt --all -- --check diff --git a/README.md b/README.md new file mode 100644 index 0000000..8fd0ab2 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# Merkle mountain range + +A generalized merkle mountain range implementation. + +## References + +* https://github.com/mimblewimble/grin/blob/master/doc/mmr.md#structure +* https://github.com/opentimestamps/opentimestamps-server/blob/master/doc/merkle-mountain-range.md diff --git a/benches/helper_benchmark.rs b/benches/helper_benchmark.rs new file mode 100644 index 0000000..b45b1c6 --- /dev/null +++ b/benches/helper_benchmark.rs @@ -0,0 +1,28 @@ +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +use ckb_merkle_mountain_range::{leaf_index_to_mmr_size, leaf_index_to_pos}; + +use rand::{thread_rng, Rng}; + +fn bench(c: &mut Criterion) { + c.bench_function("left_index_to_pos", |b| { + let mut rng = thread_rng(); + b.iter(|| { + let leaf_index = rng.gen_range(50_000_000_000, 70_000_000_000); + leaf_index_to_pos(leaf_index); + }); + }); + + c.bench_function("left_index_to_mmr_size", |b| { + let mut rng = thread_rng(); + b.iter(|| { + let leaf_index = rng.gen_range(50_000_000_000, 70_000_000_000); + leaf_index_to_mmr_size(leaf_index); + }); + }); +} + +criterion_group!(benches, bench); +criterion_main!(benches); diff --git a/benches/mmr_benchmark.rs b/benches/mmr_benchmark.rs new file mode 100644 index 0000000..ba4447e --- /dev/null +++ b/benches/mmr_benchmark.rs @@ -0,0 +1,96 @@ +#[macro_use] +extern crate criterion; + +use criterion::Criterion; + +use bytes::Bytes; +use ckb_merkle_mountain_range::{util::MemStore, Error, MMRStore, Merge, Result, MMR}; +use rand::{seq::SliceRandom, thread_rng}; +use std::convert::TryFrom; + +use blake2b_rs::{Blake2b, Blake2bBuilder}; + +fn new_blake2b() -> Blake2b { + Blake2bBuilder::new(32).build() +} + +#[derive(Eq, PartialEq, Clone, Debug, Default)] +struct NumberHash(pub Bytes); +impl TryFrom for NumberHash { + type Error = Error; + fn try_from(num: u32) -> Result { + let mut hasher = new_blake2b(); + let mut hash = [0u8; 32]; + hasher.update(&num.to_le_bytes()); + hasher.finalize(&mut hash); + Ok(NumberHash(hash.to_vec().into())) + } +} + +struct MergeNumberHash; + +impl Merge for MergeNumberHash { + type Item = NumberHash; + fn merge(lhs: &Self::Item, rhs: &Self::Item) -> Self::Item { + let mut hasher = new_blake2b(); + let mut hash = [0u8; 32]; + hasher.update(&lhs.0); + hasher.update(&rhs.0); + hasher.finalize(&mut hash); + NumberHash(hash.to_vec().into()) + } +} + +fn prepare_mmr(count: u32) -> (u64, MemStore, Vec) { + let store = MemStore::default(); + let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + let positions: Vec = (0u32..count) + .map(|i| mmr.push(NumberHash::try_from(i).unwrap()).unwrap()) + .collect(); + let mmr_size = mmr.mmr_size(); + mmr.commit().expect("write to store"); + (mmr_size, store, positions) +} + +fn bench(c: &mut Criterion) { + c.bench_function_over_inputs( + "MMR insert", + |b, &&size| { + b.iter(|| prepare_mmr(size)); + }, + &[10_000, 100_000, 100_0000], + ); + + c.bench_function("MMR gen proof", |b| { + let (mmr_size, store, positions) = prepare_mmr(100_0000); + let mmr = MMR::<_, MergeNumberHash, _>::new(mmr_size, &store); + let mut rng = thread_rng(); + b.iter(|| mmr.gen_proof(*positions.choose(&mut rng).unwrap())); + }); + + c.bench_function("MMR verify", |b| { + let (mmr_size, store, positions) = prepare_mmr(100_0000); + let mmr = MMR::<_, MergeNumberHash, _>::new(mmr_size, &store); + let mut rng = thread_rng(); + let root: NumberHash = mmr.get_root().unwrap(); + let proofs: Vec<_> = (0..10_000) + .map(|_| { + let pos = positions.choose(&mut rng).unwrap(); + let elem = (&store).get_elem(*pos).unwrap().unwrap(); + let proof = mmr.gen_proof(*pos).unwrap(); + (pos, elem, proof) + }) + .collect(); + b.iter(|| { + let (pos, elem, proof) = proofs.choose(&mut rng).unwrap(); + proof.verify(root.clone(), **pos, elem.clone()).unwrap(); + }); + }); +} + +criterion_group!( + name = benches; + config = Criterion::default().sample_size(20); + targets = bench +); +criterion_main!(benches); diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..a2c0d1f --- /dev/null +++ b/src/error.rs @@ -0,0 +1,12 @@ +pub use failure::Fail; +pub type Result = ::std::result::Result; + +#[derive(Fail, Debug, PartialEq, Eq, Clone)] +pub enum Error { + #[fail(display = "Get root on an empty MMR")] + GetRootOnEmpty, + #[fail(display = "Inconsistent store")] + InconsistentStore, + #[fail(display = "Store error {}", _0)] + StoreError(String), +} diff --git a/src/helper.rs b/src/helper.rs new file mode 100644 index 0000000..c8dfff5 --- /dev/null +++ b/src/helper.rs @@ -0,0 +1,111 @@ +pub fn leaf_index_to_pos(index: u64) -> u64 { + if index == 0 { + return 0; + } + // leaf_count + let mut leaves = index + 1; + let mut tree_node_count = 0; + let mut height = 0u32; + while leaves > 1 { + // get heighest peak height + height = (leaves as f64).log2() as u32; + // calculate leaves in peak + let peak_leaves = 1 << height; + // heighest positon + let sub_tree_node_count = get_peak_pos_by_height(height) + 1; + tree_node_count += sub_tree_node_count; + leaves -= peak_leaves; + } + // two leaves can construct a new peak, the only valid number of leaves is 0 or 1. + debug_assert!(leaves == 0 || leaves == 1, "remain leaves incorrect"); + if leaves == 1 { + // add one pos for remain leaf + // equals to `tree_node_count - 1 + 1` + tree_node_count + } else { + let pos = tree_node_count - 1; + pos - u64::from(height) + } +} + +// TODO optimize +pub fn leaf_index_to_mmr_size(index: u64) -> u64 { + let mut pos = leaf_index_to_pos(index); + while pos_height_in_tree(pos + 1) > pos_height_in_tree(pos) { + pos += 1 + } + pos + 1 +} + +pub fn pos_height_in_tree(mut pos: u64) -> u32 { + pos += 1; + fn all_ones(num: u64) -> bool { + num != 0 && num.count_zeros() == num.leading_zeros() + } + fn jump_left(pos: u64) -> u64 { + let bit_length = 64 - pos.leading_zeros(); + let most_significant_bits = 1 << (bit_length - 1); + pos - (most_significant_bits - 1) + } + + while !all_ones(pos) { + pos = jump_left(pos) + } + + 64 - pos.leading_zeros() - 1 +} + +pub fn parent_offset(height: u32) -> u64 { + 2 << height +} + +pub fn sibling_offset(height: u32) -> u64 { + (2 << height) - 1 +} + +pub fn get_peaks(mmr_size: u64) -> Vec { + let mut pos_s = Vec::new(); + let (mut height, mut pos) = left_peak_height_pos(mmr_size); + pos_s.push(pos); + while height > 0 { + let peak = match get_right_peak(height, pos, mmr_size) { + Some(peak) => peak, + None => break, + }; + height = peak.0; + pos = peak.1; + pos_s.push(pos); + } + pos_s +} + +fn get_right_peak(mut height: u32, mut pos: u64, mmr_size: u64) -> Option<(u32, u64)> { + // move to right sibling pos + pos += sibling_offset(height); + // loop until we find a pos in mmr + while pos > mmr_size - 1 { + if height == 0 { + return None; + } + // move to left child + pos -= parent_offset(height - 1); + height -= 1; + } + Some((height, pos)) +} + +fn get_peak_pos_by_height(height: u32) -> u64 { + (1 << (height + 1)) - 2 +} + +fn left_peak_height_pos(mmr_size: u64) -> (u32, u64) { + let mut height = 1; + let mut prev_pos = 0; + let mut pos = get_peak_pos_by_height(height); + while pos < mmr_size { + height += 1; + prev_pos = pos; + pos = get_peak_pos_by_height(height); + } + (height - 1, prev_pos) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a36e787 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,14 @@ +mod error; +mod helper; +mod merge; +mod mmr; +mod mmr_store; +#[cfg(test)] +mod tests; +pub mod util; + +pub use error::{Error, Result}; +pub use helper::{leaf_index_to_mmr_size, leaf_index_to_pos}; +pub use merge::Merge; +pub use mmr::{MerkleProof, MMR}; +pub use mmr_store::MMRStore; diff --git a/src/merge.rs b/src/merge.rs new file mode 100644 index 0000000..8cbbfd3 --- /dev/null +++ b/src/merge.rs @@ -0,0 +1,4 @@ +pub trait Merge { + type Item; + fn merge(left: &Self::Item, right: &Self::Item) -> Self::Item; +} diff --git a/src/mmr.rs b/src/mmr.rs new file mode 100644 index 0000000..3fa80fb --- /dev/null +++ b/src/mmr.rs @@ -0,0 +1,207 @@ +//! Merkle Mountain Range +//! +//! references: +//! https://github.com/mimblewimble/grin/blob/master/doc/mmr.md#structure +//! https://github.com/mimblewimble/grin/blob/0ff6763ee64e5a14e70ddd4642b99789a1648a32/core/src/core/pmmr.rs#L606 + +use crate::helper::{get_peaks, parent_offset, pos_height_in_tree, sibling_offset}; +use crate::mmr_store::{MMRBatch, MMRStore}; +use crate::{Error, Merge, Result}; +use std::borrow::Cow; +use std::fmt::Debug; +use std::marker::PhantomData; + +pub struct MMR> { + mmr_size: u64, + batch: MMRBatch, + merge: PhantomData, +} + +impl<'a, T: Clone + PartialEq + Debug, M: Merge, S: MMRStore> MMR { + pub fn new(mmr_size: u64, store: S) -> Self { + MMR { + mmr_size, + batch: MMRBatch::new(store), + merge: PhantomData, + } + } + + // find internal MMR elem, the pos must exists, otherwise a error will return + fn find_elem<'b>(&self, pos: u64, hashes: &'b [T]) -> Result> { + let pos_offset = pos.checked_sub(self.mmr_size); + if let Some(elem) = pos_offset.and_then(|i| hashes.get(i as usize)) { + return Ok(Cow::Borrowed(elem)); + } + let elem = self.batch.get_elem(pos)?.ok_or(Error::InconsistentStore)?; + Ok(Cow::Owned(elem)) + } + + pub fn mmr_size(&self) -> u64 { + self.mmr_size + } + + pub fn is_empty(&self) -> bool { + self.mmr_size == 0 + } + + // push a element and return position + pub fn push(&mut self, elem: T) -> Result { + let mut elems: Vec = Vec::new(); + // position of new elem + let elem_pos = self.mmr_size; + elems.push(elem); + let mut height = 0u32; + let mut pos = elem_pos; + // continue to merge tree node if next pos heigher than current + while pos_height_in_tree(pos + 1) > height { + pos += 1; + let left_pos = pos - parent_offset(height); + let right_pos = left_pos + sibling_offset(height); + let left_elem = self.find_elem(left_pos, &elems)?; + let right_elem = self.find_elem(right_pos, &elems)?; + let parent_elem = M::merge(&left_elem, &right_elem); + elems.push(parent_elem); + height += 1 + } + // store hashes + self.batch.append(elem_pos, elems); + // update mmr_size + self.mmr_size = pos + 1; + Ok(elem_pos) + } + + /// get_root + pub fn get_root(&self) -> Result { + if self.mmr_size == 0 { + return Err(Error::GetRootOnEmpty); + } else if self.mmr_size == 1 { + return self.batch.get_elem(0)?.ok_or(Error::InconsistentStore); + } + let peaks = get_peaks(self.mmr_size); + self.bag_rhs_peaks(0, &peaks)? + .ok_or(Error::InconsistentStore) + } + + fn bag_rhs_peaks(&self, skip_peak_pos: u64, peaks: &[u64]) -> Result> { + let mut rhs_peak_elems: Vec = peaks + .iter() + .filter(|&&p| p > skip_peak_pos) + .map(|&p| self.batch.get_elem(p)) + .collect::>>()? + .ok_or(Error::InconsistentStore)?; + while rhs_peak_elems.len() > 1 { + let right_peak = rhs_peak_elems.pop().expect("pop"); + let left_peak = rhs_peak_elems.pop().expect("pop"); + rhs_peak_elems.push(M::merge(&right_peak, &left_peak)); + } + Ok(rhs_peak_elems.pop()) + } + + pub fn gen_proof(&self, mut pos: u64) -> Result> { + let mut proof: Vec = Vec::new(); + let mut height = 0; + while pos < self.mmr_size { + let pos_height = pos_height_in_tree(pos); + let next_height = pos_height_in_tree(pos + 1); + let (sib_pos, next_pos) = if next_height > pos_height { + // implies pos is right sibling + let sib_pos = pos - sibling_offset(height); + (sib_pos, pos + 1) + } else { + // pos is left sibling + let sib_pos = pos + sibling_offset(height); + (sib_pos, pos + parent_offset(height)) + }; + if sib_pos > self.mmr_size - 1 { + break; + } + proof.push( + self.batch + .get_elem(sib_pos)? + .ok_or(Error::InconsistentStore)?, + ); + pos = next_pos; + height += 1; + } + // now we get peak merkle proof + let peak_pos = pos; + // calculate bagging proof + let peaks = get_peaks(self.mmr_size); + if let Some(rhs_peak_hash) = self.bag_rhs_peaks(peak_pos, &peaks[..])? { + proof.push(rhs_peak_hash); + } + let lhs_peaks: Vec<_> = peaks + .iter() + .filter(|&&p| p < peak_pos) + .map(|&p| self.batch.get_elem(p)) + .rev() + .collect::>>()? + .ok_or(Error::InconsistentStore)?; + proof.extend(lhs_peaks); + Ok(MerkleProof::new(self.mmr_size, proof)) + } + + pub fn commit(self) -> Result<()> { + self.batch.commit() + } +} + +#[derive(Debug)] +pub struct MerkleProof { + mmr_size: u64, + proof: Vec, + merge: PhantomData, +} + +impl> MerkleProof { + pub fn new(mmr_size: u64, proof: Vec) -> Self { + MerkleProof { + mmr_size, + proof, + merge: PhantomData, + } + } + + pub fn verify(&self, root: T, mut pos: u64, elem: T) -> Result { + let peaks = get_peaks(self.mmr_size); + let mut sum_elem = elem; + let mut height = 0; + let mut proof_iter = self.proof.iter(); + // calculate peak's merkle root + // start bagging peaks if pos reach a peak pos + while !peaks.contains(&pos) { + let proof = match proof_iter.next() { + Some(proof) => proof, + None => break, + }; + // verify merkle path + let pos_height = pos_height_in_tree(pos); + let next_height = pos_height_in_tree(pos + 1); + sum_elem = if next_height > pos_height { + // to next pos + pos += 1; + M::merge(proof, &sum_elem) + } else { + pos += parent_offset(height); + M::merge(&sum_elem, proof) + }; + height += 1 + } + + // bagging peaks + // bagging with left peaks if pos is last peak + let mut bagging_left = Some(&pos) == peaks.last(); + for proof in &mut proof_iter { + sum_elem = if bagging_left { + M::merge(&sum_elem, &proof) + } else { + // we are not in the last peak, so bag with right peaks first + // notice the right peaks is already bagging into one hash in proof, + // so after this merge, the remain proofs are always left peaks. + bagging_left = true; + M::merge(&proof, &sum_elem) + }; + } + Ok(root == sum_elem) + } +} diff --git a/src/mmr_store.rs b/src/mmr_store.rs new file mode 100644 index 0000000..2114ad6 --- /dev/null +++ b/src/mmr_store.rs @@ -0,0 +1,58 @@ +use crate::Result; + +#[derive(Default)] +pub struct MMRBatch> { + memory_batch: Vec<(u64, Vec)>, + store: Store, +} + +impl> MMRBatch { + pub fn new(store: Store) -> Self { + MMRBatch { + memory_batch: Vec::new(), + store, + } + } + + pub fn append(&mut self, pos: u64, elems: Vec) { + self.memory_batch.push((pos, elems)); + } + + pub fn get_elem(&self, pos: u64) -> Result> { + for (start_pos, elems) in self.memory_batch.iter().rev() { + if pos < *start_pos { + continue; + } else if pos < start_pos + elems.len() as u64 { + return Ok(elems.get((pos - start_pos) as usize).cloned()); + } else { + break; + } + } + self.store.get_elem(pos) + } + + pub fn commit(self) -> Result<()> { + let Self { + mut store, + memory_batch, + } = self; + for (pos, elems) in memory_batch { + store.append(pos, elems)?; + } + Ok(()) + } +} + +impl> IntoIterator for MMRBatch { + type Item = (u64, Vec); + type IntoIter = ::std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.memory_batch.into_iter() + } +} + +pub trait MMRStore { + fn get_elem(&self, pos: u64) -> Result>; + fn append(&mut self, pos: u64, elems: Vec) -> Result<()>; +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs new file mode 100644 index 0000000..221e645 --- /dev/null +++ b/src/tests/mod.rs @@ -0,0 +1,37 @@ +mod test_accumulate_headers; +mod test_helper; +mod test_mmr; + +use crate::Merge; +use blake2b_rs::{Blake2b, Blake2bBuilder}; +use bytes::Bytes; + +fn new_blake2b() -> Blake2b { + Blake2bBuilder::new(32).build() +} + +#[derive(Eq, PartialEq, Clone, Debug, Default)] +struct NumberHash(pub Bytes); +impl From for NumberHash { + fn from(num: u32) -> Self { + let mut hasher = new_blake2b(); + let mut hash = [0u8; 32]; + hasher.update(&num.to_le_bytes()); + hasher.finalize(&mut hash); + NumberHash(hash.to_vec().into()) + } +} + +struct MergeNumberHash; + +impl Merge for MergeNumberHash { + type Item = NumberHash; + fn merge(lhs: &Self::Item, rhs: &Self::Item) -> Self::Item { + let mut hasher = new_blake2b(); + let mut hash = [0u8; 32]; + hasher.update(&lhs.0); + hasher.update(&rhs.0); + hasher.finalize(&mut hash); + NumberHash(hash.to_vec().into()) + } +} diff --git a/src/tests/test_accumulate_headers.rs b/src/tests/test_accumulate_headers.rs new file mode 100644 index 0000000..ddbc007 --- /dev/null +++ b/src/tests/test_accumulate_headers.rs @@ -0,0 +1,193 @@ +use super::new_blake2b; +use crate::{leaf_index_to_pos, util::MemStore, MMRStore, Merge, MerkleProof, Result, MMR}; +use bytes::Bytes; +use std::fmt::{self, Debug}; + +#[derive(Clone)] +struct Header { + number: u64, + parent_hash: Bytes, + difficulty: u64, + // MMR root + chain_root: Bytes, +} + +impl Header { + fn default() -> Self { + Header { + number: 0, + parent_hash: vec![0; 32].into(), + difficulty: 0, + chain_root: vec![0; 32].into(), + } + } + + fn hash(&self) -> Bytes { + let mut hasher = new_blake2b(); + let mut hash = [0u8; 32]; + hasher.update(&self.number.to_le_bytes()); + hasher.update(&self.parent_hash); + hasher.update(&self.difficulty.to_le_bytes()); + hasher.update(&self.chain_root); + hasher.finalize(&mut hash); + hash.to_vec().into() + } +} + +#[derive(Eq, PartialEq, Clone, Default)] +struct HashWithTD { + hash: Bytes, + td: u64, +} + +impl HashWithTD { + fn serialize(&self) -> Bytes { + let mut data = self.hash.clone(); + data.extend(&self.td.to_le_bytes()); + data + } + + fn deserialize(mut data: Bytes) -> Self { + assert_eq!(data.len(), 40); + let mut td_bytes = [0u8; 8]; + td_bytes.copy_from_slice(&data[32..]); + let td = u64::from_le_bytes(td_bytes); + data.truncate(32); + HashWithTD { hash: data, td } + } +} + +impl Debug for HashWithTD { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "HashWithTD {{ hash: {}, td: {} }}", + faster_hex::hex_string(&self.hash).unwrap(), + self.td + ) + } +} + +struct MergeHashWithTD; + +impl Merge for MergeHashWithTD { + type Item = HashWithTD; + fn merge(lhs: &Self::Item, rhs: &Self::Item) -> Self::Item { + let mut hasher = new_blake2b(); + let mut hash = [0u8; 32]; + hasher.update(&lhs.serialize()); + hasher.update(&rhs.serialize()); + hasher.finalize(&mut hash); + let td = lhs.td + rhs.td; + HashWithTD { + hash: hash.to_vec().into(), + td, + } + } +} + +struct Prover { + headers: Vec<(Header, u64)>, + positions: Vec, + store: MemStore, +} + +impl Prover { + fn new() -> Prover { + let store = MemStore::default(); + Prover { + headers: Vec::new(), + positions: Vec::new(), + store, + } + } + + fn gen_blocks(&mut self, count: u64) -> Result<()> { + let mut mmr = MMR::<_, MergeHashWithTD, _>::new(self.positions.len() as u64, &self.store); + // get previous element + let mut previous = if let Some(pos) = self.positions.last() { + MMRStore::<_>::get_elem(&&self.store, *pos)?.expect("exists") + } else { + let genesis = Header::default(); + + let previous = HashWithTD { + hash: genesis.hash(), + td: genesis.difficulty, + }; + self.headers.push((genesis, previous.td)); + let pos = mmr.push(previous.clone())?; + self.positions.push(pos); + previous + }; + let last_number = self.headers.last().unwrap().0.number; + for i in (last_number + 1)..=(last_number + count) { + let block = Header { + number: i, + parent_hash: previous.hash.clone(), + difficulty: i, + chain_root: mmr.get_root()?.serialize(), + }; + previous = HashWithTD { + hash: block.hash(), + td: block.difficulty, + }; + let pos = mmr.push(previous.clone())?; + self.positions.push(pos); + self.headers.push((block, previous.td)); + } + mmr.commit() + } + + fn get_header(&self, number: u64) -> (Header, u64) { + self.headers[number as usize].clone() + } + + // generate proof that headers are in same chain + fn gen_proof( + &mut self, + number: u64, + later_number: u64, + ) -> Result> { + assert!(number < later_number); + let pos = self.positions[number as usize]; + let later_pos = self.positions[later_number as usize]; + let mmr = MMR::new(later_pos, &self.store); + assert_eq!( + mmr.get_root()?.serialize(), + self.headers[later_number as usize].0.chain_root + ); + mmr.gen_proof(pos) + } + + fn get_pos(&self, number: u64) -> u64 { + self.positions[number as usize] + } +} + +#[test] +fn test_insert_header() { + let mut prover = Prover::new(); + prover.gen_blocks(30).expect("gen blocks"); + let h1 = 11; + let h2 = 19; + + // get headers from prover + let prove_elem = { + let (header, td) = prover.get_header(h1); + HashWithTD { + hash: header.hash(), + td, + } + }; + let root = { + let (later_header, _later_td) = prover.get_header(h2); + HashWithTD::deserialize(later_header.chain_root) + }; + // gen proof, blocks are in the same chain + let proof = prover.gen_proof(h1, h2).expect("gen proof"); + let pos = leaf_index_to_pos(h1); + assert_eq!(pos, prover.get_pos(h1)); + assert_eq!(prove_elem, (&prover.store).get_elem(pos).unwrap().unwrap()); + let result = proof.verify(root, pos, prove_elem).expect("verify"); + assert!(result); +} diff --git a/src/tests/test_helper.rs b/src/tests/test_helper.rs new file mode 100644 index 0000000..dbc08f0 --- /dev/null +++ b/src/tests/test_helper.rs @@ -0,0 +1,82 @@ +use super::{MergeNumberHash, NumberHash}; +use crate::{ + helper::{get_peaks, pos_height_in_tree}, + leaf_index_to_mmr_size, leaf_index_to_pos, + util::MemStore, + MMR, +}; +use lazy_static::lazy_static; +use proptest::prelude::*; + +lazy_static! { + /// Positions of 0..100_000 elem + static ref INDEX_TO_POS: Vec = { + let store = MemStore::default(); + let mut mmr = MMR::<_,MergeNumberHash,_>::new(0, &store); + (0u32..100_000) + .map(|i| mmr.push(NumberHash::from(i)).unwrap()) + .collect() + }; + /// mmr size when 0..100_000 elem + static ref INDEX_TO_MMR_SIZE: Vec = { + let store = MemStore::default(); + let mut mmr = MMR::<_,MergeNumberHash,_>::new(0, &store); + (0u32..100_000) + .map(|i| { + mmr.push(NumberHash::from(i)).unwrap(); + mmr.mmr_size() + }) + .collect() + }; +} + +#[test] +fn test_leaf_index_to_pos() { + assert_eq!(leaf_index_to_pos(0), 0); + assert_eq!(leaf_index_to_pos(1), 1); + assert_eq!(leaf_index_to_pos(2), 3); +} + +#[test] +fn test_leaf_index_to_mmr_size() { + assert_eq!(leaf_index_to_mmr_size(0), 1); + assert_eq!(leaf_index_to_mmr_size(1), 3); + assert_eq!(leaf_index_to_mmr_size(2), 4); +} + +#[test] +fn test_pos_height_in_tree() { + assert_eq!(pos_height_in_tree(0), 0); + assert_eq!(pos_height_in_tree(1), 0); + assert_eq!(pos_height_in_tree(2), 1); + assert_eq!(pos_height_in_tree(3), 0); + assert_eq!(pos_height_in_tree(4), 0); + assert_eq!(pos_height_in_tree(6), 2); + assert_eq!(pos_height_in_tree(7), 0); +} + +#[test] +fn test_get_peaks() { + assert_eq!(get_peaks(0), vec![0]); + assert_eq!(get_peaks(1), vec![0]); + assert_eq!(get_peaks(2), vec![0]); + assert_eq!(get_peaks(3), vec![2]); + assert_eq!(get_peaks(4), vec![2, 3]); + assert_eq!(get_peaks(5), vec![2, 3]); + assert_eq!(get_peaks(6), vec![2, 5]); + assert_eq!(get_peaks(7), vec![6]); + assert_eq!(get_peaks(19), vec![14, 17, 18]); +} + +proptest! { + #[test] + fn test_leaf_index_to_pos_randomly(index in 0..INDEX_TO_POS.len()) { + let pos = leaf_index_to_pos(index as u64); + assert_eq!(pos, INDEX_TO_POS[index]); + } + + #[test] + fn test_leaf_index_to_mmr_size_randomly(index in 0..INDEX_TO_MMR_SIZE.len()) { + assert_eq!(leaf_index_to_mmr_size(index as u64), INDEX_TO_MMR_SIZE[index]); + } +} diff --git a/src/tests/test_mmr.rs b/src/tests/test_mmr.rs new file mode 100644 index 0000000..d9242b2 --- /dev/null +++ b/src/tests/test_mmr.rs @@ -0,0 +1,98 @@ +use super::{MergeNumberHash, NumberHash}; +use crate::{util::MemStore, Error, MMR}; +use faster_hex::hex_string; +use proptest::prelude::*; + +fn test_mmr(count: u32, proof_elem: u32) { + let store = MemStore::default(); + let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + let positions: Vec = (0u32..count) + .map(|i| mmr.push(NumberHash::from(i)).unwrap()) + .collect(); + let root = mmr.get_root().expect("get root"); + let proof = mmr + .gen_proof(positions[proof_elem as usize]) + .expect("gen proof"); + mmr.commit().expect("commit changes"); + let result = proof + .verify( + root, + positions[proof_elem as usize], + NumberHash::from(proof_elem), + ) + .unwrap(); + assert!(result); +} + +#[test] +fn test_mmr_root() { + let store = MemStore::default(); + let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + (0u32..11).for_each(|i| { + mmr.push(NumberHash::from(i)).unwrap(); + }); + let root = mmr.get_root().expect("get root"); + let hex_root = hex_string(&root.0).unwrap(); + assert_eq!( + "f6794677f37a57df6a5ec36ce61036e43a36c1a009d05c81c9aa685dde1fd6e3", + hex_root + ); +} + +#[test] +fn test_empty_mmr_root() { + let store = MemStore::::default(); + let mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + assert_eq!(Err(Error::GetRootOnEmpty), mmr.get_root()); +} + +#[test] +fn test_mmr_3_peaks() { + test_mmr(11, 5); +} + +#[test] +fn test_mmr_2_peaks() { + test_mmr(10, 5); +} + +#[test] +fn test_mmr_1_peak() { + test_mmr(8, 5); +} + +#[test] +fn test_mmr_first_elem_proof() { + test_mmr(11, 0); +} + +#[test] +fn test_mmr_last_elem_proof() { + test_mmr(11, 10); +} + +#[test] +fn test_mmr_1_elem() { + test_mmr(1, 0); +} + +#[test] +fn test_mmr_2_elems() { + test_mmr(2, 0); + test_mmr(2, 1); +} + +prop_compose! { + fn count_elem(count: u32) + (elem in 0..count) + -> (u32, u32) { + (count, elem) + } +} + +proptest! { + #[test] + fn test_random_mmr((count , elem) in count_elem(500)) { + test_mmr(count, elem); + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..1c68af0 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,73 @@ +use crate::{MMRStore, Merge, Result, MMR}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::fmt::Debug; +use std::marker::PhantomData; + +#[derive(Clone)] +pub struct MemStore(RefCell>); + +impl Default for MemStore { + fn default() -> Self { + Self::new() + } +} + +impl MemStore { + fn new() -> Self { + MemStore(RefCell::new(HashMap::new())) + } +} + +impl MMRStore for &MemStore { + fn get_elem(&self, pos: u64) -> Result> { + Ok(self.0.borrow().get(&pos).cloned()) + } + + fn append(&mut self, pos: u64, elems: Vec) -> Result<()> { + let mut store = self.0.borrow_mut(); + for (i, elem) in elems.into_iter().enumerate() { + store.insert(pos + i as u64, elem); + } + Ok(()) + } +} + +pub struct MemMMR { + store: MemStore, + mmr_size: u64, + merge: PhantomData, +} + +impl> Default for MemMMR { + fn default() -> Self { + Self::new(0, Default::default()) + } +} + +impl> MemMMR { + pub fn new(mmr_size: u64, store: MemStore) -> Self { + MemMMR { + mmr_size, + store, + merge: PhantomData, + } + } + + pub fn store(&self) -> &MemStore { + &self.store + } + + pub fn get_root(&self) -> Result { + let mmr = MMR::>::new(self.mmr_size, &self.store); + mmr.get_root() + } + + pub fn push(&mut self, elem: T) -> Result { + let mut mmr = MMR::>::new(self.mmr_size, &self.store); + let pos = mmr.push(elem)?; + self.mmr_size = mmr.mmr_size(); + mmr.commit()?; + Ok(pos) + } +}