Skip to content

Commit

Permalink
Merge pull request #4 from jtraglia/allow-big-ints
Browse files Browse the repository at this point in the history
Change Scalar::new input from u128 to BigUint
  • Loading branch information
kevaundray authored Aug 30, 2024
2 parents c78c9d8 + 7f467a9 commit 7825c7a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ name = "py_arkworks_bls12381"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.0", features = ["extension-module"] }
pyo3 = { version = "0.22.2", features = ["extension-module", "num-bigint"] }
ark-bls12-381 = "0.4.0"
ark-serialize = "0.4.0"
ark-ec = "0.4.0"
ark-ff = "0.4.0"
rayon = "1.6.1"
hex = "0.4.3"
num-traits = "0.2.15"
num-bigint = "0.4.6"

[features]
default = ["parallel", "asm"]
Expand Down
24 changes: 22 additions & 2 deletions examples/scalar.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
from py_arkworks_bls12381 import Scalar

BLS_MODULUS = 52435875175126190479447740508185965837690552500527637822603658699938581184513

# Initialisation - The default initialiser for a scalar is an u128 integer
scalar = Scalar(12345)

# It should be possible to instantiate BLS_MODULUS - 1
max_value = Scalar(BLS_MODULUS - 1)
assert max_value + Scalar(2) == Scalar(1)

# Equality -- We override eq and neq operators
assert scalar == scalar
assert Scalar(1234) != Scalar(4567)

# Scalar Addition/subtraction/Negation -- We override the add/sub/neg operators
# Scalar arithmetic -- We override the mul/div/add/sub/neg operators
a = Scalar(3)
b = Scalar(4)
c = Scalar(5)
assert a.square() + b.square() == c.square()
assert a * a + b * b == c * c

assert Scalar(12) / Scalar(3) == Scalar(4)

try:
assert Scalar(12) / Scalar(0)
assert False
except ZeroDivisionError:
pass

exp = Scalar(0xffff_ffff_ffff_fff)
assert int(Scalar(2).pow(exp)) == pow(2, int(exp), BLS_MODULUS)

neg_a = -a
assert a + neg_a == Scalar(0)
assert (a + neg_a).is_zero()

# Serialisation
compressed_bytes = scalar.to_le_bytes()
deserialised_scalar = Scalar.from_le_bytes(compressed_bytes)
assert scalar == deserialised_scalar
assert scalar == deserialised_scalar

# Conversion to int
assert int(Scalar(12345)) == 12345
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use wrapper::{G1Point, G2Point, Scalar, GT};

/// A Python module implemented in Rust.
#[pymodule]
fn py_arkworks_bls12381(_py: Python, m: &PyModule) -> PyResult<()> {
fn py_arkworks_bls12381(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<G1Point>()?;
m.add_class::<G2Point>()?;
m.add_class::<GT>()?;
Expand Down
36 changes: 32 additions & 4 deletions src/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::str::FromStr;
use ark_bls12_381::{G1Affine, G1Projective, G2Affine, G2Projective};
use ark_ec::pairing::{Pairing, PairingOutput};
use ark_ec::{AffineRepr, Group, ScalarMul, VariableBaseMSM};
Expand All @@ -6,9 +7,12 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError
use num_traits::identities::Zero;
use pyo3::{exceptions, pyclass, pymethods, PyErr, PyResult, Python};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use num_bigint::BigUint;

const G1_COMPRESSED_SIZE: usize = 48;
const G2_COMPRESSED_SIZE: usize = 96;
const SCALAR_SIZE: usize = 32;
const BLS_MODULUS: &str = "52435875175126190479447740508185965837690552500527637822603658699938581184513";

#[derive(Copy, Clone)]
#[pyclass]
Expand Down Expand Up @@ -80,7 +84,7 @@ impl G1Point {
py.allow_threads(|| {
let points: Vec<_> = points.into_par_iter().map(|point| point.0).collect();
let scalars: Vec<_> = scalars.into_par_iter().map(|scalar| scalar.0).collect();

// Convert the points to affine.
// TODO: we could have a G1AffinePoint struct and then a G1ProjectivePoint
// TODO struct, so that this cost is explicit
Expand Down Expand Up @@ -162,7 +166,7 @@ impl G2Point {
py.allow_threads(|| {
let points: Vec<_> = points.into_iter().map(|point| point.0).collect();
let scalars: Vec<_> = scalars.into_iter().map(|scalar| scalar.0).collect();

// Convert the points to affine.
// TODO: we could have a G2AffinePoint struct and then a G2ProjectivePoint
// TODO struct, so that this cost is explicit
Expand All @@ -180,8 +184,10 @@ pub struct Scalar(ark_bls12_381::Fr);
#[pymethods]
impl Scalar {
#[new]
fn new(integer: u128) -> Self {
Scalar(ark_bls12_381::Fr::from(integer))
fn new(integer: BigUint) -> PyResult<Self> {
let fr = ark_bls12_381::Fr::from_str(&*integer.to_string())
.map_err(|_| exceptions::PyValueError::new_err("Value is greater than BLS_MODULUS"))?;
Ok(Scalar(fr))
}

// Overriding operators
Expand All @@ -194,6 +200,13 @@ impl Scalar {
fn __mul__(&self, rhs: Scalar) -> Scalar {
Scalar(self.0 * rhs.0)
}
fn __truediv__(&self, rhs: Scalar) -> PyResult<Scalar> {
if rhs.is_zero() {
let message = "Cannot divide by zero";
return Err(exceptions::PyZeroDivisionError::new_err(message));
}
Ok(Scalar(self.0 / rhs.0))
}
fn __neg__(&self) -> Scalar {
Scalar(-self.0)
}
Expand All @@ -209,7 +222,22 @@ impl Scalar {
)),
}
}
fn __int__(&self) -> BigUint {
// Bug, Fr::to_string will print nothing if the value is zero
BigUint::from_str(&*self.0.to_string()).unwrap_or(BigUint::ZERO)
}

fn pow(&self, exp: Scalar) -> PyResult<Scalar> {
let bls_modulus = BigUint::from_str(BLS_MODULUS).unwrap();
let base_bigint = BigUint::from_bytes_le(self.to_le_bytes()?.as_slice());
let exp_bigint = BigUint::from_bytes_le(exp.to_le_bytes()?.as_slice());
let result = base_bigint.modpow(&exp_bigint, &bls_modulus);
Ok(Scalar(
ark_bls12_381::Fr::from_str(&*result.to_string()).map_err(|_| {
exceptions::PyValueError::new_err("Failed to convert result to scalar")
})?,
))
}
fn square(&self) -> Scalar {
use ark_ff::fields::Field;
Scalar(self.0.square())
Expand Down

0 comments on commit 7825c7a

Please sign in to comment.