Skip to content

Commit

Permalink
Merge pull request #194 from zkFold/vks4git/protostar
Browse files Browse the repository at this point in the history
Fixed division benchmark
  • Loading branch information
TurtlePU authored Jul 27, 2024
2 parents 528d646 + f1413e9 commit 4390c2a
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions bench/BenchDiv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ module Main where

import Control.DeepSeq (force)
import Control.Exception (evaluate)
import qualified Data.Map as M
import Data.Time.Clock (getCurrentTime)
import Prelude hiding (divMod, not, sum, (&&), (*), (+), (-), (/), (^),
(||))
Expand All @@ -26,54 +25,58 @@ import ZkFold.Symbolic.Compiler
import ZkFold.Symbolic.Data.Combinators
import ZkFold.Symbolic.Data.UInt

evalUInt :: forall a n . UInt n ArithmeticCircuit a -> Vector (NumberOfRegisters a n) a
evalUInt (UInt xs) = eval xs M.empty
evalUInt :: forall a n r . UInt n r (ArithmeticCircuit a)-> Vector (NumberOfRegisters a n r) a
evalUInt (UInt v) = exec v

-- | Generate random addition circuit of given size
--
divisionCircuit
:: forall n p r
:: forall n p r rs
. KnownNat n
=> KnownRegisterSize r
=> rs ~ NumberOfRegisters (Zp p) n r
=> KnownNat rs
=> KnownNat (rs - 1)
=> KnownNat (rs + rs)
=> 1 + (rs - 1) ~ rs
=> (rs - 1) + 1 ~ rs
=> 1 <= rs
=> PrimeField (Zp p)
=> r ~ NumberOfRegisters (Zp p) n
=> KnownNat r
=> KnownNat (r - 1)
=> KnownNat (r + r)
=> 1 + (r - 1) ~ r
=> 1 <= r
=> IO (UInt n ArithmeticCircuit (Zp p), UInt n ArithmeticCircuit (Zp p))
=> IO (UInt n r (ArithmeticCircuit (Zp p)), UInt n r (ArithmeticCircuit (Zp p)))
divisionCircuit = do
x <- randomIO
y <- randomIO
let acX = fromConstant (x :: Integer) :: UInt n ArithmeticCircuit (Zp p)
acY = fromConstant (y :: Integer) :: UInt n ArithmeticCircuit (Zp p)
let acX = fromConstant (x :: Integer) :: UInt n r (ArithmeticCircuit (Zp p))
acY = fromConstant (y :: Integer) :: UInt n r (ArithmeticCircuit (Zp p))

acZ = acX `divMod` acY

evaluate . force $ acZ

benchOps
:: forall n p r
:: forall n p (r :: RegisterSize) rs
. KnownNat n
=> PrimeField (Zp p)
=> r ~ NumberOfRegisters (Zp p) n
=> KnownNat r
=> KnownNat (r - 1)
=> KnownNat (r + r)
=> 1 + (r - 1) ~ r
=> 1 <= r
=> KnownRegisterSize r
=> rs ~ NumberOfRegisters (Zp p) n r
=> KnownNat rs
=> KnownNat (rs - 1)
=> KnownNat (rs + rs)
=> 1 + (rs - 1) ~ rs
=> (rs - 1) + 1 ~ rs
=> 1 <= rs
=> Benchmark
benchOps = env (divisionCircuit @n @p) $ \ ~ac ->
benchOps = env (divisionCircuit @n @p @r) $ \ ~ac ->
bench ("Dividing UInts of size " <> show (value @n)) $ nf (\(a, b) -> (evalUInt a, evalUInt b)) ac

main :: IO ()
main = do
getCurrentTime >>= print
(UInt ac32q, UInt ac32r) <- divisionCircuit @32 @BLS12_381_Scalar
(UInt ac32q, UInt ac32r) <- divisionCircuit @32 @BLS12_381_Scalar @Auto
getCurrentTime >>= print
(UInt ac64q, UInt ac64r) <- divisionCircuit @64 @BLS12_381_Scalar
(UInt ac64q, UInt ac64r) <- divisionCircuit @64 @BLS12_381_Scalar @Auto
getCurrentTime >>= print
(UInt ac128q, UInt ac128r) <- divisionCircuit @128 @BLS12_381_Scalar
(UInt ac128q, UInt ac128r) <- divisionCircuit @128 @BLS12_381_Scalar @Auto
getCurrentTime >>= print

putStrLn "Sizes"
Expand All @@ -95,8 +98,8 @@ main = do
getCurrentTime >>= print

defaultMain
[ benchOps @32 @BLS12_381_Scalar
, benchOps @64 @BLS12_381_Scalar
, benchOps @128 @BLS12_381_Scalar
[ benchOps @32 @BLS12_381_Scalar @Auto
, benchOps @64 @BLS12_381_Scalar @Auto
, benchOps @128 @BLS12_381_Scalar @Auto
]

0 comments on commit 4390c2a

Please sign in to comment.