Skip to content

Commit

Permalink
llvm-arm64: stash experiments to try to improve stack usage
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 17, 2025
1 parent 55c6954 commit d1e6001
Show file tree
Hide file tree
Showing 12 changed files with 565 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ macro ccopy_gen[N: static int](a_PIR: var Limbs[N], b_PIR: Limbs[N], ctl: Secret
# Codegen
result.add ctx.generate()

debugEcho "======Transfo====="
debugEcho getImplTransformed(result).repr()
debugEcho "======"

func ccopy_asm*(a: var Limbs, b: Limbs, ctl: SecretBool) =
## Constant-time conditional copy
## If ctl is true: b is copied into a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
ctx.str t[i], r[i]

result.add ctx.generate()
debugEcho "======Transfo====="
debugEcho getImplTransformed(result).repr()
debugEcho "======"

func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, lazyReduce: static bool = false) =
## Constant-time Montgomery multiplication
Expand Down
184 changes: 157 additions & 27 deletions constantine/math_compiler/impl_fields_isa_arm64.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/llvm/llvm,
./ir
constantine/platforms/llvm/[llvm, super_instructions],
./ir,
./impl_fields_globals

import
constantine/platforms/llvm/asm_arm64
Expand All @@ -31,7 +32,7 @@ import

const SectionName = "ctt,fields"

proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) =
proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a

Expand All @@ -40,7 +41,7 @@ proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M
# due to LLVM adding extra instructions (from 1, 2 to 33% or 66% more): https://github.com/mratsim/constantine/issues/357

let N = fd.numWords
let t = asy.makeArray(fd.fieldTy)
var t = asy.makeArray(fd.fieldTy)

# Contains 0x0001 (if overflowed limbs) or 0x0000
let overflowedLimbs = asy.br.arm64_add_ci(0'u32, 0'u32)
Expand All @@ -58,6 +59,33 @@ proc finalSubMayOverflow_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M
for i in 0 ..< N:
r[i] = asy.br.arm64_csel_cc(a[i], t[i])

proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
## This is constant-time straightline code.
## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU.
##
## To be used when the modulus does not use the full bitwidth of the storing words
## (say using 255 bits for the modulus out of 256 available in words)

# We use word-level arithmetic instead of llvm_sub_overflow.u256 or llvm_sub_overflow.u384
# due to LLVM adding extra instructions (from 1, 2 to 33% or 66% more): https://github.com/mratsim/constantine/issues/357

var t = asy.makeArray(fd.fieldTy)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
var B = fd.zero_i1
for i in 0 ..< fd.numWords:
(B, t[i]) = asy.br.subborrow(a[i], M[i], B)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `a-M`
for i in 0 ..< fd.numWords:
t[i] = asy.br.select(B, a[i], t[i])
asy.store(r, t)

proc modadd_sat_fullbits_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
## Generate an optimized modular addition kernel
## with parameters `a, b, modulus: Limbs -> Limbs`
Expand All @@ -75,11 +103,11 @@ proc modadd_sat_fullbits_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a,
let (rr, aa, bb, MM) = llvmParams

# Pointers are opaque in LLVM now
let r = asy.asArray(rr, fd.fieldTy)
var r = asy.asArray(rr, fd.fieldTy)
let a = asy.asArray(aa, fd.fieldTy)
let b = asy.asArray(bb, fd.fieldTy)
let M = asy.asArray(MM, fd.fieldTy)
let apb = asy.makeArray(fd.fieldTy)
var apb = asy.makeArray(fd.fieldTy)

apb[0] = asy.br.arm64_add_co(a[0], b[0])
for i in 1 ..< fd.numWords:
Expand All @@ -91,24 +119,126 @@ proc modadd_sat_fullbits_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a,

asy.callFn(name, [r, a, b, M])

proc mtymul_sat_CIOS_sparebit_mulhi_arm64(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef, finalReduce: bool) =
## Generate an optimized modular multiplication kernel
## with parameters `a, b, modulus: Limbs -> Limbs`
##
## Specialization for ARM64
## While the computing instruction count is the same between generic and optimized assembly
## There are significantly more loads/stores and stack usage:
## On 6 limbs (CodeGenLevelDefault):
## - 64 bytes stack vs 368
## - 4 stp vs 23
## - 10 ldp vs 35
## - 6 ldr vs 61
## - 6 str vs 43
## - 6 mov vs 24
## - 78 mul vs 78
## - 72 umulh vs 72
## - 17 adds vs 17
## - 103 adcs vs 103
## - 23 adc vs 12
## - 6 cmn vs 6
## - 0 cset vs 11
# template mulloadd_co(ctx, lhs, rhs, addend): ValueRef =
# let t = ctx.mul(lhs, rhs)
# ctx.arm64_add_co(addend, t)
# template mulloadd_cio(ctx, lhs, rhs, addend): ValueRef =
# let t = ctx.mul(lhs, rhs)
# ctx.arm64_add_cio(addend, t)

# template mulhiadd_co(ctx, lhs, rhs, addend): ValueRef =
# let t = ctx.mulhi(lhs, rhs)
# ctx.arm64_add_co(addend, t)
# template mulhiadd_cio(ctx, lhs, rhs, addend): ValueRef =
# let t = ctx.mulhi(lhs, rhs)
# ctx.arm64_add_cio(addend, t)
# template mulhiadd_ci(ctx, lhs, rhs, addend): ValueRef =
# let t = ctx.mulhi(lhs, rhs)
# ctx.arm64_add_ci(addend, t)

# proc mtymul_sat_CIOS_sparebit_arm64*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef, finalReduce: bool) =
# ## Generate an optimized modular multiplication kernel
# ## with parameters `a, b, modulus: Limbs -> Limbs`
# ##
# ## Specialization for ARM64
# ## While the computing instruction count is the same between generic and optimized assembly
# ## There are significantly more loads/stores and stack usage:
# ## On 6 limbs (CodeGenLevelDefault):
# ## - 64 bytes stack vs 368
# ## - 4 stp vs 23
# ## - 10 ldp vs 35
# ## - 6 ldr vs 61
# ## - 6 str vs 43
# ## - 6 mov vs 24
# ## - 78 mul vs 78
# ## - 72 umulh vs 72
# ## - 17 adds vs 17
# ## - 103 adcs vs 103
# ## - 23 adc vs 12
# ## - 6 cmn vs 6
# ## - 0 cset vs 11

# let name =
# if not finalReduce and fd.spareBits >= 2:
# "_mty_mulur.u" & $fd.w & "x" & $fd.numWords & "b2"
# else:
# doAssert fd.spareBits >= 1
# "_mty_mul.u" & $fd.w & "x" & $fd.numWords & "b1"

# asy.llvmInternalFnDef(
# name, SectionName,
# asy.void_t, toTypes([r, a, b, M]) & fd.wordTy,
# {kHot}):

# tagParameter(1, "sret")

# let (rr, aa, bb, MM, m0ninv) = llvmParams

# # Pointers are opaque in LLVM now
# let r = asy.asArray(rr, fd.fieldTy)
# let b = asy.asArray(bb, fd.fieldTy)

# # Explicitly allocate on the stack
# # the local variable.
# # Unfortunately despite optimization passes
# # stack usage is 5.75 than manual register allocation otherwise
# # so we help the compiler with register lifetimes
# # and imitate C local variable declaration/allocation
# let a = asy.toLocalArray(aa, fd.fieldTy, "a")
# let M = asy.toLocalArray(MM, fd.fieldTy, "M")
# let t = asy.makeArray(fd.fieldTy, "t")
# let N = fd.numWords

# let A = asy.localVar(fd.wordTy, "A")
# let bi = asy.localVar(fd.wordTy, "bi")

# doAssert N >= 2
# for i in 0 ..< N:
# # Multiplication
# # -------------------------------
# # for j=0 to N-1
# # (A,t[j]) := t[j] + a[j]*b[i] + A
# bi[] = b[i]
# A[] = fd.zero
# if i == 0:
# for j in 0 ..< N:
# t[j] = asy.br.mul(a[j], bi[])
# else:
# t[0] = asy.br.mulloadd_co(a[0], bi[], t[0])
# for j in 1 ..< N:
# t[j] = asy.br.mulloadd_cio(a[j], bi[], t[j])
# A[] = asy.br.arm64_cset_cs()

# t[1] = asy.br.mulhiadd_co(a[0], bi[], t[1])
# for j in 2 ..< N:
# t[j] = asy.br.mulhiadd_cio(a[j-1], bi[], t[j])
# A[] = asy.br.mulhiadd_ci(a[N-1], bi[], A[])

# # Reduction
# # -------------------------------
# # m := t[0]*m0ninv mod W
# #
# # C,_ := t[0] + m*M[0]
# # for j=1 to N-1
# # (C,t[j-1]) := t[j] + m*M[j] + C
# # t[N-1] = C + A
# let m = asy.br.mul(t[0], m0ninv)
# let u = asy.br.mul(m, M[0])
# discard asy.br.arm64_cmn(t[0], u)
# for j in 1 ..< N:
# t[j-1] = asy.br.mulloadd_cio(m, M[j], t[j])
# t[N-1] = asy.br.arm64_add_ci(A[], fd.zero)

# t[0] = asy.br.mulhiadd_co(m, M[0], t[0])
# for j in 1 ..< N-1:
# t[j] = asy.br.mulhiadd_cio(m, M[j], t[j])
# t[N-1] = asy.br.mulhiadd_ci(m, M[N-1], t[N-1])

# if finalReduce:
# asy.finalSubNoOverflow(fd, t, t, M)

# asy.store(r, t)
# asy.br.retVoid()

# let m0ninv = asy.getM0ninv(fd)
# asy.callFn(name, [r, a, b, M, m0ninv])
28 changes: 14 additions & 14 deletions constantine/math_compiler/impl_fields_isa_nvidia.nim
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import

const SectionName = "ctt,fields"

proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) =
proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -59,7 +59,7 @@ proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arra
## To be used when the final substraction can
## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)
let N = fd.numWords
let t = asy.makeArray(fd.fieldTy)
var t = asy.makeArray(fd.fieldTy)

# Contains 0x0001 (if overflowed limbs) or 0x0000
let overflowedLimbs = asy.br.add_ci(0'u32, 0'u32)
Expand All @@ -78,7 +78,7 @@ proc finalSubMayOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arra
for i in 0 ..< N:
r[i] = asy.br.slct(t[i], a[i], underflowedModulus)

proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) =
proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r: var Array, a, M: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -88,18 +88,18 @@ proc finalSubNoOverflow(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array
## To be used when the modulus does not use the full bitwidth of the storing words
## (say using 255 bits for the modulus out of 256 available in words)
let N = fd.numWords
let scratch = asy.makeArray(fd.fieldTy)
var t = asy.makeArray(fd.fieldTy)

# Now substract the modulus, and test a < M with the last borrow
scratch[0] = asy.br.sub_bo(a[0], M[0])
t[0] = asy.br.sub_bo(a[0], M[0])
for i in 1 ..< N:
scratch[i] = asy.br.sub_bio(a[i], M[i])
t[i] = asy.br.sub_bio(a[i], M[i])

# If it underflows here, `a` was smaller than the modulus, which is what we want
let underflowedModulus = asy.br.sub_bi(0'u32, 0'u32)

for i in 0 ..< N:
r[i] = asy.br.slct(scratch[i], a[i], underflowedModulus)
r[i] = asy.br.slct(t[i], a[i], underflowedModulus)

proc modadd_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) {.used.} =
## Generate an optimized modular addition kernel
Expand All @@ -118,12 +118,12 @@ proc modadd_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRe
let (rr, aa, bb, MM) = llvmParams

# Pointers are opaque in LLVM now
let r = asy.asArray(rr, fd.fieldTy)
var r = asy.asArray(rr, fd.fieldTy)
let a = asy.asArray(aa, fd.fieldTy)
let b = asy.asArray(bb, fd.fieldTy)
let M = asy.asArray(MM, fd.fieldTy)

let t = asy.makeArray(fd.fieldTy)
var t = asy.makeArray(fd.fieldTy)
let N = fd.numWords

t[0] = asy.br.add_co(a[0], b[0])
Expand Down Expand Up @@ -155,12 +155,12 @@ proc modsub_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRe
let (rr, aa, bb, MM) = llvmParams

# Pointers are opaque in LLVM now
let r = asy.asArray(rr, fd.fieldTy)
var r = asy.asArray(rr, fd.fieldTy)
let a = asy.asArray(aa, fd.fieldTy)
let b = asy.asArray(bb, fd.fieldTy)
let M = asy.asArray(MM, fd.fieldTy)

let t = asy.makeArray(fd.fieldTy)
var t = asy.makeArray(fd.fieldTy)
let N = fd.numWords

t[0] = asy.br.sub_bo(a[0], b[0])
Expand All @@ -171,7 +171,7 @@ proc modsub_nvidia(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRe

# If underflow
# TODO: predicated mov instead?
let maskedM = asy.makeArray(fd.fieldTy)
var maskedM = asy.makeArray(fd.fieldTy)
for i in 0 ..< N:
maskedM[i] = asy.br.`and`(M[i], underflowMask)

Expand Down Expand Up @@ -208,12 +208,12 @@ proc mtymul_CIOS_sparebit(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M:
let (rr, aa, bb, MM) = llvmParams

# Pointers are opaque in LLVM now
let r = asy.asArray(rr, fd.fieldTy)
var r = asy.asArray(rr, fd.fieldTy)
let a = asy.asArray(aa, fd.fieldTy)
let b = asy.asArray(bb, fd.fieldTy)
let M = asy.asArray(MM, fd.fieldTy)

let t = asy.makeArray(fd.fieldTy)
var t = asy.makeArray(fd.fieldTy)
let N = fd.numWords
let m0ninv = asy.getM0ninv(fd)

Expand Down
Loading

0 comments on commit d1e6001

Please sign in to comment.