Skip to content

Commit

Permalink
Crandall primes (#445)
Browse files Browse the repository at this point in the history
* feat(special primes accel): Support Crandall primes / Pseudo-Mersenne Prime fast reduction - closes #11

* feat(special primes accel): refactoring: p-1 support ompiles on 64-bit, renaming of lazy reduction both in Montgomery and Crandall to lazyReduction

* feat(special primes accel): support 32-bit

* chore: lazyReduction->lazyReduce

* fix: fp mulsquare test

* feat: Crandall exponentiation

* feat: initial commit assembly for Crandall reduction, passing secp256k1, failing edwards25519

* feat(asm-crandall): actually use the assembly

* feat(asm-crandall): fix sqrt test and short immediate

* feat(crandall reduction): x86-adx reduction

* feat(crandall reduction): add final reduce, deactivate adx partial reduce temporarily

* feat(crandall reduction): fix adx partial reduce

* feat(crandall reduction): prevent asm for mul on 32-bit

* feat(bench): check overhead of field calls

* feat(crandall reduction): prevent asm for mul on 32-bit reloaded

* crandall-primes: reactivate tests for secp256k1
  • Loading branch information
mratsim authored Dec 3, 2024
1 parent 68a6cbb commit 585f803
Show file tree
Hide file tree
Showing 51 changed files with 2,080 additions and 498 deletions.
3 changes: 1 addition & 2 deletions PLANNING.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ Other tracks are stretch goals, contributions towards them are accepted.

- ARM assembly
- Finish Nvidia GPU codegenerator up to MSM
- Implement a backend for prime moduli of special form with fast reduction
that don't need Montgomery form
- Implement a backend for Solinas prime like P256
- Implement an unsaturated finite fields backend for Risc-V, WASM, WebGPU, AMD GPU, Apple Metal, Vulkan, ...
- ideally in LLVM IR so that pristine Risc-V assembly can be generated
and used in zkVMs without any risk of C stdlib or syscalls being used
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/bench_elliptic_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export bench_elliptic_template
#
# ############################################################

proc multiAddParallelBench*(EC: typedesc, numInputs: int, iters: int) =
proc multiAddParallelBench*(EC: typedesc, numInputs: int, iters: int) {.noinline.} =
var points = newSeq[EC_ShortW_Aff[EC.F, EC.G]](numInputs)

for i in 0 ..< numInputs:
Expand All @@ -59,7 +59,7 @@ type BenchMsmContext*[EC] = object
coefs: seq[getBigInt(EC.getName(), kScalarField)]
points: seq[affine(EC)]

proc createBenchMsmContext*(EC: typedesc, inputSizes: openArray[int]): BenchMsmContext[EC] =
proc createBenchMsmContext*(EC: typedesc, inputSizes: openArray[int]): BenchMsmContext[EC] {.noinline.} =
result.tp = Threadpool.new()
let maxNumInputs = inputSizes.max()

Expand Down
36 changes: 18 additions & 18 deletions benchmarks/bench_elliptic_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func `+=`[F; G: static Subgroup](P: var EC_ShortW_JacExt[F, G], Q: EC_ShortW_Jac
func `+=`[F; G: static Subgroup](P: var EC_ShortW_JacExt[F, G], Q: EC_ShortW_Aff[F, G]) {.inline.}=
P.mixedSum_vartime(P, Q)

proc addBench*(EC: typedesc, iters: int) =
proc addBench*(EC: typedesc, iters: int) {.noinline.} =
var r {.noInit.}: EC
let P = rng.random_unsafe(EC)
let Q = rng.random_unsafe(EC)
Expand All @@ -88,7 +88,7 @@ proc addBench*(EC: typedesc, iters: int) =
bench("EC Add vartime " & $EC.G, EC, iters):
r.sum_vartime(P, Q)

proc mixedAddBench*(EC: typedesc, iters: int) =
proc mixedAddBench*(EC: typedesc, iters: int) {.noinline.} =
var r {.noInit.}: EC
let P = rng.random_unsafe(EC)
let Q = rng.random_unsafe(EC)
Expand All @@ -106,25 +106,25 @@ proc mixedAddBench*(EC: typedesc, iters: int) =
bench("EC Mixed Addition vartime " & $EC.G, EC, iters):
r.mixedSum_vartime(P, Qaff)

proc doublingBench*(EC: typedesc, iters: int) =
proc doublingBench*(EC: typedesc, iters: int) {.noinline.} =
var r {.noInit.}: EC
let P = rng.random_unsafe(EC)
bench("EC Double " & $EC.G, EC, iters):
r.double(P)

proc affFromProjBench*(EC: typedesc, iters: int) =
proc affFromProjBench*(EC: typedesc, iters: int) {.noinline.} =
var r {.noInit.}: EC_ShortW_Aff[EC.F, EC.G]
let P = rng.random_unsafe(EC)
bench("EC Projective to Affine " & $EC.G, EC, iters):
r.affine(P)

proc affFromJacBench*(EC: typedesc, iters: int) =
proc affFromJacBench*(EC: typedesc, iters: int) {.noinline.} =
var r {.noInit.}: EC_ShortW_Aff[EC.F, EC.G]
let P = rng.random_unsafe(EC)
bench("EC Jacobian to Affine " & $EC.G, EC, iters):
r.affine(P)

proc affFromProjBatchBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) =
proc affFromProjBatchBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) {.noinline.} =
var r = newSeq[affine(EC)](numPoints)
var points = newSeq[EC](numPoints)

Expand All @@ -139,7 +139,7 @@ proc affFromProjBatchBench*(EC: typedesc, numPoints: int, useBatching: bool, ite
for i in 0 ..< numPoints:
r[i].affine(points[i])

proc affFromJacBatchBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) =
proc affFromJacBatchBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) {.noinline.} =
var r = newSeq[affine(EC)](numPoints)
var points = newSeq[EC](numPoints)

Expand All @@ -154,7 +154,7 @@ proc affFromJacBatchBench*(EC: typedesc, numPoints: int, useBatching: bool, iter
for i in 0 ..< numPoints:
r[i].affine(points[i])

proc scalarMulGenericBench*(EC: typedesc, bits, window: static int, iters: int) =
proc scalarMulGenericBench*(EC: typedesc, bits, window: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -165,7 +165,7 @@ proc scalarMulGenericBench*(EC: typedesc, bits, window: static int, iters: int)
r = P
r.scalarMulGeneric(exponent, window)

proc scalarMulEndo*(EC: typedesc, bits: static int, iters: int) =
proc scalarMulEndo*(EC: typedesc, bits: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -176,7 +176,7 @@ proc scalarMulEndo*(EC: typedesc, bits: static int, iters: int) =
r = P
r.scalarMulEndo(exponent)

proc scalarMulEndoWindow*(EC: typedesc, bits: static int, iters: int) =
proc scalarMulEndoWindow*(EC: typedesc, bits: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -190,7 +190,7 @@ proc scalarMulEndoWindow*(EC: typedesc, bits: static int, iters: int) =
else:
{.error: "Not implemented".}

proc scalarMulVartimeDoubleAddBench*(EC: typedesc, bits: static int, iters: int) =
proc scalarMulVartimeDoubleAddBench*(EC: typedesc, bits: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -201,7 +201,7 @@ proc scalarMulVartimeDoubleAddBench*(EC: typedesc, bits: static int, iters: int)
r = P
r.scalarMul_doubleAdd_vartime(exponent)

proc scalarMulVartimeMinHammingWeightRecodingBench*(EC: typedesc, bits: static int, iters: int) =
proc scalarMulVartimeMinHammingWeightRecodingBench*(EC: typedesc, bits: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -212,7 +212,7 @@ proc scalarMulVartimeMinHammingWeightRecodingBench*(EC: typedesc, bits: static i
r = P
r.scalarMul_jy00_vartime(exponent)

proc scalarMulVartimeWNAFBench*(EC: typedesc, bits, window: static int, iters: int) =
proc scalarMulVartimeWNAFBench*(EC: typedesc, bits, window: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -223,7 +223,7 @@ proc scalarMulVartimeWNAFBench*(EC: typedesc, bits, window: static int, iters: i
r = P
r.scalarMul_wNAF_vartime(exponent, window)

proc scalarMulVartimeEndoWNAFBench*(EC: typedesc, bits, window: static int, iters: int) =
proc scalarMulVartimeEndoWNAFBench*(EC: typedesc, bits, window: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -234,14 +234,14 @@ proc scalarMulVartimeEndoWNAFBench*(EC: typedesc, bits, window: static int, iter
r = P
r.scalarMulEndo_wNAF_vartime(exponent, window)

proc subgroupCheckBench*(EC: typedesc, iters: int) =
proc subgroupCheckBench*(EC: typedesc, iters: int) {.noinline.} =
var P = rng.random_unsafe(EC)
P.clearCofactor()

bench("Subgroup check", EC, iters):
discard P.isInSubgroup()

proc subgroupCheckScalarMulVartimeEndoWNAFBench*(EC: typedesc, bits, window: static int, iters: int) =
proc subgroupCheckScalarMulVartimeEndoWNAFBench*(EC: typedesc, bits, window: static int, iters: int) {.noinline.} =
var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
Expand All @@ -253,7 +253,7 @@ proc subgroupCheckScalarMulVartimeEndoWNAFBench*(EC: typedesc, bits, window: sta
discard r.isInSubgroup()
r.scalarMulEndo_wNAF_vartime(exponent, window)

proc multiAddBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) =
proc multiAddBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int) {.noinline.} =
var points = newSeq[EC_ShortW_Aff[EC.F, EC.G]](numPoints)

for i in 0 ..< numPoints:
Expand All @@ -271,7 +271,7 @@ proc multiAddBench*(EC: typedesc, numPoints: int, useBatching: bool, iters: int)
r += points[i]


proc msmBench*(EC: typedesc, numPoints: int, iters: int) =
proc msmBench*(EC: typedesc, numPoints: int, iters: int) {.noinline.} =
const bits = EC.getScalarField().bits()
var points = newSeq[EC_ShortW_Aff[EC.F, EC.G]](numPoints)
var scalars = newSeq[BigInt[bits]](numPoints)
Expand Down
53 changes: 31 additions & 22 deletions benchmarks/bench_fields_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -65,75 +65,84 @@ func random_unsafe(rng: var RngState, a: var ExtensionField2x) =
for i in 0 ..< a.coords.len:
rng.random_unsafe(a.coords[i])

proc addBench*(T: typedesc, iters: int) =
proc addBench*(T: typedesc, iters: int) {.noinline.} =
var x = rng.random_unsafe(T)
let y = rng.random_unsafe(T)
bench("Addition", T, iters):
x += y

proc subBench*(T: typedesc, iters: int) =
proc add10Bench*(T: typedesc, iters: int) {.noinline.} =
var xs: array[10, T]
for x in xs.mitems():
x = rng.random_unsafe(T)
let y = rng.random_unsafe(T)
bench("Additions (10)", T, iters):
staticFor i, 0, 10:
xs[i] += y

proc subBench*(T: typedesc, iters: int) {.noinline.} =
var x = rng.random_unsafe(T)
let y = rng.random_unsafe(T)
preventOptimAway(x)
bench("Substraction", T, iters):
x -= y

proc negBench*(T: typedesc, iters: int) =
proc negBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let x = rng.random_unsafe(T)
bench("Negation", T, iters):
r.neg(x)

proc ccopyBench*(T: typedesc, iters: int) =
proc ccopyBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let x = rng.random_unsafe(T)
bench("Conditional Copy", T, iters):
r.ccopy(x, CtFalse)

proc div2Bench*(T: typedesc, iters: int) =
proc div2Bench*(T: typedesc, iters: int) {.noinline.} =
var x = rng.random_unsafe(T)
bench("Division by 2", T, iters):
x.div2()

proc mulBench*(T: typedesc, iters: int) =
proc mulBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let x = rng.random_unsafe(T)
let y = rng.random_unsafe(T)
preventOptimAway(r)
bench("Multiplication", T, iters):
r.prod(x, y)

proc sqrBench*(T: typedesc, iters: int) =
proc sqrBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("Squaring", T, iters):
r.square(x)

proc mul2xUnrBench*(T: typedesc, iters: int) =
proc mul2xUnrBench*(T: typedesc, iters: int) {.noinline.} =
var r: doublePrec(T)
let x = rng.random_unsafe(T)
let y = rng.random_unsafe(T)
preventOptimAway(r)
bench("Multiplication 2x unreduced", T, iters):
r.prod2x(x, y)

proc sqr2xUnrBench*(T: typedesc, iters: int) =
proc sqr2xUnrBench*(T: typedesc, iters: int) {.noinline.} =
var r: doublePrec(T)
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("Squaring 2x unreduced", T, iters):
r.square2x(x)

proc rdc2xBench*(T: typedesc, iters: int) =
proc rdc2xBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
var t: doublePrec(T)
rng.random_unsafe(t)
preventOptimAway(r)
bench("Redc 2x", T, iters):
r.redc2x(t)

proc sumprodBench*(T: typedesc, iters: int) =
proc sumprodBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let a = rng.random_unsafe(T)
let b = rng.random_unsafe(T)
Expand All @@ -143,40 +152,40 @@ proc sumprodBench*(T: typedesc, iters: int) =
bench("Linear combination", T, iters):
r.sumprod([a, b], [u, v])

proc toBigBench*(T: typedesc, iters: int) =
proc toBigBench*(T: typedesc, iters: int) {.noinline.} =
var r: T.getBigInt()
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("BigInt <- field conversion", T, iters):
r.fromField(x)

proc toFieldBench*(T: typedesc, iters: int) =
proc toFieldBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let x = rng.random_unsafe(T.getBigInt())
preventOptimAway(r)
bench("BigInt -> field conversion", T, iters):
r.fromBig(x)

proc invBench*(T: typedesc, iters: int) =
proc invBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("Inversion (constant-time)", T, iters):
r.inv(x)

proc invVartimeBench*(T: typedesc, iters: int) =
proc invVartimeBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("Inversion (variable-time)", T, iters):
r.inv_vartime(x)

proc isSquareBench*(T: typedesc, iters: int) =
proc isSquareBench*(T: typedesc, iters: int) {.noinline.} =
let x = rng.random_unsafe(T)
bench("isSquare (constant-time)", T, iters):
let qrt = x.isSquare()

proc sqrtBench*(T: typedesc, iters: int) =
proc sqrtBench*(T: typedesc, iters: int) {.noinline.} =
let x = rng.random_unsafe(T)

const algoType = block:
Expand All @@ -196,14 +205,14 @@ proc sqrtBench*(T: typedesc, iters: int) =
var r = x
discard r.sqrt_if_square()

proc sqrtRatioBench*(T: typedesc, iters: int) =
proc sqrtRatioBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let u = rng.random_unsafe(T)
let v = rng.random_unsafe(T)
bench("Fused SquareRoot+Division+isSquare sqrt(u/v)", T, iters):
let isSquare = r.sqrt_ratio_if_square(u, v)

proc sqrtVartimeBench*(T: typedesc, iters: int) =
proc sqrtVartimeBench*(T: typedesc, iters: int) {.noinline.} =
let x = rng.random_unsafe(T)

const algoType = block:
Expand All @@ -223,21 +232,21 @@ proc sqrtVartimeBench*(T: typedesc, iters: int) =
var r = x
discard r.sqrt_if_square_vartime()

proc sqrtRatioVartimeBench*(T: typedesc, iters: int) =
proc sqrtRatioVartimeBench*(T: typedesc, iters: int) {.noinline.} =
var r: T
let u = rng.random_unsafe(T)
let v = rng.random_unsafe(T)
bench("Fused SquareRoot+Division+isSquare sqrt_vartime(u/v)", T, iters):
let isSquare = r.sqrt_ratio_if_square_vartime(u, v)

proc powBench*(T: typedesc, iters: int) =
proc powBench*(T: typedesc, iters: int) {.noinline.} =
let x = rng.random_unsafe(T)
let exponent = rng.random_unsafe(BigInt[Fr[T.Name].bits()])
var r = x
bench("Exp curve order (constant-time) - " & $exponent.bits & "-bit", T, iters):
r.pow(exponent)

proc powVartimeBench*(T: typedesc, iters: int) =
proc powVartimeBench*(T: typedesc, iters: int) {.noinline.} =
let x = rng.random_unsafe(T)
let exponent = rng.random_unsafe(BigInt[Fr[T.Name].bits()])
var r = x
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/bench_fp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ proc main() =
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
addBench(Fp[curve], Iters)
add10Bench(Fp[curve], Iters)
subBench(Fp[curve], Iters)
negBench(Fp[curve], Iters)
ccopyBench(Fp[curve], Iters)
Expand All @@ -55,8 +56,9 @@ proc main() =
sqr2xUnrBench(Fp[curve], Iters)
rdc2xBench(Fp[curve], Iters)
smallSeparator()
sumprodBench(Fp[curve], Iters)
smallSeparator()
when not Fp[curve].isCrandallPrimeField():
sumprodBench(Fp[curve], Iters)
smallSeparator()
toBigBench(Fp[curve], Iters)
toFieldBench(Fp[curve], Iters)
smallSeparator()
Expand Down
Loading

0 comments on commit 585f803

Please sign in to comment.