Skip to content

Commit

Permalink
chore: add compile time flags to disable rayon (#240)
Browse files Browse the repository at this point in the history
* maybe-rayon package

* refactor code

* remove rayon from dep

* fix features

* remove default

* conditionally compile the DasContext constructor

* use multithread and singlethreaded

* turn on multithreading for c_kzg

* add feature flags
  • Loading branch information
kevaundray authored Aug 27, 2024
1 parent c00bbe0 commit d472d60
Show file tree
Hide file tree
Showing 15 changed files with 202 additions and 23 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ members = [
"bindings/nim/rust_code",
"bindings/csharp/rust_code",
"eip7594",
"maybe_rayon",
"cryptography/bls12_381",
"cryptography/kzg_multi_open",
"cryptography/polynomial",
Expand All @@ -31,6 +32,7 @@ bls12_381 = { package = "crate_crypto_internal_eth_kzg_bls12_381", version = "0.
polynomial = { package = "crate_crypto_internal_eth_kzg_polynomial", version = "0.4.1", path = "cryptography/polynomial" }
erasure_codes = { package = "crate_crypto_internal_eth_kzg_erasure_codes", version = "0.4.1", path = "cryptography/erasure_codes" }
rust_eth_kzg = { version = "0.4.1", path = "eip7594" }
maybe_rayon = { package = "crate_crypto_internal_eth_kzg_maybe_rayon", version = "0.4.1", path = "maybe_rayon" }
kzg_multi_open = { package = "crate_crypto_kzg_multi_open_fk20", version = "0.4.1", path = "cryptography/kzg_multi_open" }
c_eth_kzg = { version = "0.4.1", path = "bindings/c" }
hex = "0.4.3"
Expand Down
2 changes: 1 addition & 1 deletion bindings/c/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ crate-type = ["staticlib", "cdylib", "rlib"]

[dependencies]
libc = "0.2.2"
rust_eth_kzg = { workspace = true }
rust_eth_kzg = { workspace = true, features = ["multithreaded"] }

[build-dependencies]
cbindgen = "0.26.0"
2 changes: 1 addition & 1 deletion bindings/node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ napi = { version = "2.12.2", default-features = false, features = [
"async",
] }
napi-derive = "2.12.2"
rust_eth_kzg = { workspace = true }
rust_eth_kzg = { workspace = true, features = ["multithreaded"] }

[build-dependencies]
napi-build = "2.0.1"
3 changes: 3 additions & 0 deletions cryptography/bls12_381/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ subtle = { version = ">=2.5.0, <3.0" }
criterion = "0.5.1"
rand = "0.8.4"

[features]
blst-no-threads = ["blst/no-threads"]

[[bench]]
name = "benchmark"
harness = false
6 changes: 5 additions & 1 deletion cryptography/kzg_multi_open/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ repository = { workspace = true }
[dependencies]
bls12_381 = { workspace = true }
polynomial = { workspace = true }
maybe_rayon = { workspace = true }
hex = { workspace = true }
rayon = { workspace = true }
sha2 = "0.10.8"

[dev-dependencies]
criterion = "0.5.1"
rand = "0.8.4"

[features]
singlethreaded = ["bls12_381/blst-no-threads"]
multithreaded = ["maybe_rayon/multithreaded"]

[[bench]]
name = "benchmark"
harness = false
17 changes: 9 additions & 8 deletions cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use bls12_381::{
fixed_base_msm::{FixedBaseMSM, UsePrecomp},
g1_batch_normalize, G1Point, G1Projective,
};
use maybe_rayon::prelude::*;
use polynomial::domain::Domain;
use rayon::prelude::*;

/// BatchToeplitzMatrixVecMul allows one to compute multiple matrix vector multiplications
/// and sum them together.
Expand Down Expand Up @@ -43,7 +43,7 @@ impl BatchToeplitzMatrixVecMul {

// Precompute the FFT of the vectors, since they do not change per matrix-vector multiplication
let vectors: Vec<Vec<G1Point>> = vectors
.into_par_iter()
.maybe_par_iter()
.map(|vector| {
let vector_projective = vector
.iter()
Expand All @@ -61,7 +61,7 @@ impl BatchToeplitzMatrixVecMul {
//
// This is a trade-off between storage and computation, where storage grows exponentially.
let precomputed_table: Vec<_> = transposed_msm_vectors
.into_par_iter()
.maybe_into_par_iter()
.map(|v| FixedBaseMSM::new(v, use_precomp))
.collect();

Expand All @@ -87,22 +87,23 @@ impl BatchToeplitzMatrixVecMul {
);

// Embed Toeplitz matrices into circulant matrices
let circulant_matrices = matrices.into_iter().map(CirculantMatrix::from_toeplitz);
let circulant_matrices = matrices
.maybe_into_par_iter()
.map(CirculantMatrix::from_toeplitz);

// Perform circulant matrix-vector multiplication between all of the matrices and vectors
// and sum them together.
//
// Transpose the circulant matrices so that we convert a group of hadamard products into a group of
// inner products.
let col_ffts: Vec<_> = circulant_matrices
.into_iter()
.maybe_into_par_iter()
.map(|matrix| self.circulant_domain.fft_scalars(matrix.row))
.collect();
let msm_scalars = transpose(col_ffts);

let result: Vec<_> = self
.precomputed_fft_vectors
.iter()
let result: Vec<_> = (&self.precomputed_fft_vectors)
.maybe_par_iter()
.zip(msm_scalars)
.map(|(points, scalars)| points.msm(scalars))
.collect();
Expand Down
7 changes: 6 additions & 1 deletion eip7594/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ kzg_multi_open = { workspace = true }
bls12_381 = { workspace = true }
hex = { workspace = true }
erasure_codes = { workspace = true }
rayon = { workspace = true }
rayon = { workspace = true, optional = true }
serde = { version = "1", features = ["derive"] }
serde_json = "1"

[features]
singlethreaded = ["rayon", "kzg_multi_open/singlethreaded"]
multithreaded = ["rayon", "kzg_multi_open/multithreaded"]

[dev-dependencies]
criterion = "0.5.1"
rand = "0.8.4"
Expand All @@ -28,3 +32,4 @@ serde_yaml = "0.9.34"
[[bench]]
name = "benchmark"
harness = false
# required-features = ["multithreaded"]
46 changes: 41 additions & 5 deletions eip7594/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#[cfg(all(feature = "singlethreaded", feature = "multithreaded"))]
compile_error!("feature_a and feature_b cannot be enabled simultaneously");

pub mod constants;
mod errors;
mod prover;
mod serialization;
mod trusted_setup;
mod verifier;
#[macro_use]
pub(crate) mod macros;

pub use bls12_381::fixed_base_msm::UsePrecomp;
// Exported types
Expand Down Expand Up @@ -54,9 +59,12 @@ pub type CellIndex = kzg_multi_open::CosetIndex;

use constants::{BYTES_PER_BLOB, BYTES_PER_CELL, BYTES_PER_COMMITMENT};
use prover::ProverContext;
use verifier::VerifierContext;

#[cfg(feature = "multithreaded")]
use rayon::ThreadPool;
#[cfg(feature = "multithreaded")]
use std::sync::Arc;
use verifier::VerifierContext;

/// ThreadCount indicates whether we want to use a single thread or multiple threads
#[derive(Debug, Copy, Clone)]
Expand All @@ -65,19 +73,23 @@ pub enum ThreadCount {
Single,
/// Initializes the threadpool with the number of threads
/// denoted by this enum variant.
#[cfg(feature = "multithreaded")]
Multi(usize),
/// Initializes the threadpool with a sensible default number of
/// threads. This is currently set to `RAYON_NUM_THREADS`.
#[cfg(feature = "multithreaded")]
SensibleDefault,
}

impl From<ThreadCount> for usize {
fn from(value: ThreadCount) -> Self {
match value {
ThreadCount::Single => 1,
#[cfg(feature = "multithreaded")]
ThreadCount::Multi(num_threads) => num_threads,
// Setting this to `0` will tell ThreadPool to use
// `RAYON_NUM_THREADS`.
#[cfg(feature = "multithreaded")]
ThreadCount::SensibleDefault => 0,
}
}
Expand All @@ -86,29 +98,37 @@ impl From<ThreadCount> for usize {
/// The context that will be used to create and verify opening proofs.
#[derive(Debug)]
pub struct DASContext {
#[cfg(feature = "multithreaded")]
thread_pool: Arc<ThreadPool>,
pub prover_ctx: ProverContext,
pub verifier_ctx: VerifierContext,
}

#[cfg(feature = "multithreaded")]
impl Default for DASContext {
fn default() -> Self {
let trusted_setup = TrustedSetup::default();
const DEFAULT_NUM_THREADS: ThreadCount = ThreadCount::Single;
DASContext::with_threads(&trusted_setup, DEFAULT_NUM_THREADS, UsePrecomp::No)
}
}
#[cfg(not(feature = "multithreaded"))]
impl Default for DASContext {
fn default() -> Self {
let trusted_setup = TrustedSetup::default();

DASContext::new(&trusted_setup, UsePrecomp::No)
}
}

impl DASContext {
#[cfg(feature = "multithreaded")]
pub fn with_threads(
trusted_setup: &TrustedSetup,
num_threads: ThreadCount,
// This parameter indicates whether we should allocate memory
// in order to speed up proof creation. Heuristics show that
// if pre-computations are desired, one should set the
// width value to `8` for optimal storage and performance tradeoffs.
use_precomp: UsePrecomp,
) -> Self {
#[cfg(feature = "multithreaded")]
let thread_pool = std::sync::Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads.into())
Expand All @@ -117,12 +137,28 @@ impl DASContext {
);

DASContext {
#[cfg(feature = "multithreaded")]
thread_pool,
prover_ctx: ProverContext::new(trusted_setup, use_precomp),
verifier_ctx: VerifierContext::new(trusted_setup),
}
}

#[cfg(not(feature = "multithreaded"))]
pub fn new(
trusted_setup: &TrustedSetup,
// This parameter indicates whether we should allocate memory
// in order to speed up proof creation. Heuristics show that
// if pre-computations are desired, one should set the
// width value to `8` for optimal storage and performance tradeoffs.
use_precomp: UsePrecomp,
) -> Self {
DASContext {
prover_ctx: ProverContext::new(trusted_setup, use_precomp),
verifier_ctx: VerifierContext::new(trusted_setup),
}
}

pub fn prover_ctx(&self) -> &ProverContext {
&self.prover_ctx
}
Expand Down
13 changes: 13 additions & 0 deletions eip7594/src/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#[macro_export]
macro_rules! with_optional_threadpool {
($self:expr, $body:expr) => {{
#[cfg(feature = "multithreaded")]
{
$self.thread_pool.install(|| $body)
}
#[cfg(not(feature = "multithreaded"))]
{
$body
}
}};
}
9 changes: 5 additions & 4 deletions eip7594/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use crate::{
deserialize_blob_to_scalars, serialize_cells_and_proofs, serialize_g1_compressed,
},
trusted_setup::TrustedSetup,
BlobRef, Cell, CellIndex, CellRef, DASContext, KZGCommitment, KZGProof,
with_optional_threadpool, BlobRef, Cell, CellIndex, CellRef, DASContext, KZGCommitment,
KZGProof,
};

/// Context object that is used to call functions in the prover API.
Expand Down Expand Up @@ -64,7 +65,7 @@ impl DASContext {
///
/// The matching function in the specs is: https://github.com/ethereum/consensus-specs/blob/13ac373a2c284dc66b48ddd2ef0a10537e4e0de6/specs/deneb/polynomial-commitments.md#blob_to_kzg_commitment
pub fn blob_to_kzg_commitment(&self, blob: BlobRef) -> Result<KZGCommitment, Error> {
self.thread_pool.install(|| {
with_optional_threadpool!(self, {
// Deserialize the blob into scalars.
let scalars = deserialize_blob_to_scalars(blob)?;

Expand All @@ -86,7 +87,7 @@ impl DASContext {
&self,
blob: BlobRef,
) -> Result<([Cell; CELLS_PER_EXT_BLOB], [KZGProof; CELLS_PER_EXT_BLOB]), Error> {
self.thread_pool.install(|| {
with_optional_threadpool!(self, {
// Deserialization
//
let scalars = deserialize_blob_to_scalars(blob)?;
Expand Down Expand Up @@ -116,7 +117,7 @@ impl DASContext {
cell_indices: Vec<CellIndex>,
cells: Vec<CellRef>,
) -> Result<([Cell; CELLS_PER_EXT_BLOB], [KZGProof; CELLS_PER_EXT_BLOB]), Error> {
self.thread_pool.install(|| {
with_optional_threadpool!(self, {
// Recover polynomial
//
let poly_coeff = self.recover_polynomial_coeff(cell_indices, cells)?;
Expand Down
4 changes: 2 additions & 2 deletions eip7594/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
errors::Error,
serialization::{deserialize_cells, deserialize_compressed_g1_points},
trusted_setup::TrustedSetup,
Bytes48Ref, CellIndex, CellRef, DASContext,
with_optional_threadpool, Bytes48Ref, CellIndex, CellRef, DASContext,
};
use bls12_381::Scalar;
use erasure_codes::{BlockErasureIndices, ReedSolomon};
Expand Down Expand Up @@ -103,7 +103,7 @@ impl DASContext {
cells: Vec<CellRef>,
proofs_bytes: Vec<Bytes48Ref>,
) -> Result<(), Error> {
self.thread_pool.install(|| {
with_optional_threadpool!(self, {
let (deduplicated_commitments, row_indices) = deduplicate_with_indices(commitments);
// Validation
//
Expand Down
15 changes: 15 additions & 0 deletions maybe_rayon/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "crate_crypto_internal_eth_kzg_maybe_rayon"
description = "This crate provides an implementation of a wrapper around the rayon crate"
version = { workspace = true }
authors = { workspace = true }
edition = { workspace = true }
license = { workspace = true }
rust-version = { workspace = true }
repository = { workspace = true }

[dependencies]
rayon = { workspace = true, optional = true }

[features]
multithreaded = ["rayon"]
17 changes: 17 additions & 0 deletions maybe_rayon/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#[cfg(feature = "multithreaded")]
mod multi_threaded;
#[cfg(not(feature = "multithreaded"))]
mod single_threaded;

#[cfg(feature = "multithreaded")]
pub use multi_threaded::*;
#[cfg(not(feature = "multithreaded"))]
pub use single_threaded::*;

pub mod prelude {
pub use crate::MaybeParallelRefExt;
pub use crate::MaybeParallelRefMutExt;
pub use crate::*;
#[cfg(feature = "multithreaded")]
pub use rayon::prelude::*;
}
29 changes: 29 additions & 0 deletions maybe_rayon/src/multi_threaded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
pub use rayon::iter::IntoParallelIterator;
pub use rayon::iter::IntoParallelRefIterator;
pub use rayon::iter::IntoParallelRefMutIterator;
pub use rayon::iter::ParallelIterator;

pub trait MaybeParallelExt: IntoParallelIterator {
fn maybe_into_par_iter(self) -> <Self as IntoParallelIterator>::Iter
where
Self: Sized,
{
self.into_par_iter()
}
}

pub trait MaybeParallelRefExt: for<'a> IntoParallelRefIterator<'a> {
fn maybe_par_iter(&self) -> <Self as IntoParallelRefIterator>::Iter {
self.par_iter()
}
}

pub trait MaybeParallelRefMutExt: for<'a> IntoParallelRefMutIterator<'a> {
fn maybe_par_iter_mut(&mut self) -> <Self as IntoParallelRefMutIterator>::Iter {
self.par_iter_mut()
}
}

impl<T: IntoParallelIterator> MaybeParallelExt for T {}
impl<T: for<'a> IntoParallelRefIterator<'a>> MaybeParallelRefExt for T {}
impl<T: for<'a> IntoParallelRefMutIterator<'a>> MaybeParallelRefMutExt for T {}
Loading

0 comments on commit d472d60

Please sign in to comment.