forked from facebook/winterfell
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconcurrent.rs
92 lines (76 loc) · 3.42 KB
/
concurrent.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
use crate::Hasher;
use alloc::vec::Vec;
use core::slice;
use utils::{iterators::*, rayon};
// CONSTANTS
// ================================================================================================
pub const MIN_CONCURRENT_LEAVES: usize = 1024;
// PUBLIC FUNCTIONS
// ================================================================================================
/// Builds all internal nodes of the Merkle using all available threads and stores the
/// results in a single vector such that root of the tree is at position 1, nodes immediately
/// under the root is at positions 2 and 3 etc.
pub fn build_merkle_nodes<H: Hasher>(leaves: &[H::Digest]) -> Vec<H::Digest> {
let n = leaves.len() / 2;
// create un-initialized array to hold all intermediate nodes
let mut nodes = unsafe { utils::uninit_vector::<H::Digest>(2 * n) };
nodes[0] = H::Digest::default();
// re-interpret leaves as an array of two leaves fused together and use it to
// build first row of internal nodes (parents of leaves)
let two_leaves = unsafe { slice::from_raw_parts(leaves.as_ptr() as *const [H::Digest; 2], n) };
nodes[n..]
.par_iter_mut()
.zip(two_leaves.par_iter())
.for_each(|(target, source)| *target = H::merge(source));
// calculate all other tree nodes, we can't use regular iterators here because
// access patterns are rather complicated - so, we use regular threads instead
// number of sub-trees must always be a power of 2
let num_subtrees = rayon::current_num_threads().next_power_of_two();
let batch_size = n / num_subtrees;
// re-interpret nodes as an array of two nodes fused together
let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [H::Digest; 2], n) };
// process each subtree in a separate thread
rayon::scope(|s| {
for i in 0..num_subtrees {
let nodes = unsafe { &mut *(&mut nodes[..] as *mut [H::Digest]) };
s.spawn(move |_| {
let mut batch_size = batch_size / 2;
let mut start_idx = n / 2 + batch_size * i;
while start_idx >= num_subtrees {
for k in (start_idx..(start_idx + batch_size)).rev() {
nodes[k] = H::merge(&two_nodes[k]);
}
start_idx /= 2;
batch_size /= 2;
}
});
}
});
// finish the tip of the tree
for i in (1..num_subtrees).rev() {
nodes[i] = H::merge(&two_nodes[i]);
}
nodes
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use crate::hash::{ByteDigest, Sha3_256};
use math::fields::f128::BaseElement;
use proptest::collection::vec;
use proptest::prelude::*;
proptest! {
#[test]
fn build_merkle_nodes_concurrent(ref data in vec(any::<[u8; 32]>(), 256..257).no_shrink()) {
let leaves = ByteDigest::bytes_as_digests(data).to_vec();
let sequential = super::super::build_merkle_nodes::<Sha3_256<BaseElement>>(&leaves);
let concurrent = super::build_merkle_nodes::<Sha3_256<BaseElement>>(&leaves);
assert_eq!(concurrent, sequential);
}
}
}