diff --git a/Cargo.toml b/Cargo.toml index b9a21e461..7c6b2b17e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,9 +2,7 @@ resolver = "2" members = [ "constantine-rust/constantine-sys", - "constantine-rust/constantine-curves", "constantine-rust/constantine-zal-halo2kzg", - "constantine-rust/constantine-proto-ethereum-bls-signatures", ] # The Nim static library is compiled with ThinLTO, always enable it diff --git a/bindings/c_curve_decls_parallel.nim b/bindings/c_curve_decls_parallel.nim index f81bc6b68..f9c5fcb2a 100644 --- a/bindings/c_curve_decls_parallel.nim +++ b/bindings/c_curve_decls_parallel.nim @@ -28,8 +28,8 @@ template genParallelBindings_EC_ShortW_NonAffine*(ECP, ECP_Aff: untyped) = # -------------------------------------------------------------------------------------- proc `ctt _ ECP _ multi_scalar_mul_vartime_parallel`( tp: Threadpool, - r: ptr ECP, + r: var ECP, coefs: ptr UncheckedArray[BigInt[ECP.F.C.getCurveOrderBitwidth()]], points: ptr UncheckedArray[ECP_Aff], len: csize_t) {.libExport.} = - tp.multiScalarMul_vartime_parallel(r, coefs, points, cast[int](len)) + tp.multiScalarMul_vartime_parallel(r.addr, coefs, points, cast[int](len)) diff --git a/constantine-rust/constantine-curves/Cargo.toml b/constantine-rust/constantine-curves/Cargo.toml deleted file mode 100644 index 65dc12de5..000000000 --- a/constantine-rust/constantine-curves/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "constantine-curves" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -constantine-sys = { path = "../constantine-sys" } \ No newline at end of file diff --git a/constantine-rust/constantine-curves/src/lib.rs b/constantine-rust/constantine-curves/src/lib.rs deleted file mode 100644 index f4f216dc1..000000000 --- a/constantine-rust/constantine-curves/src/lib.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! Constantine -//! Copyright (c) 2018-2019 Status Research & Development GmbH -//! Copyright (c) 2020-Present Mamy André-Ratsimbazafy -//! Licensed and distributed under either of -//! * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -//! * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -//! at your option. This file may not be copied, modified, or distributed except according to those terms. - -use constantine_sys::*; - -pub struct CttThreadpool { - ctx: *mut ctt_threadpool, -} - -impl CttThreadpool { - #[inline(always)] - pub fn new(num_threads: usize) -> CttThreadpool { - let ctx = unsafe{ ctt_threadpool_new(num_threads) }; - CttThreadpool{ctx} - } -} - -impl Drop for CttThreadpool { - fn drop(&mut self) { - unsafe { ctt_threadpool_shutdown(self.ctx) } - } -} - - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn t_threadpool() { - let tp = CttThreadpool::new(4); - drop(tp); - } -} diff --git a/constantine-rust/constantine-proto-ethereum-bls-signatures/Cargo.toml b/constantine-rust/constantine-proto-ethereum-bls-signatures/Cargo.toml deleted file mode 100644 index 15ec5f3ad..000000000 --- a/constantine-rust/constantine-proto-ethereum-bls-signatures/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "constantine-proto-ethereum-bls-signatures" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -constantine-sys = { path = "../constantine-sys" } \ No newline at end of file diff --git a/constantine-rust/constantine-proto-ethereum-bls-signatures/src/lib.rs b/constantine-rust/constantine-proto-ethereum-bls-signatures/src/lib.rs deleted file mode 100644 index 7d12d9af8..000000000 --- a/constantine-rust/constantine-proto-ethereum-bls-signatures/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} diff --git a/constantine-rust/constantine-sys/src/bindings.rs b/constantine-rust/constantine-sys/src/bindings.rs index 34161a8a2..da448b64c 100644 --- a/constantine-rust/constantine-sys/src/bindings.rs +++ b/constantine-rust/constantine-sys/src/bindings.rs @@ -3536,7 +3536,7 @@ fn bindgen_test_layout_big254() { extern "C" { pub fn ctt_bls12_381_g1_jac_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const bls12_381_g1_jac, + r: *mut bls12_381_g1_jac, coefs: *const big255, points: *const bls12_381_g1_aff, len: usize, @@ -3545,7 +3545,7 @@ extern "C" { extern "C" { pub fn ctt_bls12_381_g1_prj_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const bls12_381_g1_prj, + r: *mut bls12_381_g1_prj, coefs: *const big255, points: *const bls12_381_g1_aff, len: usize, @@ -3554,7 +3554,7 @@ extern "C" { extern "C" { pub fn ctt_bn254_snarks_g1_jac_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const bn254_snarks_g1_jac, + r: *mut bn254_snarks_g1_jac, coefs: *const big254, points: *const bn254_snarks_g1_aff, len: usize, @@ -3563,7 +3563,7 @@ extern "C" { extern "C" { pub fn ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const bn254_snarks_g1_prj, + r: *mut bn254_snarks_g1_prj, coefs: *const big254, points: *const bn254_snarks_g1_aff, len: usize, @@ -3572,7 +3572,7 @@ extern "C" { extern "C" { pub fn ctt_pallas_ec_jac_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const pallas_ec_jac, + r: *mut pallas_ec_jac, coefs: *const big255, points: *const pallas_ec_aff, len: usize, @@ -3581,7 +3581,7 @@ extern "C" { extern "C" { pub fn ctt_pallas_ec_prj_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const pallas_ec_prj, + r: *mut pallas_ec_prj, coefs: *const big255, points: *const pallas_ec_aff, len: usize, @@ -3590,7 +3590,7 @@ extern "C" { extern "C" { pub fn ctt_vesta_ec_jac_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const vesta_ec_jac, + r: *mut vesta_ec_jac, coefs: *const big255, points: *const vesta_ec_aff, len: usize, @@ -3599,7 +3599,7 @@ extern "C" { extern "C" { pub fn ctt_vesta_ec_prj_multi_scalar_mul_vartime_parallel( tp: *const ctt_threadpool, - r: *const vesta_ec_prj, + r: *mut vesta_ec_prj, coefs: *const big255, points: *const vesta_ec_aff, len: usize, diff --git a/constantine-rust/constantine-zal-halo2kzg/Cargo.toml b/constantine-rust/constantine-zal-halo2kzg/Cargo.toml index 9ab2f4310..dafb05313 100644 --- a/constantine-rust/constantine-zal-halo2kzg/Cargo.toml +++ b/constantine-rust/constantine-zal-halo2kzg/Cargo.toml @@ -6,4 +6,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -constantine-curves = { path = "../constantine-curves" } \ No newline at end of file +constantine-sys = { path = "../constantine-sys" } +halo2curves = { git = 'https://github.com/taikoxyz/halo2curves', branch = "pr-pse-exec-engine" } + +[dev-dependencies] +ark-std = "0.3" +rand_core = { version = "0.6", default-features = false } +num_cpus = "1.16.0" \ No newline at end of file diff --git a/constantine-rust/constantine-zal-halo2kzg/src/lib.rs b/constantine-rust/constantine-zal-halo2kzg/src/lib.rs index 7d12d9af8..9c746183c 100644 --- a/constantine-rust/constantine-zal-halo2kzg/src/lib.rs +++ b/constantine-rust/constantine-zal-halo2kzg/src/lib.rs @@ -1,14 +1,109 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right +//! Constantine +//! Copyright (c) 2018-2019 Status Research & Development GmbH +//! Copyright (c) 2020-Present Mamy André-Ratsimbazafy +//! Licensed and distributed under either of +//! * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +//! * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +//! at your option. This file may not be copied, modified, or distributed except according to those terms. + +//! See https://github.com/privacy-scaling-explorations/halo2/issues/216 + +use std::mem; +use ::core::mem::MaybeUninit; +use constantine_sys::*; +use halo2curves::bn256; +use halo2curves::zal::{ZalEngine, MsmAccel}; + +pub struct CttEngine { + ctx: *mut ctt_threadpool, +} + +impl CttEngine { + #[inline(always)] + pub fn new(num_threads: usize) -> CttEngine { + let ctx = unsafe{ ctt_threadpool_new(num_threads) }; + CttEngine{ctx} + } +} + +impl Drop for CttEngine { + fn drop(&mut self) { + unsafe { ctt_threadpool_shutdown(self.ctx) } + } +} + +impl ZalEngine for CttEngine{} + +impl MsmAccel for CttEngine { + fn msm(&self, coeffs: &[bn256::Fr], bases: &[bn256::G1Affine]) -> bn256::G1 { + + assert_eq!(coeffs.len(), bases.len()); + let mut result: MaybeUninit = MaybeUninit::uninit(); + unsafe { + ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel( + self.ctx, + result.as_mut_ptr(), + coeffs.as_ptr() as *const big254, + bases.as_ptr() as *const bn254_snarks_g1_aff, + bases.len() + ); + mem::transmute::, bn256::G1>(result) + } + } } #[cfg(test)] mod tests { use super::*; + use ark_std::{end_timer, start_timer}; + use rand_core::OsRng; + + use halo2curves::bn256; + use halo2curves::ff::Field; + use halo2curves::group::{Curve, Group}; + use halo2curves::group::prime::PrimeCurveAffine; + use halo2curves::zal::MsmAccel; + use halo2curves::msm::best_multiexp; + #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); + fn t_threadpool() { + let tp = CttEngine::new(4); + drop(tp); } + + fn run_msm_zal(min_k: usize, max_k: usize) { + let points = (0..1 << max_k) + .map(|_| bn256::G1::random(OsRng)) + .collect::>(); + let mut affine_points = vec![bn256::G1Affine::identity(); 1 << max_k]; + bn256::G1::batch_normalize(&points[..], &mut affine_points[..]); + let points = affine_points; + + let scalars = (0..1 << max_k) + .map(|_| bn256::Fr::random(OsRng)) + .collect::>(); + + for k in min_k..=max_k { + let points = &points[..1 << k]; + let scalars = &scalars[..1 << k]; + + let t0 = start_timer!(|| format!("freestanding msm k={}", k)); + let e0 = best_multiexp(scalars, points); + end_timer!(t0); + + let engine = CttEngine::new(num_cpus::get()); + let t1 = start_timer!(|| format!("CttEngine msm k={}", k)); + let e1 = engine.msm(scalars, points); + end_timer!(t1); + + assert_eq!(e0, e1); + } + } + + #[test] + fn t_msm_zal() { + run_msm_zal(3, 14); + } + } diff --git a/include/constantine/curves/bls12_381_parallel.h b/include/constantine/curves/bls12_381_parallel.h index eadbfcdfe..9a9a0a94b 100644 --- a/include/constantine/curves/bls12_381_parallel.h +++ b/include/constantine/curves/bls12_381_parallel.h @@ -18,8 +18,8 @@ extern "C" { #endif -void ctt_bls12_381_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bls12_381_g1_jac* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len); -void ctt_bls12_381_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bls12_381_g1_prj* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len); +void ctt_bls12_381_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bls12_381_g1_jac* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len); +void ctt_bls12_381_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bls12_381_g1_prj* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len); #ifdef __cplusplus } diff --git a/include/constantine/curves/bn254_snarks_parallel.h b/include/constantine/curves/bn254_snarks_parallel.h index 8a9a59547..565aaf71a 100644 --- a/include/constantine/curves/bn254_snarks_parallel.h +++ b/include/constantine/curves/bn254_snarks_parallel.h @@ -18,8 +18,8 @@ extern "C" { #endif -void ctt_bn254_snarks_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bn254_snarks_g1_jac* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len); -void ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bn254_snarks_g1_prj* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len); +void ctt_bn254_snarks_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bn254_snarks_g1_jac* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len); +void ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bn254_snarks_g1_prj* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len); #ifdef __cplusplus } diff --git a/include/constantine/curves/pallas_parallel.h b/include/constantine/curves/pallas_parallel.h index c07bbe434..d4a29e444 100644 --- a/include/constantine/curves/pallas_parallel.h +++ b/include/constantine/curves/pallas_parallel.h @@ -18,8 +18,8 @@ extern "C" { #endif -void ctt_pallas_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const pallas_ec_jac* r, const big255 coefs[], const pallas_ec_aff points[], size_t len); -void ctt_pallas_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const pallas_ec_prj* r, const big255 coefs[], const pallas_ec_aff points[], size_t len); +void ctt_pallas_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, pallas_ec_jac* r, const big255 coefs[], const pallas_ec_aff points[], size_t len); +void ctt_pallas_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, pallas_ec_prj* r, const big255 coefs[], const pallas_ec_aff points[], size_t len); #ifdef __cplusplus } diff --git a/include/constantine/curves/vesta_parallel.h b/include/constantine/curves/vesta_parallel.h index fbc08884a..e52002d94 100644 --- a/include/constantine/curves/vesta_parallel.h +++ b/include/constantine/curves/vesta_parallel.h @@ -18,8 +18,8 @@ extern "C" { #endif -void ctt_vesta_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const vesta_ec_jac* r, const big255 coefs[], const vesta_ec_aff points[], size_t len); -void ctt_vesta_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const vesta_ec_prj* r, const big255 coefs[], const vesta_ec_aff points[], size_t len); +void ctt_vesta_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, vesta_ec_jac* r, const big255 coefs[], const vesta_ec_aff points[], size_t len); +void ctt_vesta_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, vesta_ec_prj* r, const big255 coefs[], const vesta_ec_aff points[], size_t len); #ifdef __cplusplus }