Skip to content

Commit

Permalink
Merge pull request ProvableHQ#2415 from ljedrz/perf/merkle_tree_padding
Browse files Browse the repository at this point in the history
Reduce zero-padding and cache zero-hashes in MerkleTree
  • Loading branch information
howardwu authored Apr 12, 2024
2 parents bbdb745 + 7fc8a2b commit 933f12f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
63 changes: 55 additions & 8 deletions console/collections/src/merkle_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
// Compute the empty hash.
let empty_hash = path_hasher.hash_empty()?;

// Calculate the size of the tree which excludes leafless nodes.
// The minimum tree size is either a single root node or the calculated number of nodes plus
// the supplied leaves; if the number of leaves is odd, an empty hash is added for padding.
let minimum_tree_size =
std::cmp::max(1, num_nodes + leaves.len() + if leaves.len() > 1 { leaves.len() % 2 } else { 0 });

// Initialize the Merkle tree.
let mut tree = vec![empty_hash; tree_size];
let mut tree = vec![empty_hash; minimum_tree_size];

// Compute and store each leaf hash.
tree[num_nodes..num_nodes + leaves.len()].copy_from_slice(&leaf_hasher.hash_leaves(leaves)?);
Expand All @@ -90,10 +96,22 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
while let Some(start) = parent(start_index) {
// Compute the end index of the current level.
let end = left_child(start);
// Construct the children for each node in the current level.
let tuples = (start..end).map(|i| (tree[left_child(i)], tree[right_child(i)])).collect::<Vec<_>>();
// Construct the children for each node in the current level; the leaves are padded, which means
// that there either are 2 children, or there are none, at which point we may stop iterating.
let tuples = (start..end)
.take_while(|&i| tree.get(left_child(i)).is_some())
.map(|i| (tree[left_child(i)], tree[right_child(i)]))
.collect::<Vec<_>>();
// Compute and store the hashes for each node in the current level.
tree[start..end].copy_from_slice(&path_hasher.hash_all_children(&tuples)?);
let num_full_nodes = tuples.len();
tree[start..][..num_full_nodes].copy_from_slice(&path_hasher.hash_all_children(&tuples)?);
// Use the precomputed empty node hash for every empty node, if there are any.
if start + num_full_nodes < end {
let empty_node_hash = path_hasher.hash_children(&empty_hash, &empty_hash)?;
for node in tree.iter_mut().take(end).skip(start + num_full_nodes) {
*node = empty_node_hash;
}
}
// Update the start index for the next level.
start_index = start;
}
Expand Down Expand Up @@ -144,8 +162,16 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
tree.extend(self.leaf_hashes()?);
// Extend the new Merkle tree with the new leaf hashes.
tree.extend(&self.leaf_hasher.hash_leaves(new_leaves)?);

// Calculate the size of the tree which excludes leafless nodes.
let new_number_of_leaves = self.number_of_leaves + new_leaves.len();
let minimum_tree_size = std::cmp::max(
1,
num_nodes + new_number_of_leaves + if new_number_of_leaves > 1 { new_number_of_leaves % 2 } else { 0 },
);

// Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
tree.resize(tree_size, self.empty_hash);
tree.resize(minimum_tree_size, self.empty_hash);
lap!(timer, "Hashed {} new leaves", new_leaves.len());

// Initialize a start index to track the starting index of the current level.
Expand Down Expand Up @@ -453,12 +479,20 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
// Compute the number of padded levels.
let padding_depth = DEPTH - tree_depth;

// Calculate the size of the tree which excludes leafless nodes.
let minimum_tree_size = std::cmp::max(
1,
num_nodes
+ updated_number_of_leaves
+ if updated_number_of_leaves > 1 { updated_number_of_leaves % 2 } else { 0 },
);

// Initialize the Merkle tree.
let mut tree = vec![self.empty_hash; num_nodes];
// Extend the new Merkle tree with the existing leaf hashes, excluding the last 'n' leaves.
tree.extend(&self.leaf_hashes()?[..updated_number_of_leaves]);
// Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
tree.resize(tree_size, self.empty_hash);
tree.resize(minimum_tree_size, self.empty_hash);
lap!(timer, "Resizing to {} leaves", updated_number_of_leaves);

// Initialize a start index to track the starting index of the current level.
Expand Down Expand Up @@ -627,6 +661,7 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
let timer = timer!("MerkleTree::compute_updated_tree");

// Compute and store the hashes for each level, iterating from the penultimate level to the root level.
let empty_hash = self.path_hasher.hash_empty()?;
while let (Some(start), Some(middle)) = (parent(start_index), parent(middle_index)) {
// Compute the end index of the current level.
let end = left_child(start);
Expand All @@ -651,7 +686,12 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
if let Some(middle_precompute) = parent(middle_precompute) {
// Construct the children for the new indices in the current level.
let tuples = (middle..middle_precompute)
.map(|i| (tree[left_child(i)], tree[right_child(i)]))
.map(|i| {
(
tree.get(left_child(i)).copied().unwrap_or(empty_hash),
tree.get(right_child(i)).copied().unwrap_or(empty_hash),
)
})
.collect::<Vec<_>>();
// Process the indices that need to be computed for the current level.
// If any level requires computing more than 100 nodes, borrow the tree for performance.
Expand Down Expand Up @@ -687,7 +727,14 @@ impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>
}
} else {
// Construct the children for the new indices in the current level.
let tuples = (middle..end).map(|i| (tree[left_child(i)], tree[right_child(i)])).collect::<Vec<_>>();
let tuples = (middle..end)
.map(|i| {
(
tree.get(left_child(i)).copied().unwrap_or(empty_hash),
tree.get(right_child(i)).copied().unwrap_or(empty_hash),
)
})
.collect::<Vec<_>>();
// Process the indices that need to be computed for the current level.
// If any level requires computing more than 100 nodes, borrow the tree for performance.
match tuples.len() >= 100 {
Expand Down
14 changes: 4 additions & 10 deletions console/collections/src/merkle_tree/tests/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ fn check_merkle_tree_depth_3_padded<E: Environment, LH: LeafHash<Hash = PH::Hash

// Rebuild the Merkle tree with the additional leaf.
merkle_tree.append(additional_leaves)?;
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(5, merkle_tree.number_of_leaves);

Expand All @@ -173,8 +173,6 @@ fn check_merkle_tree_depth_3_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
assert_eq!(expected_leaf3, merkle_tree.tree[10]);
assert_eq!(expected_leaf4, merkle_tree.tree[11]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[12]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[13]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[14]);

// Depth 2.
let expected_left0 = PathHash::hash_children(path_hasher, &expected_leaf0, &expected_leaf1)?;
Expand Down Expand Up @@ -258,7 +256,7 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash

// Rebuild the Merkle tree with the additional leaf.
merkle_tree.append(&[additional_leaves[0].clone()])?;
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(5, merkle_tree.number_of_leaves);

Expand All @@ -274,8 +272,6 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
assert_eq!(expected_leaf3, merkle_tree.tree[10]);
assert_eq!(expected_leaf4, merkle_tree.tree[11]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[12]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[13]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[14]);

// Depth 3.
let expected_left0 = PathHash::hash_children(path_hasher, &expected_leaf0, &expected_leaf1)?;
Expand Down Expand Up @@ -308,13 +304,13 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
// ------------------------------------------------------------------------------------------ //

// Ensure we're starting where we left off from the previous rebuild.
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(5, merkle_tree.number_of_leaves);

// Rebuild the Merkle tree with the additional leaf.
merkle_tree.append(&[additional_leaves[1].clone()])?;
assert_eq!(15, merkle_tree.tree.len());
assert_eq!(13, merkle_tree.tree.len());
// assert_eq!(0, merkle_tree.padding_tree.len());
assert_eq!(6, merkle_tree.number_of_leaves);

Expand All @@ -331,8 +327,6 @@ fn check_merkle_tree_depth_4_padded<E: Environment, LH: LeafHash<Hash = PH::Hash
assert_eq!(expected_leaf3, merkle_tree.tree[10]);
assert_eq!(expected_leaf4, merkle_tree.tree[11]);
assert_eq!(expected_leaf5, merkle_tree.tree[12]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[13]);
assert_eq!(path_hasher.hash_empty()?, merkle_tree.tree[14]);

// Depth 3.
let expected_left0 = PathHash::hash_children(path_hasher, &expected_leaf0, &expected_leaf1)?;
Expand Down

0 comments on commit 933f12f

Please sign in to comment.