diff --git a/Cargo.toml b/Cargo.toml index aa40937..9e0891a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ name = "py_arkworks_bls12381" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.18.0", features = ["extension-module"] } +pyo3 = { version = "0.22.2", features = ["extension-module", "num-bigint"] } ark-bls12-381 = "0.4.0" ark-serialize = "0.4.0" ark-ec = "0.4.0" @@ -17,6 +17,7 @@ ark-ff = "0.4.0" rayon = "1.6.1" hex = "0.4.3" num-traits = "0.2.15" +num-bigint = "0.4.6" [features] default = ["parallel", "asm"] diff --git a/examples/scalar.py b/examples/scalar.py index cb6f43a..fcd7cc0 100644 --- a/examples/scalar.py +++ b/examples/scalar.py @@ -1,19 +1,36 @@ from py_arkworks_bls12381 import Scalar +BLS_MODULUS = 52435875175126190479447740508185965837690552500527637822603658699938581184513 + # Initialisation - The default initialiser for a scalar is an u128 integer scalar = Scalar(12345) +# It should be possible to instantiate BLS_MODULUS - 1 +max_value = Scalar(BLS_MODULUS - 1) +assert max_value + Scalar(2) == Scalar(1) + # Equality -- We override eq and neq operators assert scalar == scalar assert Scalar(1234) != Scalar(4567) -# Scalar Addition/subtraction/Negation -- We override the add/sub/neg operators +# Scalar arithmetic -- We override the mul/div/add/sub/neg operators a = Scalar(3) b = Scalar(4) c = Scalar(5) assert a.square() + b.square() == c.square() assert a * a + b * b == c * c +assert Scalar(12) / Scalar(3) == Scalar(4) + +try: + assert Scalar(12) / Scalar(0) + assert False +except ZeroDivisionError: + pass + +exp = Scalar(0xffff_ffff_ffff_fff) +assert int(Scalar(2).pow(exp)) == pow(2, int(exp), BLS_MODULUS) + neg_a = -a assert a + neg_a == Scalar(0) assert (a + neg_a).is_zero() @@ -21,4 +38,7 @@ # Serialisation compressed_bytes = scalar.to_le_bytes() deserialised_scalar = Scalar.from_le_bytes(compressed_bytes) -assert scalar == deserialised_scalar \ No newline at end of file +assert scalar == deserialised_scalar + +# Conversion to int +assert int(Scalar(12345)) == 12345 \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index e6eaf58..ec818df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ use wrapper::{G1Point, G2Point, Scalar, GT}; /// A Python module implemented in Rust. #[pymodule] -fn py_arkworks_bls12381(_py: Python, m: &PyModule) -> PyResult<()> { +fn py_arkworks_bls12381(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/wrapper.rs b/src/wrapper.rs index caffa3c..4f25178 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -1,3 +1,4 @@ +use std::str::FromStr; use ark_bls12_381::{G1Affine, G1Projective, G2Affine, G2Projective}; use ark_ec::pairing::{Pairing, PairingOutput}; use ark_ec::{AffineRepr, Group, ScalarMul, VariableBaseMSM}; @@ -6,9 +7,12 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError use num_traits::identities::Zero; use pyo3::{exceptions, pyclass, pymethods, PyErr, PyResult, Python}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use num_bigint::BigUint; + const G1_COMPRESSED_SIZE: usize = 48; const G2_COMPRESSED_SIZE: usize = 96; const SCALAR_SIZE: usize = 32; +const BLS_MODULUS: &str = "52435875175126190479447740508185965837690552500527637822603658699938581184513"; #[derive(Copy, Clone)] #[pyclass] @@ -80,7 +84,7 @@ impl G1Point { py.allow_threads(|| { let points: Vec<_> = points.into_par_iter().map(|point| point.0).collect(); let scalars: Vec<_> = scalars.into_par_iter().map(|scalar| scalar.0).collect(); - + // Convert the points to affine. // TODO: we could have a G1AffinePoint struct and then a G1ProjectivePoint // TODO struct, so that this cost is explicit @@ -162,7 +166,7 @@ impl G2Point { py.allow_threads(|| { let points: Vec<_> = points.into_iter().map(|point| point.0).collect(); let scalars: Vec<_> = scalars.into_iter().map(|scalar| scalar.0).collect(); - + // Convert the points to affine. // TODO: we could have a G2AffinePoint struct and then a G2ProjectivePoint // TODO struct, so that this cost is explicit @@ -180,8 +184,10 @@ pub struct Scalar(ark_bls12_381::Fr); #[pymethods] impl Scalar { #[new] - fn new(integer: u128) -> Self { - Scalar(ark_bls12_381::Fr::from(integer)) + fn new(integer: BigUint) -> PyResult { + let fr = ark_bls12_381::Fr::from_str(&*integer.to_string()) + .map_err(|_| exceptions::PyValueError::new_err("Value is greater than BLS_MODULUS"))?; + Ok(Scalar(fr)) } // Overriding operators @@ -194,6 +200,13 @@ impl Scalar { fn __mul__(&self, rhs: Scalar) -> Scalar { Scalar(self.0 * rhs.0) } + fn __truediv__(&self, rhs: Scalar) -> PyResult { + if rhs.is_zero() { + let message = "Cannot divide by zero"; + return Err(exceptions::PyZeroDivisionError::new_err(message)); + } + Ok(Scalar(self.0 / rhs.0)) + } fn __neg__(&self) -> Scalar { Scalar(-self.0) } @@ -209,7 +222,22 @@ impl Scalar { )), } } + fn __int__(&self) -> BigUint { + // Bug, Fr::to_string will print nothing if the value is zero + BigUint::from_str(&*self.0.to_string()).unwrap_or(BigUint::ZERO) + } + fn pow(&self, exp: Scalar) -> PyResult { + let bls_modulus = BigUint::from_str(BLS_MODULUS).unwrap(); + let base_bigint = BigUint::from_bytes_le(self.to_le_bytes()?.as_slice()); + let exp_bigint = BigUint::from_bytes_le(exp.to_le_bytes()?.as_slice()); + let result = base_bigint.modpow(&exp_bigint, &bls_modulus); + Ok(Scalar( + ark_bls12_381::Fr::from_str(&*result.to_string()).map_err(|_| { + exceptions::PyValueError::new_err("Failed to convert result to scalar") + })?, + )) + } fn square(&self) -> Scalar { use ark_ff::fields::Field; Scalar(self.0.square())