use std::cmp::min;
use crate::tree::MerkleTree;
use crate::util::{crypto_library_init, hash_block, hash_hashes, HASHES_PER_BLOCK};
use crate::{Hash, BLOCK_SIZE};
#[derive(Clone, Debug)]
pub struct MerkleTreeBuilder {
block: Vec<u8>,
levels: Vec<Vec<Hash>>,
}
impl Default for MerkleTreeBuilder {
fn default() -> Self {
crypto_library_init();
Self { levels: vec![Vec::new()], block: Vec::with_capacity(BLOCK_SIZE) }
}
}
impl MerkleTreeBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn write(&mut self, buf: &[u8]) {
let buf = if self.block.is_empty() {
buf
} else {
let left = BLOCK_SIZE - self.block.len();
let prefix = min(buf.len(), left);
let (buf, rest) = buf.split_at(prefix);
self.block.extend_from_slice(buf);
if self.block.len() == BLOCK_SIZE {
self.push_data_hash(self.hash_block(&self.block[..]));
}
rest
};
for block in buf.chunks(BLOCK_SIZE) {
if block.len() == BLOCK_SIZE {
self.push_data_hash(self.hash_block(block));
} else {
self.block.extend_from_slice(block);
}
}
}
fn hash_block(&self, block: &[u8]) -> Hash {
hash_block(block, self.levels[0].len() * BLOCK_SIZE)
}
pub fn push_data_hash(&mut self, hash: Hash) {
self.block.clear();
self.levels[0].push(hash);
if self.levels[0].len() % HASHES_PER_BLOCK == 0 {
self.commit_tail_block(0);
}
}
fn commit_tail_block(&mut self, level: usize) {
let len = self.levels[level].len();
let next_level = level + 1;
if next_level >= self.levels.len() {
self.levels.push(Vec::new());
}
let first_hash = if len % HASHES_PER_BLOCK == 0 {
len - HASHES_PER_BLOCK
} else {
len - (len % HASHES_PER_BLOCK)
};
let hash = hash_hashes(
&self.levels[level][first_hash..],
next_level,
self.levels[next_level].len() * BLOCK_SIZE,
);
self.levels[next_level].push(hash);
if self.levels[next_level].len() % HASHES_PER_BLOCK == 0 {
self.commit_tail_block(next_level);
}
}
pub fn finish(mut self) -> MerkleTree {
if !self.block.is_empty() || self.levels[0].is_empty() {
self.push_data_hash(self.hash_block(&self.block[..]));
}
for level in 0.. {
if level >= self.levels.len() {
break;
}
let len = self.levels[level].len();
if len > 1 && len % HASHES_PER_BLOCK != 0 {
self.commit_tail_block(level);
}
}
MerkleTree::from_levels(self.levels)
}
}
impl From<MerkleTreeBuilder> for MerkleTree {
fn from(builder: MerkleTreeBuilder) -> Self {
builder.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_case::test_case;
#[allow(clippy::unused_unit)]
#[test_case(vec![], "15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b" ; "test_empty")]
#[test_case(vec![0xFF; 8192], "68d131bc271f9c192d4f6dcd8fe61bef90004856da19d0f2f514a7f4098b0737"; "test_oneblock")]
#[test_case(vec![0xFF; 65536], "f75f59a944d2433bc6830ec243bfefa457704d2aed12f30539cd4f18bf1d62cf"; "test_small")]
#[test_case(vec![0xFF; 2105344], "7d75dfb18bfd48e03b5be4e8e9aeea2f89880cb81c1551df855e0d0a0cc59a67"; "test_large")]
#[test_case(vec![0xFF; 2109440], "7577266aa98ce587922fdc668c186e27f3c742fb1b732737153b70ae46973e43"; "test_unaligned")]
fn tests(input: Vec<u8>, output: &str) {
let mut tree = MerkleTreeBuilder::new();
tree.write(input.as_slice());
let actual = tree.finish().root();
let expected: Hash = output.parse().unwrap();
assert_eq!(expected, actual);
}
#[test]
fn test_unaligned_single_block() {
let data = vec![0xFF; 8192];
let mut tree = MerkleTreeBuilder::new();
let (first, second) = &data[..].split_at(1024);
tree.write(first);
tree.write(second);
let root = tree.finish().root();
let expected =
"68d131bc271f9c192d4f6dcd8fe61bef90004856da19d0f2f514a7f4098b0737".parse().unwrap();
assert_eq!(root, expected);
}
#[test]
fn test_unaligned_n_block() {
let data = vec![0xFF; 65536];
let expected =
"f75f59a944d2433bc6830ec243bfefa457704d2aed12f30539cd4f18bf1d62cf".parse().unwrap();
for chunk_size in &[1, 100, 1024, 8193] {
let mut tree = MerkleTreeBuilder::new();
for block in data.as_slice().chunks(*chunk_size) {
tree.write(block);
}
let root = tree.finish().root();
assert_eq!(root, expected);
}
}
#[test]
fn test_fuchsia() {
let fuchsia: Vec<_> =
vec![0xff, 0x00, 0x80].into_iter().cycle().take(3 * BLOCK_SIZE).collect();
let mut t = MerkleTreeBuilder::new();
let mut remaining = 0xff0080;
while remaining > 0 {
let n = min(remaining, fuchsia.len());
t.write(&fuchsia[..n]);
remaining -= n;
}
let actual = t.finish().root();
let expected: Hash =
"2feb488cffc976061998ac90ce7292241dfa86883c0edc279433b5c4370d0f30".parse().unwrap();
assert_eq!(expected, actual);
}
}