From 5de5f061d637533160de9739d2a0204772a80ae4 Mon Sep 17 00:00:00 2001 From: Kevaundray Wedderburn Date: Thu, 22 Aug 2024 16:58:36 +0100 Subject: [PATCH] refactor code --- Cargo.toml | 2 ++ bindings/node/Cargo.toml | 2 +- cryptography/kzg_multi_open/Cargo.toml | 7 ++++++- .../kzg_multi_open/src/fk20/batch_toeplitz.rs | 17 +++++++++-------- eip7594/Cargo.toml | 7 ++++++- eip7594/src/lib.rs | 14 +++++++++++++- eip7594/src/macros.rs | 13 +++++++++++++ eip7594/src/prover.rs | 9 +++++---- eip7594/src/verifier.rs | 4 ++-- 9 files changed, 57 insertions(+), 18 deletions(-) create mode 100644 eip7594/src/macros.rs diff --git a/Cargo.toml b/Cargo.toml index e1cc204b..b236e0fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "bindings/nim/rust_code", "bindings/csharp/rust_code", "eip7594", + "maybe_rayon", "cryptography/bls12_381", "cryptography/kzg_multi_open", "cryptography/polynomial", @@ -31,6 +32,7 @@ bls12_381 = { package = "crate_crypto_internal_eth_kzg_bls12_381", version = "0. polynomial = { package = "crate_crypto_internal_eth_kzg_polynomial", version = "0.4.1", path = "cryptography/polynomial" } erasure_codes = { package = "crate_crypto_internal_eth_kzg_erasure_codes", version = "0.4.1", path = "cryptography/erasure_codes" } rust_eth_kzg = { version = "0.4.1", path = "eip7594" } +maybe_rayon = { package = "crate_crypto_internal_eth_kzg_maybe_rayon", version = "0.4.1", path = "maybe_rayon" } kzg_multi_open = { package = "crate_crypto_kzg_multi_open_fk20", version = "0.4.1", path = "cryptography/kzg_multi_open" } c_eth_kzg = { version = "0.4.1", path = "bindings/c" } hex = "0.4.3" diff --git a/bindings/node/Cargo.toml b/bindings/node/Cargo.toml index 39c5cbf4..27f634b6 100644 --- a/bindings/node/Cargo.toml +++ b/bindings/node/Cargo.toml @@ -17,7 +17,7 @@ napi = { version = "2.12.2", default-features = false, features = [ "async", ] } napi-derive = "2.12.2" -rust_eth_kzg = { workspace = true } +rust_eth_kzg = { workspace = true, features = ["multithreading"] } [build-dependencies] napi-build = "2.0.1" diff --git a/cryptography/kzg_multi_open/Cargo.toml b/cryptography/kzg_multi_open/Cargo.toml index 22953e65..6d3fa615 100644 --- a/cryptography/kzg_multi_open/Cargo.toml +++ b/cryptography/kzg_multi_open/Cargo.toml @@ -13,14 +13,19 @@ repository = { workspace = true } [dependencies] bls12_381 = { workspace = true } polynomial = { workspace = true } +maybe_rayon = { workspace = true } hex = { workspace = true } -rayon = { workspace = true } +rayon = { workspace = true, optional = true } sha2 = "0.10.8" [dev-dependencies] criterion = "0.5.1" rand = "0.8.4" +[features] +default = [] +multithreading = ["rayon", "maybe_rayon/multithreading"] + [[bench]] name = "benchmark" harness = false diff --git a/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs b/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs index 20a567d1..610b0566 100644 --- a/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs +++ b/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs @@ -3,8 +3,8 @@ use bls12_381::{ fixed_base_msm::{FixedBaseMSM, UsePrecomp}, g1_batch_normalize, G1Point, G1Projective, }; +use maybe_rayon::prelude::*; use polynomial::domain::Domain; -use rayon::prelude::*; /// BatchToeplitzMatrixVecMul allows one to compute multiple matrix vector multiplications /// and sum them together. @@ -43,7 +43,7 @@ impl BatchToeplitzMatrixVecMul { // Precompute the FFT of the vectors, since they do not change per matrix-vector multiplication let vectors: Vec> = vectors - .into_par_iter() + .maybe_par_iter() .map(|vector| { let vector_projective = vector .iter() @@ -61,7 +61,7 @@ impl BatchToeplitzMatrixVecMul { // // This is a trade-off between storage and computation, where storage grows exponentially. let precomputed_table: Vec<_> = transposed_msm_vectors - .into_par_iter() + .maybe_into_par_iter() .map(|v| FixedBaseMSM::new(v, use_precomp)) .collect(); @@ -87,7 +87,9 @@ impl BatchToeplitzMatrixVecMul { ); // Embed Toeplitz matrices into circulant matrices - let circulant_matrices = matrices.into_iter().map(CirculantMatrix::from_toeplitz); + let circulant_matrices = matrices + .maybe_into_par_iter() + .map(CirculantMatrix::from_toeplitz); // Perform circulant matrix-vector multiplication between all of the matrices and vectors // and sum them together. @@ -95,14 +97,13 @@ impl BatchToeplitzMatrixVecMul { // Transpose the circulant matrices so that we convert a group of hadamard products into a group of // inner products. let col_ffts: Vec<_> = circulant_matrices - .into_iter() + .maybe_into_par_iter() .map(|matrix| self.circulant_domain.fft_scalars(matrix.row)) .collect(); let msm_scalars = transpose(col_ffts); - let result: Vec<_> = self - .precomputed_fft_vectors - .iter() + let result: Vec<_> = (&self.precomputed_fft_vectors) + .maybe_par_iter() .zip(msm_scalars) .map(|(points, scalars)| points.msm(scalars)) .collect(); diff --git a/eip7594/Cargo.toml b/eip7594/Cargo.toml index 17f1545f..ae7f6eff 100644 --- a/eip7594/Cargo.toml +++ b/eip7594/Cargo.toml @@ -13,10 +13,14 @@ kzg_multi_open = { workspace = true } bls12_381 = { workspace = true } hex = { workspace = true } erasure_codes = { workspace = true } -rayon = { workspace = true } +rayon = { workspace = true, optional = true } serde = { version = "1", features = ["derive"] } serde_json = "1" +[features] +default = ["multithreading"] +multithreading = ["rayon"] + [dev-dependencies] criterion = "0.5.1" rand = "0.8.4" @@ -28,3 +32,4 @@ serde_yaml = "0.9.34" [[bench]] name = "benchmark" harness = false +required-features = ["multithreading"] diff --git a/eip7594/src/lib.rs b/eip7594/src/lib.rs index f6c24976..42e6954c 100644 --- a/eip7594/src/lib.rs +++ b/eip7594/src/lib.rs @@ -4,6 +4,8 @@ mod prover; mod serialization; mod trusted_setup; mod verifier; +#[macro_use] +pub(crate) mod macros; pub use bls12_381::fixed_base_msm::UsePrecomp; // Exported types @@ -54,9 +56,12 @@ pub type CellIndex = kzg_multi_open::CosetIndex; use constants::{BYTES_PER_BLOB, BYTES_PER_CELL, BYTES_PER_COMMITMENT}; use prover::ProverContext; +use verifier::VerifierContext; + +#[cfg(feature = "multithreading")] use rayon::ThreadPool; +#[cfg(feature = "multithreading")] use std::sync::Arc; -use verifier::VerifierContext; /// ThreadCount indicates whether we want to use a single thread or multiple threads #[derive(Debug, Copy, Clone)] @@ -65,9 +70,11 @@ pub enum ThreadCount { Single, /// Initializes the threadpool with the number of threads /// denoted by this enum variant. + #[cfg(feature = "multithreading")] Multi(usize), /// Initializes the threadpool with a sensible default number of /// threads. This is currently set to `RAYON_NUM_THREADS`. + #[cfg(feature = "multithreading")] SensibleDefault, } @@ -75,9 +82,11 @@ impl From for usize { fn from(value: ThreadCount) -> Self { match value { ThreadCount::Single => 1, + #[cfg(feature = "multithreading")] ThreadCount::Multi(num_threads) => num_threads, // Setting this to `0` will tell ThreadPool to use // `RAYON_NUM_THREADS`. + #[cfg(feature = "multithreading")] ThreadCount::SensibleDefault => 0, } } @@ -86,6 +95,7 @@ impl From for usize { /// The context that will be used to create and verify opening proofs. #[derive(Debug)] pub struct DASContext { + #[cfg(feature = "multithreading")] thread_pool: Arc, pub prover_ctx: ProverContext, pub verifier_ctx: VerifierContext, @@ -109,6 +119,7 @@ impl DASContext { // width value to `8` for optimal storage and performance tradeoffs. use_precomp: UsePrecomp, ) -> Self { + #[cfg(feature = "multithreading")] let thread_pool = std::sync::Arc::new( rayon::ThreadPoolBuilder::new() .num_threads(num_threads.into()) @@ -117,6 +128,7 @@ impl DASContext { ); DASContext { + #[cfg(feature = "multithreading")] thread_pool, prover_ctx: ProverContext::new(trusted_setup, use_precomp), verifier_ctx: VerifierContext::new(trusted_setup), diff --git a/eip7594/src/macros.rs b/eip7594/src/macros.rs new file mode 100644 index 00000000..f604028f --- /dev/null +++ b/eip7594/src/macros.rs @@ -0,0 +1,13 @@ +#[macro_export] +macro_rules! with_optional_threadpool { + ($self:expr, $body:expr) => {{ + #[cfg(feature = "multithreading")] + { + $self.thread_pool.install(|| $body) + } + #[cfg(not(feature = "multithreading"))] + { + $body + } + }}; +} diff --git a/eip7594/src/prover.rs b/eip7594/src/prover.rs index 85e6836e..35f6cdfa 100644 --- a/eip7594/src/prover.rs +++ b/eip7594/src/prover.rs @@ -14,7 +14,8 @@ use crate::{ deserialize_blob_to_scalars, serialize_cells_and_proofs, serialize_g1_compressed, }, trusted_setup::TrustedSetup, - BlobRef, Cell, CellIndex, CellRef, DASContext, KZGCommitment, KZGProof, + with_optional_threadpool, BlobRef, Cell, CellIndex, CellRef, DASContext, KZGCommitment, + KZGProof, }; /// Context object that is used to call functions in the prover API. @@ -64,7 +65,7 @@ impl DASContext { /// /// The matching function in the specs is: https://github.com/ethereum/consensus-specs/blob/13ac373a2c284dc66b48ddd2ef0a10537e4e0de6/specs/deneb/polynomial-commitments.md#blob_to_kzg_commitment pub fn blob_to_kzg_commitment(&self, blob: BlobRef) -> Result { - self.thread_pool.install(|| { + with_optional_threadpool!(self, { // Deserialize the blob into scalars. let scalars = deserialize_blob_to_scalars(blob)?; @@ -86,7 +87,7 @@ impl DASContext { &self, blob: BlobRef, ) -> Result<([Cell; CELLS_PER_EXT_BLOB], [KZGProof; CELLS_PER_EXT_BLOB]), Error> { - self.thread_pool.install(|| { + with_optional_threadpool!(self, { // Deserialization // let scalars = deserialize_blob_to_scalars(blob)?; @@ -116,7 +117,7 @@ impl DASContext { cell_indices: Vec, cells: Vec, ) -> Result<([Cell; CELLS_PER_EXT_BLOB], [KZGProof; CELLS_PER_EXT_BLOB]), Error> { - self.thread_pool.install(|| { + with_optional_threadpool!(self, { // Recover polynomial // let poly_coeff = self.recover_polynomial_coeff(cell_indices, cells)?; diff --git a/eip7594/src/verifier.rs b/eip7594/src/verifier.rs index e755c411..071cf0dd 100644 --- a/eip7594/src/verifier.rs +++ b/eip7594/src/verifier.rs @@ -9,7 +9,7 @@ use crate::{ errors::Error, serialization::{deserialize_cells, deserialize_compressed_g1_points}, trusted_setup::TrustedSetup, - Bytes48Ref, CellIndex, CellRef, DASContext, + with_optional_threadpool, Bytes48Ref, CellIndex, CellRef, DASContext, }; use bls12_381::Scalar; use erasure_codes::{BlockErasureIndices, ReedSolomon}; @@ -103,7 +103,7 @@ impl DASContext { cells: Vec, proofs_bytes: Vec, ) -> Result<(), Error> { - self.thread_pool.install(|| { + with_optional_threadpool!(self, { let (deduplicated_commitments, row_indices) = deduplicate_with_indices(commitments); // Validation //