Skip to content

Commit

Permalink
Merge pull request #183 from zkFold/turtlepu-hide-field
Browse files Browse the repository at this point in the history
Hide field parameter in backend usage sites
  • Loading branch information
vlasin authored Jul 16, 2024
2 parents 19f7940 + afac8f3 commit dc1fdd5
Show file tree
Hide file tree
Showing 63 changed files with 814 additions and 785 deletions.
30 changes: 14 additions & 16 deletions bench/BenchAC.hs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE NoGeneralisedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

{-# OPTIONS_GHC -freduction-depth=0 #-}

module Main where
Expand All @@ -27,48 +26,47 @@ import ZkFold.Symbolic.Data.ByteString
import ZkFold.Symbolic.Data.Combinators
import ZkFold.Symbolic.Data.UInt

evalBS :: forall a n . ByteString n ArithmeticCircuit a -> Vector n a
evalBS :: forall a n . ByteString n (ArithmeticCircuit a) -> Vector n a
evalBS (ByteString xs) = eval xs M.empty

evalUInt :: forall a n . UInt n ArithmeticCircuit a -> Vector (NumberOfRegisters a n) a
evalUInt :: forall a n . UInt n (ArithmeticCircuit a) -> Vector (NumberOfRegisters a n) a
evalUInt (UInt xs) = eval xs M.empty


hashCircuit
:: forall n p
. PrimeField (Zp p)
=> SHA2 "SHA256" ArithmeticCircuit (Zp p) n
=> IO (ByteString 256 ArithmeticCircuit (Zp p))
=> SHA2 "SHA256" (ArithmeticCircuit (Zp p)) n
=> IO (ByteString 256 (ArithmeticCircuit (Zp p)))
hashCircuit = do
x <- randomIO
let acX = fromConstant (x :: Integer) :: ByteString n ArithmeticCircuit (Zp p)
h = sha2 @"SHA256" @ArithmeticCircuit acX
let acX = fromConstant (x :: Integer) :: ByteString n (ArithmeticCircuit (Zp p))
h = sha2 @"SHA256" @(ArithmeticCircuit (Zp p)) acX

evaluate . force $ h

-- | Generate random addition circuit of given size
--
additionCircuit :: forall n p. (KnownNat n, PrimeField (Zp p)) => IO (ByteString n ArithmeticCircuit (Zp p))
additionCircuit :: forall n p. (KnownNat n, PrimeField (Zp p)) => IO (ByteString n (ArithmeticCircuit (Zp p)))
additionCircuit = do
x <- randomIO
y <- randomIO
let acX = fromConstant (x :: Integer) :: ByteString n ArithmeticCircuit (Zp p)
acY = fromConstant (y :: Integer) :: ByteString n ArithmeticCircuit (Zp p)

acZ = from (from acX + from acY :: UInt n ArithmeticCircuit (Zp p))
let acX = fromConstant (x :: Integer) :: ByteString n (ArithmeticCircuit (Zp p))
acY = fromConstant (y :: Integer) :: ByteString n (ArithmeticCircuit (Zp p))
acZ = from (from acX + from acY :: UInt n (ArithmeticCircuit (Zp p)))

evaluate . force $ acZ

benchOps :: forall n p. (KnownNat n, PrimeField (Zp p)) => Benchmark
benchOps = env (additionCircuit @n @p) $ \ ~ac ->
benchOps = env (additionCircuit @n @p) $ \ac ->
bench ("Adding ByteStrings of size " <> show (value @n) <> " via UInt") $ nf evalBS ac

benchHash
:: forall n p
. PrimeField (Zp p)
=> SHA2 "SHA256" ArithmeticCircuit (Zp p) n
=> SHA2 "SHA256" (ArithmeticCircuit (Zp p)) n
=> Benchmark
benchHash = env (hashCircuit @n @p) $ \ ~ac ->
benchHash = env (hashCircuit @n @p) $ \ac ->
bench ("Calculating SHA2 512/364 of a bytestring of length " <> show (value @n)) $ nf evalBS ac

main :: IO ()
Expand Down
4 changes: 2 additions & 2 deletions examples/Examples/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ exampleByteStringExtend = do
let k = show $ natVal (Proxy @k)
putStrLn $ "\nExample: Extending a bytestring of length " ++ n ++ " to length " ++ k
let file = "compiled_scripts/bytestring" ++ n ++ "_to_" ++ k ++ ".json"
compileIO @(Zp BLS12_381_Scalar) file $ extend @(ByteString n ArithmeticCircuit (Zp BLS12_381_Scalar)) @(ByteString k ArithmeticCircuit (Zp BLS12_381_Scalar))
compileIO @(Zp BLS12_381_Scalar) file $ extend @(ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar))) @(ByteString k (ArithmeticCircuit (Zp BLS12_381_Scalar)))

type Binary a = a -> a -> a

type UBinary n = Binary (ByteString n ArithmeticCircuit (Zp BLS12_381_Scalar))
type UBinary n = Binary (ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar)))

makeExample :: forall n . (KnownNat n, KnownNat (n + n)) => String -> String -> UBinary n -> IO ()
makeExample shortName name op = do
Expand Down
2 changes: 1 addition & 1 deletion examples/Examples/Conditional.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import ZkFold.Symbolic.Data.Bool (Bool)
import ZkFold.Symbolic.Data.Conditional (Conditional (..))

type F = Zp BLS12_381_Scalar
type A = ArithmeticCircuit 1 F
type A = ArithmeticCircuit F 1
type B = Bool A

exampleConditional :: IO ()
Expand Down
2 changes: 1 addition & 1 deletion examples/Examples/Eq.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ exampleEq = do

putStrLn "\nExample: (==) operation\n"

compileIO @(Zp BLS12_381_Scalar) file (eq @(ArithmeticCircuit 1 (Zp BLS12_381_Scalar)))
compileIO @(Zp BLS12_381_Scalar) file (eq @(ArithmeticCircuit (Zp BLS12_381_Scalar) 1))
2 changes: 1 addition & 1 deletion examples/Examples/FFA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ exampleFFAmul = makeExample @p "*" "mul" (*)

type Binary a = a -> a -> a

makeExample :: forall p. KnownNat p => String -> String -> Binary (FFA p ArithmeticCircuit (Zp BLS12_381_Scalar)) -> IO ()
makeExample :: forall p. KnownNat p => String -> String -> Binary (FFA p (ArithmeticCircuit (Zp BLS12_381_Scalar))) -> IO ()
makeExample shortName name op = do
let p = show $ value @p
putStrLn $ "\nExample: (" ++ shortName ++ ") operation on FFA " ++ p
Expand Down
2 changes: 1 addition & 1 deletion examples/Examples/Fibonacci.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ exampleFibonacci = do

putStrLn "\nExample: Fibonacci index function\n"

compileIO @(Zp BLS12_381_Scalar) file (fibonacciIndex @(ArithmeticCircuit 1 (Zp BLS12_381_Scalar)) nMax)
compileIO @(Zp BLS12_381_Scalar) file (fibonacciIndex @(ArithmeticCircuit (Zp BLS12_381_Scalar) 1) nMax)
2 changes: 1 addition & 1 deletion examples/Examples/LEQ.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ exampleLEQ = do

putStrLn "\nExample: (<=) operation\n"

compileIO @(Zp BLS12_381_Scalar) file (leq @(ArithmeticCircuit 1 (Zp BLS12_381_Scalar)))
compileIO @(Zp BLS12_381_Scalar) file (leq @(ArithmeticCircuit (Zp BLS12_381_Scalar) 1))
2 changes: 1 addition & 1 deletion examples/Examples/MiMCHash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ exampleMiMC = do

putStrLn "\nExample: MiMC hash function\n"

compileIO @F file (mimcHash2 @F @(ArithmeticCircuit 1 F) mimcConstants zero)
compileIO @F file (mimcHash2 @F @(ArithmeticCircuit F 1) mimcConstants zero)
3 changes: 1 addition & 2 deletions examples/Examples/ReverseList.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,4 @@ exampleReverseList = do

putStrLn "\nExample: Reverse List function\n"

compileIO @(Zp BLS12_381_Scalar) file (reverseList @(ArithmeticCircuit 1 (Zp BLS12_381_Scalar)) @32)

compileIO @(Zp BLS12_381_Scalar) file (reverseList @(ArithmeticCircuit (Zp BLS12_381_Scalar) 1) @32)
2 changes: 1 addition & 1 deletion examples/Examples/UInt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ exampleUIntStrictMul = makeExample @n "strictMul" "strict_mul" strictMul

type Binary a = a -> a -> a

type UBinary n = Binary (UInt n ArithmeticCircuit (Zp BLS12_381_Scalar))
type UBinary n = Binary (UInt n (ArithmeticCircuit (Zp BLS12_381_Scalar)))

makeExample
:: forall n r
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Base/Protocol/ARK/Plonk.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ data Plonk (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonk {
k1 :: ScalarField curve1,
k2 :: ScalarField curve1,
iPub :: Vector l Natural,
ac :: ArithmeticCircuit 1 (ScalarField curve1),
ac :: ArithmeticCircuit (ScalarField curve1) 1,
x :: ScalarField curve1
}
instance (Show (ScalarField c1), Arithmetic (ScalarField c1)) => Show (Plonk n l c1 c2 t) where
Expand Down
3 changes: 2 additions & 1 deletion src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE TypeOperators #-}

module ZkFold.Base.Protocol.ARK.Plonk.Relation where
Expand Down Expand Up @@ -37,7 +38,7 @@ toPlonkRelation :: forall n l a .
=> Scale a a
=> FromConstant a a
=> Vector l Natural
-> ArithmeticCircuit 1 a
-> ArithmeticCircuit a 1
-> Maybe (PlonkRelation n l a)
toPlonkRelation xPub ac0 =
let ac = desugarRanges ac0
Expand Down
126 changes: 63 additions & 63 deletions src/ZkFold/Symbolic/Algorithms/Hash/Blake2b.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,43 @@ import ZkFold.Symbolic.Data.UInt (UInt (..))
-- | BLAKE2b Cryptographic hash. Reference:
-- https://tools.ietf.org/html/rfc7693

type Blake2bSig b a =
( Iso (UInt 64 b a) (ByteString 64 b a)
, ShiftBits (ByteString 64 b a)
, Concat (ByteString 64 b a) (ByteString 512 b a)
, ReverseEndianness 64 (ByteString 512 b a)
, BoolType (ByteString 64 b a)
, AdditiveGroup (UInt 64 b a)
, FromConstant Natural (UInt 64 b a)
, MultiplicativeMonoid (UInt 64 b a)
type Blake2bSig b =
( Iso (UInt 64 b) (ByteString 64 b)
, ShiftBits (ByteString 64 b)
, Concat (ByteString 64 b) (ByteString 512 b)
, ReverseEndianness 64 (ByteString 512 b)
, BoolType (ByteString 64 b)
, AdditiveGroup (UInt 64 b)
, FromConstant Natural (UInt 64 b)
, MultiplicativeMonoid (UInt 64 b)
)

pow2 :: forall a . FromConstant Natural a => Natural -> a
pow2 = fromConstant @Natural . (2 ^)

shiftUIntR :: forall b a . Blake2bSig b a => UInt 64 b a -> Natural -> UInt 64 b a
shiftUIntR u n = from @_ @(UInt 64 b a) $ from @_ @(ByteString 64 b a) u `shiftBitsR` n
shiftUIntR :: forall b . Blake2bSig b => UInt 64 b -> Natural -> UInt 64 b
shiftUIntR u n = from @_ @(UInt 64 b) $ from @_ @(ByteString 64 b) u `shiftBitsR` n

shiftUIntL :: forall b a . Blake2bSig b a => UInt 64 b a -> Natural -> UInt 64 b a
shiftUIntL :: forall b . Blake2bSig b => UInt 64 b -> Natural -> UInt 64 b
shiftUIntL u n = u * pow2 n

xorUInt :: forall a b . Blake2bSig b a => UInt 64 b a -> UInt 64 b a -> UInt 64 b a
xorUInt u1 u2 = from @(ByteString 64 b a) @(UInt 64 b a) $ from u1 `xor` from u2
xorUInt :: forall b . Blake2bSig b => UInt 64 b -> UInt 64 b -> UInt 64 b
xorUInt u1 u2 = from @(ByteString 64 b) @(UInt 64 b) $ from u1 `xor` from u2

-- | state context
data Blake2bCtx a b = Blake2bCtx
{ h :: V.Vector (UInt 64 b a) -- chained state 8
, m :: V.Vector (UInt 64 b a) -- input buffer 16
data Blake2bCtx b = Blake2bCtx
{ h :: V.Vector (UInt 64 b) -- chained state 8
, m :: V.Vector (UInt 64 b) -- input buffer 16
, t :: (Natural, Natural) -- total number of bytes
}

-- | Cyclic right rotation.
rotr64 :: Blake2bSig b a => (UInt 64 b a, Natural) -> UInt 64 b a
rotr64 :: Blake2bSig b => (UInt 64 b, Natural) -> UInt 64 b
rotr64 (x, y) = (x `shiftUIntR` y) `xorUInt` (x `shiftUIntL` (64 -! y))

-- | Little-endian byte access.
b2b_g :: forall b a . Blake2bSig b a =>
V.Vector (UInt 64 b a) -> (Int, Int, Int, Int, UInt 64 b a, UInt 64 b a) -> V.Vector (UInt 64 b a)
b2b_g :: forall b . Blake2bSig b =>
V.Vector (UInt 64 b) -> (Int, Int, Int, Int, UInt 64 b, UInt 64 b) -> V.Vector (UInt 64 b)
b2b_g v (a, b, c, d, x, y) =
let va1 = (v ! a) + (v ! b) + x -- v[a] = v[a] + v[b] + x; \
vd1 = rotr64 ((v ! d) `xorUInt` va1, 32) -- v[d] = ROTR64(v[d] ^ v[a], 32); \
Expand All @@ -84,8 +84,8 @@ b2b_g v (a, b, c, d, x, y) =
in v // [(a, va2), (b, vb2), (c, vc2), (d, vd2)]

-- | Compression function. "last" flag indicates the last block.
blake2b_compress :: forall a b . Blake2bSig b a =>
Blake2bCtx a b -> Bool -> V.Vector (UInt 64 b a)
blake2b_compress :: forall b . Blake2bSig b =>
Blake2bCtx b -> Bool -> V.Vector (UInt 64 b)
blake2b_compress Blake2bCtx{h, m, t} lastBlock =
let v' = h V.++ blake2b_iv -- init work variables
v'' = v' V.// [ (12, (v' ! 12) `xorUInt` fromConstant (fst t)) -- low word of the offset
Expand All @@ -110,14 +110,14 @@ blake2b_compress Blake2bCtx{h, m, t} lastBlock =
v1 = V.foldl hashRound v0 $ fromList [0..11] -- twelve rounds
in fmap (\(i, hi) -> hi `xorUInt` (v1 ! i) `xorUInt` (v1 ! (i GHC.+ 8))) (V.zip (fromList [0..7]) h)

blake2b' :: forall bb' kk' ll' nn' a b .
blake2b' :: forall bb' kk' ll' nn' b .
( KnownNat bb'
, KnownNat kk'
, KnownNat ll'
, KnownNat nn'
, Truncate (ByteString 512 b a) (ByteString (8 * nn') b a)
, Blake2bSig b a
) => [V.Vector (UInt 64 b a)] -> ByteString (8 * nn') b a
, Truncate (ByteString 512 b) (ByteString (8 * nn') b)
, Blake2bSig b
) => [V.Vector (UInt 64 b)] -> ByteString (8 * nn') b
blake2b' d =
let bb = value @bb'
ll = value @ll'
Expand All @@ -128,7 +128,7 @@ blake2b' d =
toOffset :: forall x . (FromConstant Natural x) => Natural -> (x, x)
toOffset x = let (hi, lo) = x `divMod` pow2 64 in (fromConstant lo, fromConstant hi)

h = blake2b_iv :: V.Vector (UInt 64 b a)
h = blake2b_iv :: V.Vector (UInt 64 b)

-- Parameter block p[0]
h' = h // [(0, (h ! 0) `xorUInt` fromConstant @Natural 0x01010000 `xorUInt` (fromConstant kk `shiftUIntR` 8) `xorUInt` fromConstant nn)]
Expand All @@ -143,31 +143,31 @@ blake2b' d =
then blake2b_compress (Blake2bCtx h'' (d !! (dd -! 1)) (toOffset @Natural $ ll)) True
else blake2b_compress (Blake2bCtx h'' (d !! (dd -! 1)) (toOffset @Natural $ ll + bb)) True

bs = reverseEndianness @64 $ concat @(ByteString 64 b a) $ map from $ toList h''' :: ByteString (64 * 8) b a
bs = reverseEndianness @64 $ concat @(ByteString 64 b) $ map from $ toList h''' :: ByteString (64 * 8) b
in truncate bs

type ExtensionBits inputLen = 8 * (128 - Mod inputLen 128)
type ExtendedInputByteString inputLen b a = ByteString (8 * inputLen + ExtensionBits inputLen) b a
type ExtendedInputByteString inputLen b = ByteString (8 * inputLen + ExtensionBits inputLen) b

blake2b :: forall keyLen inputLen outputLen a b .
blake2b :: forall keyLen inputLen outputLen b .
( KnownNat keyLen
, KnownNat inputLen
, KnownNat outputLen
, KnownNat (ExtensionBits inputLen)
, Extend (ByteString (8 * inputLen) b a) (ExtendedInputByteString inputLen b a)
, ShiftBits (ExtendedInputByteString inputLen b a)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b a)
, ToWords (ExtendedInputByteString inputLen b a) (ByteString 64 b a)
, Truncate (ByteString 512 b a) (ByteString (8 * outputLen) b a)
, Blake2bSig b a
) => Natural -> ByteString (8 * inputLen) b a -> ByteString (8 * outputLen) b a
, Extend (ByteString (8 * inputLen) b) (ExtendedInputByteString inputLen b)
, ShiftBits (ExtendedInputByteString inputLen b)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b)
, ToWords (ExtendedInputByteString inputLen b) (ByteString 64 b)
, Truncate (ByteString 512 b) (ByteString (8 * outputLen) b)
, Blake2bSig b
) => Natural -> ByteString (8 * inputLen) b -> ByteString (8 * outputLen) b
blake2b key input =
let input' = map from (toWords $
reverseEndianness @64 $
flip rotateBitsL (value @(ExtensionBits inputLen)) $
extend @_ @(ExtendedInputByteString inputLen b a) input :: [ByteString 64 b a])
extend @_ @(ExtendedInputByteString inputLen b) input :: [ByteString 64 b])

key' = fromConstant @_ key :: UInt 64 b a
key' = fromConstant @_ key :: UInt 64 b
input'' = if value @keyLen > 0
then key' : input'
else input'
Expand All @@ -187,40 +187,40 @@ blake2b key input =
d

-- | Hash a `ByteString` using the Blake2b-224 hash function.
blake2b_224 :: forall inputLen b a .
blake2b_224 :: forall inputLen b .
( KnownNat inputLen
, KnownNat (ExtensionBits inputLen)
, Extend (ByteString (8 * inputLen) b a) (ExtendedInputByteString inputLen b a)
, ShiftBits (ExtendedInputByteString inputLen b a)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b a)
, ToWords (ExtendedInputByteString inputLen b a) (ByteString 64 b a)
, Truncate (ByteString 512 b a) (ByteString 224 b a)
, Blake2bSig b a
) => ByteString (8 * inputLen) b a -> ByteString 224 b a
, Extend (ByteString (8 * inputLen) b) (ExtendedInputByteString inputLen b)
, ShiftBits (ExtendedInputByteString inputLen b)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b)
, ToWords (ExtendedInputByteString inputLen b) (ByteString 64 b)
, Truncate (ByteString 512 b) (ByteString 224 b)
, Blake2bSig b
) => ByteString (8 * inputLen) b -> ByteString 224 b
blake2b_224 = blake2b @0 @inputLen @28 (fromConstant @Natural 0)

-- | Hash a `ByteString` using the Blake2b-256 hash function.
blake2b_256 :: forall inputLen b a .
blake2b_256 :: forall inputLen b .
( KnownNat inputLen
, KnownNat (ExtensionBits inputLen)
, Extend (ByteString (8 * inputLen) b a) (ExtendedInputByteString inputLen b a)
, ShiftBits (ExtendedInputByteString inputLen b a)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b a)
, ToWords (ExtendedInputByteString inputLen b a) (ByteString 64 b a)
, Truncate (ByteString 512 b a) (ByteString 256 b a)
, Blake2bSig b a
) => ByteString (8 * inputLen) b a -> ByteString 256 b a
, Extend (ByteString (8 * inputLen) b) (ExtendedInputByteString inputLen b)
, ShiftBits (ExtendedInputByteString inputLen b)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b)
, ToWords (ExtendedInputByteString inputLen b) (ByteString 64 b)
, Truncate (ByteString 512 b) (ByteString 256 b)
, Blake2bSig b
) => ByteString (8 * inputLen) b -> ByteString 256 b
blake2b_256 = blake2b @0 @inputLen @32 (fromConstant @Natural 0)

-- | Hash a `ByteString` using the Blake2b-256 hash function.
blake2b_512 :: forall inputLen b a .
blake2b_512 :: forall inputLen b .
( KnownNat inputLen
, KnownNat (ExtensionBits inputLen)
, Extend (ByteString (8 * inputLen) b a) (ExtendedInputByteString inputLen b a)
, ShiftBits (ExtendedInputByteString inputLen b a)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b a)
, ToWords (ExtendedInputByteString inputLen b a) (ByteString 64 b a)
, Truncate (ByteString 512 b a) (ByteString 512 b a)
, Blake2bSig b a
) => ByteString (8 * inputLen) b a -> ByteString 512 b a
, Extend (ByteString (8 * inputLen) b) (ExtendedInputByteString inputLen b)
, ShiftBits (ExtendedInputByteString inputLen b)
, ReverseEndianness 64 (ExtendedInputByteString inputLen b)
, ToWords (ExtendedInputByteString inputLen b) (ByteString 64 b)
, Truncate (ByteString 512 b) (ByteString 512 b)
, Blake2bSig b
) => ByteString (8 * inputLen) b -> ByteString 512 b
blake2b_512 = blake2b @0 @inputLen @64 (fromConstant @Natural 0)
Loading

0 comments on commit dc1fdd5

Please sign in to comment.