Skip to content

Commit

Permalink
zal: now compiles with halo2curves and can call Nim
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Dec 1, 2023
1 parent d9336da commit c074f90
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 98 deletions.
2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
resolver = "2"
members = [
"constantine-rust/constantine-sys",
"constantine-rust/constantine-curves",
"constantine-rust/constantine-zal-halo2kzg",
"constantine-rust/constantine-proto-ethereum-bls-signatures",
]

# The Nim static library is compiled with ThinLTO, always enable it
Expand Down
4 changes: 2 additions & 2 deletions bindings/c_curve_decls_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ template genParallelBindings_EC_ShortW_NonAffine*(ECP, ECP_Aff: untyped) =
# --------------------------------------------------------------------------------------
proc `ctt _ ECP _ multi_scalar_mul_vartime_parallel`(
tp: Threadpool,
r: ptr ECP,
r: var ECP,
coefs: ptr UncheckedArray[BigInt[ECP.F.C.getCurveOrderBitwidth()]],
points: ptr UncheckedArray[ECP_Aff],
len: csize_t) {.libExport.} =
tp.multiScalarMul_vartime_parallel(r, coefs, points, cast[int](len))
tp.multiScalarMul_vartime_parallel(r.addr, coefs, points, cast[int](len))
9 changes: 0 additions & 9 deletions constantine-rust/constantine-curves/Cargo.toml

This file was deleted.

40 changes: 0 additions & 40 deletions constantine-rust/constantine-curves/src/lib.rs

This file was deleted.

This file was deleted.

This file was deleted.

16 changes: 8 additions & 8 deletions constantine-rust/constantine-sys/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3536,7 +3536,7 @@ fn bindgen_test_layout_big254() {
extern "C" {
pub fn ctt_bls12_381_g1_jac_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const bls12_381_g1_jac,
r: *mut bls12_381_g1_jac,
coefs: *const big255,
points: *const bls12_381_g1_aff,
len: usize,
Expand All @@ -3545,7 +3545,7 @@ extern "C" {
extern "C" {
pub fn ctt_bls12_381_g1_prj_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const bls12_381_g1_prj,
r: *mut bls12_381_g1_prj,
coefs: *const big255,
points: *const bls12_381_g1_aff,
len: usize,
Expand All @@ -3554,7 +3554,7 @@ extern "C" {
extern "C" {
pub fn ctt_bn254_snarks_g1_jac_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const bn254_snarks_g1_jac,
r: *mut bn254_snarks_g1_jac,
coefs: *const big254,
points: *const bn254_snarks_g1_aff,
len: usize,
Expand All @@ -3563,7 +3563,7 @@ extern "C" {
extern "C" {
pub fn ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const bn254_snarks_g1_prj,
r: *mut bn254_snarks_g1_prj,
coefs: *const big254,
points: *const bn254_snarks_g1_aff,
len: usize,
Expand All @@ -3572,7 +3572,7 @@ extern "C" {
extern "C" {
pub fn ctt_pallas_ec_jac_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const pallas_ec_jac,
r: *mut pallas_ec_jac,
coefs: *const big255,
points: *const pallas_ec_aff,
len: usize,
Expand All @@ -3581,7 +3581,7 @@ extern "C" {
extern "C" {
pub fn ctt_pallas_ec_prj_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const pallas_ec_prj,
r: *mut pallas_ec_prj,
coefs: *const big255,
points: *const pallas_ec_aff,
len: usize,
Expand All @@ -3590,7 +3590,7 @@ extern "C" {
extern "C" {
pub fn ctt_vesta_ec_jac_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const vesta_ec_jac,
r: *mut vesta_ec_jac,
coefs: *const big255,
points: *const vesta_ec_aff,
len: usize,
Expand All @@ -3599,7 +3599,7 @@ extern "C" {
extern "C" {
pub fn ctt_vesta_ec_prj_multi_scalar_mul_vartime_parallel(
tp: *const ctt_threadpool,
r: *const vesta_ec_prj,
r: *mut vesta_ec_prj,
coefs: *const big255,
points: *const vesta_ec_aff,
len: usize,
Expand Down
8 changes: 7 additions & 1 deletion constantine-rust/constantine-zal-halo2kzg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
constantine-curves = { path = "../constantine-curves" }
constantine-sys = { path = "../constantine-sys" }
halo2curves = { git = 'https://github.com/taikoxyz/halo2curves', branch = "pr-pse-exec-engine" }

[dev-dependencies]
ark-std = "0.3"
rand_core = { version = "0.6", default-features = false }
num_cpus = "1.16.0"
105 changes: 100 additions & 5 deletions constantine-rust/constantine-zal-halo2kzg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,109 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
//! Constantine
//! Copyright (c) 2018-2019 Status Research & Development GmbH
//! Copyright (c) 2020-Present Mamy André-Ratsimbazafy
//! Licensed and distributed under either of
//! * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
//! * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
//! at your option. This file may not be copied, modified, or distributed except according to those terms.
//! See https://github.com/privacy-scaling-explorations/halo2/issues/216
use std::mem;
use ::core::mem::MaybeUninit;
use constantine_sys::*;
use halo2curves::bn256;
use halo2curves::zal::{ZalEngine, MsmAccel};

pub struct CttEngine {
ctx: *mut ctt_threadpool,
}

impl CttEngine {
#[inline(always)]
pub fn new(num_threads: usize) -> CttEngine {
let ctx = unsafe{ ctt_threadpool_new(num_threads) };
CttEngine{ctx}
}
}

impl Drop for CttEngine {
fn drop(&mut self) {
unsafe { ctt_threadpool_shutdown(self.ctx) }
}
}

impl ZalEngine for CttEngine{}

impl MsmAccel<bn256::G1Affine> for CttEngine {
fn msm(&self, coeffs: &[bn256::Fr], bases: &[bn256::G1Affine]) -> bn256::G1 {

assert_eq!(coeffs.len(), bases.len());
let mut result: MaybeUninit<bn254_snarks_g1_prj> = MaybeUninit::uninit();
unsafe {
ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel(
self.ctx,
result.as_mut_ptr(),
coeffs.as_ptr() as *const big254,
bases.as_ptr() as *const bn254_snarks_g1_aff,
bases.len()
);
mem::transmute::<MaybeUninit<bn254_snarks_g1_prj>, bn256::G1>(result)
}
}
}

#[cfg(test)]
mod tests {
use super::*;

use ark_std::{end_timer, start_timer};
use rand_core::OsRng;

use halo2curves::bn256;
use halo2curves::ff::Field;
use halo2curves::group::{Curve, Group};
use halo2curves::group::prime::PrimeCurveAffine;
use halo2curves::zal::MsmAccel;
use halo2curves::msm::best_multiexp;

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
fn t_threadpool() {
let tp = CttEngine::new(4);
drop(tp);
}

fn run_msm_zal(min_k: usize, max_k: usize) {
let points = (0..1 << max_k)
.map(|_| bn256::G1::random(OsRng))
.collect::<Vec<_>>();
let mut affine_points = vec![bn256::G1Affine::identity(); 1 << max_k];
bn256::G1::batch_normalize(&points[..], &mut affine_points[..]);
let points = affine_points;

let scalars = (0..1 << max_k)
.map(|_| bn256::Fr::random(OsRng))
.collect::<Vec<_>>();

for k in min_k..=max_k {
let points = &points[..1 << k];
let scalars = &scalars[..1 << k];

let t0 = start_timer!(|| format!("freestanding msm k={}", k));
let e0 = best_multiexp(scalars, points);
end_timer!(t0);

let engine = CttEngine::new(num_cpus::get());
let t1 = start_timer!(|| format!("CttEngine msm k={}", k));
let e1 = engine.msm(scalars, points);
end_timer!(t1);

assert_eq!(e0, e1);
}
}

#[test]
fn t_msm_zal() {
run_msm_zal(3, 14);
}

}
4 changes: 2 additions & 2 deletions include/constantine/curves/bls12_381_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
extern "C" {
#endif

void ctt_bls12_381_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bls12_381_g1_jac* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len);
void ctt_bls12_381_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bls12_381_g1_prj* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len);
void ctt_bls12_381_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bls12_381_g1_jac* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len);
void ctt_bls12_381_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bls12_381_g1_prj* r, const big255 coefs[], const bls12_381_g1_aff points[], size_t len);

#ifdef __cplusplus
}
Expand Down
4 changes: 2 additions & 2 deletions include/constantine/curves/bn254_snarks_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
extern "C" {
#endif

void ctt_bn254_snarks_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bn254_snarks_g1_jac* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len);
void ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const bn254_snarks_g1_prj* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len);
void ctt_bn254_snarks_g1_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bn254_snarks_g1_jac* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len);
void ctt_bn254_snarks_g1_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, bn254_snarks_g1_prj* r, const big254 coefs[], const bn254_snarks_g1_aff points[], size_t len);

#ifdef __cplusplus
}
Expand Down
4 changes: 2 additions & 2 deletions include/constantine/curves/pallas_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
extern "C" {
#endif

void ctt_pallas_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const pallas_ec_jac* r, const big255 coefs[], const pallas_ec_aff points[], size_t len);
void ctt_pallas_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const pallas_ec_prj* r, const big255 coefs[], const pallas_ec_aff points[], size_t len);
void ctt_pallas_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, pallas_ec_jac* r, const big255 coefs[], const pallas_ec_aff points[], size_t len);
void ctt_pallas_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, pallas_ec_prj* r, const big255 coefs[], const pallas_ec_aff points[], size_t len);

#ifdef __cplusplus
}
Expand Down
4 changes: 2 additions & 2 deletions include/constantine/curves/vesta_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
extern "C" {
#endif

void ctt_vesta_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const vesta_ec_jac* r, const big255 coefs[], const vesta_ec_aff points[], size_t len);
void ctt_vesta_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, const vesta_ec_prj* r, const big255 coefs[], const vesta_ec_aff points[], size_t len);
void ctt_vesta_ec_jac_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, vesta_ec_jac* r, const big255 coefs[], const vesta_ec_aff points[], size_t len);
void ctt_vesta_ec_prj_multi_scalar_mul_vartime_parallel(const ctt_threadpool* tp, vesta_ec_prj* r, const big255 coefs[], const vesta_ec_aff points[], size_t len);

#ifdef __cplusplus
}
Expand Down

0 comments on commit c074f90

Please sign in to comment.