diff --git a/src/mmr.rs b/src/mmr.rs index b9d5b94..53e2aab 100644 --- a/src/mmr.rs +++ b/src/mmr.rs @@ -29,6 +29,22 @@ impl MMR { merge: PhantomData, } } + + pub fn mmr_size(&self) -> u64 { + self.mmr_size + } + + pub fn is_empty(&self) -> bool { + self.mmr_size == 0 + } + + pub fn batch(&self) -> &MMRBatch { + &self.batch + } + + pub fn store(&self) -> &S { + self.batch.store() + } } impl, S: MMRStoreReadOps> MMR { @@ -42,14 +58,6 @@ impl, S: MMRStoreReadOps> MMR u64 { - self.mmr_size - } - - pub fn is_empty(&self) -> bool { - self.mmr_size == 0 - } - // push a element and return position pub fn push(&mut self, elem: T) -> Result { let mut elems: Vec = Vec::new(); @@ -219,7 +227,7 @@ impl, S: MMRStoreReadOps> MMR> MMR { - pub fn commit(self) -> Result<()> { + pub fn commit(&mut self) -> Result<()> { self.batch.commit() } } diff --git a/src/mmr_store.rs b/src/mmr_store.rs index 1e8d483..1586d70 100644 --- a/src/mmr_store.rs +++ b/src/mmr_store.rs @@ -17,6 +17,10 @@ impl MMRBatch { pub fn append(&mut self, pos: u64, elems: Vec) { self.memory_batch.push((pos, elems)); } + + pub fn store(&self) -> &Store { + &self.store + } } impl> MMRBatch { @@ -35,13 +39,9 @@ impl> MMRBatch { } impl> MMRBatch { - pub fn commit(self) -> Result<()> { - let Self { - mut store, - memory_batch, - } = self; - for (pos, elems) in memory_batch { - store.append(pos, elems)?; + pub fn commit(&mut self) -> Result<()> { + for (pos, elems) in self.memory_batch.drain(..) { + self.store.append(pos, elems)?; } Ok(()) } diff --git a/src/tests/test_accumulate_headers.rs b/src/tests/test_accumulate_headers.rs index ae9ab9e..71c26b8 100644 --- a/src/tests/test_accumulate_headers.rs +++ b/src/tests/test_accumulate_headers.rs @@ -106,7 +106,7 @@ impl Prover { let mut mmr = MMR::<_, MergeHashWithTD, _>::new(self.positions.len() as u64, &self.store); // get previous element let mut previous = if let Some(pos) = self.positions.last() { - MMRStoreReadOps::<_>::get_elem(&&self.store, *pos)?.expect("exists") + mmr.store().get_elem(*pos)?.expect("exists") } else { let genesis = Header::default(); diff --git a/src/tests/test_mmr.rs b/src/tests/test_mmr.rs index 6f0bbf0..7e4405c 100644 --- a/src/tests/test_mmr.rs +++ b/src/tests/test_mmr.rs @@ -1,6 +1,9 @@ use super::{MergeNumberHash, NumberHash}; use crate::{ - helper::pos_height_in_tree, leaf_index_to_mmr_size, util::MemStore, Error, MMRStoreReadOps, MMR, + helper::pos_height_in_tree, + leaf_index_to_mmr_size, + util::{MemMMR, MemStore}, + Error, }; use faster_hex::hex_string; use proptest::prelude::*; @@ -8,7 +11,7 @@ use rand::{seq::SliceRandom, thread_rng}; fn test_mmr(count: u32, proof_elem: Vec) { let store = MemStore::default(); - let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); let positions: Vec = (0u32..count) .map(|i| mmr.push(NumberHash::from(i)).unwrap()) .collect(); @@ -36,7 +39,7 @@ fn test_mmr(count: u32, proof_elem: Vec) { fn test_gen_new_root_from_proof(count: u32) { let store = MemStore::default(); - let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); let positions: Vec = (0u32..count) .map(|i| mmr.push(NumberHash::from(i)).unwrap()) .collect(); @@ -61,7 +64,7 @@ fn test_gen_new_root_from_proof(count: u32) { #[test] fn test_mmr_root() { let store = MemStore::default(); - let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); (0u32..11).for_each(|i| { mmr.push(NumberHash::from(i)).unwrap(); }); @@ -76,7 +79,7 @@ fn test_mmr_root() { #[test] fn test_empty_mmr_root() { let store = MemStore::::default(); - let mmr = MMR::<_, MergeNumberHash, _>::new(0, &store); + let mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); assert_eq!(Err(Error::GetRootOnEmpty), mmr.get_root()); } @@ -155,7 +158,7 @@ fn test_invalid_proof_verification( // optionally handroll proof from these positions handrolled_proof_positions: Option>, ) { - use crate::{util::MemMMR, Merge, MerkleProof}; + use crate::{Merge, MerkleProof}; use std::fmt::{Debug, Formatter}; // Simple item struct to allow debugging the contents of MMR nodes/peaks @@ -184,7 +187,8 @@ fn test_invalid_proof_verification( } } - let mut mmr: MemMMR = MemMMR::default(); + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MyMerge>::new(0, &store); let mut positions: Vec = Vec::new(); for i in 0u32..leaf_count { let pos = mmr.push(MyItem::Number(i)).unwrap(); @@ -194,7 +198,7 @@ fn test_invalid_proof_verification( let entries_to_verify: Vec<(u64, MyItem)> = positions_to_verify .iter() - .map(|pos| (*pos, mmr.store().get_elem(*pos).unwrap().unwrap())) + .map(|pos| (*pos, mmr.batch().get_elem(*pos).unwrap().unwrap())) .collect(); let mut tampered_entries_to_verify = entries_to_verify.clone(); @@ -211,7 +215,7 @@ fn test_invalid_proof_verification( mmr.mmr_size(), handrolled_proof_positions .iter() - .map(|pos| mmr.store().get_elem(*pos).unwrap().unwrap()) + .map(|pos| mmr.batch().get_elem(*pos).unwrap().unwrap()) .collect(), ) }); diff --git a/src/util.rs b/src/util.rs index 79dec26..f021de3 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,6 @@ use crate::collections::BTreeMap; -use crate::{vec::Vec, MMRStoreReadOps, MMRStoreWriteOps, Merge, MerkleProof, Result, MMR}; +use crate::{vec::Vec, MMRStoreReadOps, MMRStoreWriteOps, Result, MMR}; use core::cell::RefCell; -use core::marker::PhantomData; #[derive(Clone)] pub struct MemStore(RefCell>); @@ -34,50 +33,4 @@ impl MMRStoreWriteOps for &MemStore { } } -pub struct MemMMR { - store: MemStore, - mmr_size: u64, - merge: PhantomData, -} - -impl> Default for MemMMR { - fn default() -> Self { - Self::new(0, Default::default()) - } -} - -impl> MemMMR { - pub fn new(mmr_size: u64, store: MemStore) -> Self { - MemMMR { - mmr_size, - store, - merge: PhantomData, - } - } - - pub fn store(&self) -> &MemStore { - &self.store - } - - pub fn mmr_size(&self) -> u64 { - self.mmr_size - } - - pub fn get_root(&self) -> Result { - let mmr = MMR::>::new(self.mmr_size, &self.store); - mmr.get_root() - } - - pub fn push(&mut self, elem: T) -> Result { - let mut mmr = MMR::>::new(self.mmr_size, &self.store); - let pos = mmr.push(elem)?; - self.mmr_size = mmr.mmr_size(); - mmr.commit()?; - Ok(pos) - } - - pub fn gen_proof(&self, pos_list: Vec) -> Result> { - let mmr = MMR::>::new(self.mmr_size, &self.store); - mmr.gen_proof(pos_list) - } -} +pub type MemMMR<'a, T, M> = MMR>;