Skip to content

Commit

Permalink
Merge pull request #176 from zkFold/zlonast-refactoring-plonk
Browse files Browse the repository at this point in the history
Small refactoring plonk
  • Loading branch information
vlasin authored Aug 6, 2024
2 parents 353ab13 + 5b4c2e3 commit 97b3ff3
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 41 deletions.
11 changes: 5 additions & 6 deletions src/ZkFold/Base/Protocol/ARK/Plonk.hs
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,9 @@ instance forall n l c1 c2 t plonk f g1.
omega'' = omega
k1'' = k1
k2'' = k2
g0'' = gen
h0'' = gen
h1'' = x `mul` gen
x2'' = x `mul` gen
pow'' = log2 $ value @n
n'' = fromIntegral $ value @n

pr = fromJust $ toPlonkRelation @n @l @f iPub ac
perm = plonkPermutation plonk pr
Expand Down Expand Up @@ -363,7 +362,7 @@ instance forall n l c1 c2 t plonk f g1.
+ v * v * v * v * s1_xi
+ v * v * v * v * v * s2_xi
+ u * z_xi
) `mul` g0''
) `mul` gen

p1 = pairing @c1 @c2 (xi `mul` proof1 + (u * xi * omega'') `mul` proof2 + f - e) h0''
p2 = pairing (proof1 + u `mul` proof2) h1''
p1 = pairing @c1 @c2 (xi `mul` proof1 + (u * xi * omega'') `mul` proof2 + f - e) (gen :: Point c2)
p2 = pairing (proof1 + u `mul` proof2) x2''
26 changes: 13 additions & 13 deletions src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Data.Bifunctor (first)
import Data.Bool (bool)
import qualified Data.Map as Map
import qualified Data.Vector as V
import GHC.Generics (Generic)
import GHC.IsList (IsList (..))
import Prelude hiding (Num (..), drop, length, sum, take, (!!), (/), (^))
import System.Random (RandomGen, mkStdGen, uniformR)
Expand Down Expand Up @@ -73,22 +74,20 @@ data PlonkSetupParamsVerify c1 c2 = PlonkSetupParamsVerify {
omega'' :: ScalarField c1,
k1'' :: ScalarField c1,
k2'' :: ScalarField c1,
g0'' :: Point c1,
h0'' :: Point c2,
h1'' :: Point c2,
pow'' :: Integer
x2'' :: Point c2,
pow'' :: Integer,
n'' :: Integer
}
instance (Show (ScalarField c1), Show (BaseField c1), Show (BaseField c2),
EllipticCurve c1, EllipticCurve c2) => Show (PlonkSetupParamsVerify c1 c2) where
show (PlonkSetupParamsVerify omega'' k1'' k2'' g0'' h0'' h1'' pow'') =
show (PlonkSetupParamsVerify omega'' k1'' k2'' x2'' pow'' n'') =
"Setup Parameters (Verify): "
++ show omega'' ++ " "
++ show k1'' ++ " "
++ show k2'' ++ " "
++ show g0'' ++ " "
++ show h0'' ++ " "
++ show h1'' ++ " "
++ show pow''
++ show k1'' ++ " "
++ show k2'' ++ " "
++ show x2'' ++ " "
++ show pow'' ++ " "
++ show n''

data PlonkPermutation n c = PlonkPermutation {
s1 :: PolyVec (ScalarField c) n,
Expand Down Expand Up @@ -161,7 +160,8 @@ data PlonkProverSecret c = PlonkProverSecret {
b9 :: ScalarField c,
b10 :: ScalarField c,
b11 :: ScalarField c
}
} deriving Generic

instance Show (ScalarField c) => Show (PlonkProverSecret c) where
show (PlonkProverSecret b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11) =
"Prover Secret: "
Expand All @@ -182,7 +182,7 @@ instance Arbitrary (ScalarField c) => Arbitrary (PlonkProverSecret c) where
arbitrary <*> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary
<*> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary

newtype PlonkInput c = PlonkInput (V.Vector (ScalarField c))
newtype PlonkInput c = PlonkInput { unPlonkInput :: V.Vector (ScalarField c) }
instance Show (ScalarField c) => Show (PlonkInput c) where
show (PlonkInput v) = "Input: " ++ show v

Expand Down
40 changes: 31 additions & 9 deletions src/ZkFold/Symbolic/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@
module ZkFold.Symbolic.Compiler (
module ZkFold.Symbolic.Compiler.ArithmeticCircuit,
compile,
compileIO
compileIO,
compileSafeZero
) where

import Data.Aeson (ToJSON)
import Data.Eq (Eq)
import Data.Function (const, (.))
import Prelude (FilePath, IO, Monoid (mempty), Show (..), putStrLn,
type (~), ($), (++))
import Data.Aeson (ToJSON)
import Data.Eq (Eq)
import Data.Function (const, (.))
import Prelude (FilePath, IO, Monoid (mempty), Ord, Show (..),
putStrLn, type (~), ($), (++))

import ZkFold.Base.Algebra.Basic.Class (MultiplicativeMonoid)
import ZkFold.Base.Algebra.Basic.Class (BinaryExpansion (..), Field, Finite,
MultiplicativeMonoid)
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.Vector (Vector, unsafeToVector)
import ZkFold.Prelude (writeFileJSON)
import ZkFold.Base.Data.Vector (Vector, unsafeToVector)
import ZkFold.Prelude (writeFileJSON)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (safeZero)
import ZkFold.Symbolic.Data.Class

{-
Expand Down Expand Up @@ -48,6 +51,25 @@ solder f = pieces f (restore @c @(Support c f) $ const inputC)
inputList = [1..(typeSize @c @(Support c f))]
inputC = mempty { acInput = inputList, acOutput = unsafeToVector inputList }

-- | Compiles function `f` into an arithmetic circuit with all outputs are zero.
compileSafeZero ::
forall a c f y .
( c ~ ArithmeticCircuit a
, SymbolicData c f
, SymbolicData c (Support c f)
, Support c (Support c f) ~ ()
, KnownNat (TypeSize c (Support c f))
, SymbolicData c y
, Support c y ~ ()
, TypeSize c f ~ TypeSize c y
, Finite a
, Field a
, BinaryExpansion a
, Bits a ~ [a]
, Ord a
) => f -> y
compileSafeZero = restore @c . const . optimize . safeZero . solder @a

-- | Compiles function `f` into an arithmetic circuit.
compile ::
forall a c f y .
Expand Down
7 changes: 7 additions & 0 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (
splitExpansion,
horner,
desugarRange,
safeZero,
isZeroC,
invertC,
foldCircuit,
Expand All @@ -23,6 +24,7 @@ import Control.Monad (fold
import Data.Containers.ListUtils (nubOrd)
import Data.Eq ((==))
import Data.Foldable (foldlM)
import Data.Functor (($>))
import Data.List (sort)
import Data.Map (elems)
import Data.Traversable (for)
Expand Down Expand Up @@ -120,6 +122,11 @@ desugarRange i b
| c == zero = ($ j) * (one - ($ k))
| otherwise = one + ($ k) * (($ j) - one)

safeZero :: (Arithmetic a, Traversable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f
safeZero r = circuitF $ do
is' <- runCircuit r
for is' $ \i -> constraint (\x -> x i - one) $> i

isZeroC :: (Arithmetic a, Z.Zip f, Traversable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f
isZeroC r = circuitF $ fst <$> runInvert r

Expand Down
67 changes: 54 additions & 13 deletions tests/Tests/Arithmetization/Test4.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,24 @@ module Tests.Arithmetization.Test4 (specArithmetization4) where

import Data.Map (fromList)
import GHC.Generics (Par1 (unPar1))
import GHC.Num (Natural)
import Prelude hiding (Bool, Eq (..), Num (..), Ord (..), (&&))
import qualified Prelude as Haskell
import Test.Hspec (Spec, describe, it)
import Test.QuickCheck (Testable (..), withMaxSuccess, (==>))
import Tests.NonInteractiveProof.Plonk (PlonkBS)

import ZkFold.Base.Algebra.Basic.Class (FromConstant (..), one, zero)
import ZkFold.Base.Algebra.Basic.Class (FromConstant (..), one, zero, (+))
import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1)
import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..))
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkProverSecret, PlonkWitnessInput (..),
plonkVerifierInput)
import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkInput (..), PlonkProverSecret,
PlonkWitnessInput (..), plonkVerifierInput)
import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams)
import ZkFold.Base.Protocol.NonInteractiveProof (NonInteractiveProof (..))
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), acValue, applyArgs, compile)
import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), acValue, applyArgs, compile,
compileSafeZero)
import ZkFold.Symbolic.Data.Bool (Bool (..))
import ZkFold.Symbolic.Data.Eq (Eq (..))
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
Expand All @@ -44,28 +46,67 @@ testDifferentValue targetValue otherValue =
b = unPar1 $ acValue (applyArgs ac [otherValue])
in b Haskell.== zero

testZKP :: F -> PlonkProverSecret C -> F -> Haskell.Bool
testZKP x ps targetValue =
testOnlyOutputZKP :: F -> PlonkProverSecret C -> F -> Haskell.Bool
testOnlyOutputZKP x ps targetValue =
let Bool ac = compile @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F)

(omega, k1, k2) = getParams 32
inputs = fromList [(1, targetValue), (unPar1 $ acOutput ac, 1)]
plonk = Plonk @32 omega k1 k2 (V.singleton $ unPar1 $ acOutput ac) ac x
witnessInputs = fromList [(1, targetValue), (unPar1 $ acOutput ac, 1)]
indexOutputBool = V.singleton $ unPar1 $ acOutput ac
plonk = Plonk @32 omega k1 k2 indexOutputBool ac x
setupP = setupProve @(PlonkBS N) plonk
setupV = setupVerify @(PlonkBS N) plonk
witness = (PlonkWitnessInput inputs, ps)
(_, proof) = prove @(PlonkBS N) setupP witness
witness = (PlonkWitnessInput witnessInputs, ps)
(input, proof) = prove @(PlonkBS N) setupP witness

-- `one` corresponds to `True`
circuitOutputsTrue = plonkVerifierInput $ V.singleton one

in verify @(PlonkBS N) setupV circuitOutputsTrue proof
in unPlonkInput input Haskell.== unPlonkInput circuitOutputsTrue Haskell.&& verify @(PlonkBS N) setupV circuitOutputsTrue proof

testSafeOneInputZKP :: F -> PlonkProverSecret C -> F -> Haskell.Bool
testSafeOneInputZKP x ps targetValue =
let Bool ac = compileSafeZero @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F)

(omega, k1, k2) = getParams 32
witnessInputs = fromList [(1, targetValue), (unPar1 $ acOutput ac, 1)]
indexTargetValue = V.singleton (1 :: Natural)
plonk = Plonk @32 omega k1 k2 indexTargetValue ac x
setupP = setupProve @(PlonkBS N) plonk
setupV = setupVerify @(PlonkBS N) plonk
witness = (PlonkWitnessInput witnessInputs, ps)
(input, proof) = prove @(PlonkBS N) setupP witness

onePublicInput = plonkVerifierInput $ V.singleton targetValue

in unPlonkInput input Haskell.== unPlonkInput onePublicInput Haskell.&& verify @(PlonkBS N) setupV onePublicInput proof

testAttackSafeOneInputZKP :: F -> PlonkProverSecret C -> F -> Haskell.Bool
testAttackSafeOneInputZKP x ps targetValue =
let Bool ac = compileSafeZero @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F)

(omega, k1, k2) = getParams 32
witnessInputs = fromList [(1, targetValue + 1), (unPar1 $ acOutput ac, 0)]
indexTargetValue = V.singleton (1 :: Natural)
plonk = Plonk @32 omega k1 k2 indexTargetValue ac x
setupP = setupProve @(PlonkBS N) plonk
setupV = setupVerify @(PlonkBS N) plonk
witness = (PlonkWitnessInput witnessInputs, ps)
(input, proof) = prove @(PlonkBS N) setupP witness

onePublicInput = plonkVerifierInput $ V.singleton $ targetValue + 1

in unPlonkInput input Haskell.== unPlonkInput onePublicInput Haskell.&& Haskell.not (verify @(PlonkBS N) setupV onePublicInput proof)

specArithmetization4 :: Spec
specArithmetization4 = do
describe "LockedByTxId arithmetization test 1" $ do
it "should pass" $ property testSameValue
describe "LockedByTxId arithmetization test 2" $ do
it "should pass" $ property $ \x y -> x Haskell./= y ==> testDifferentValue x y
describe "LockedByTxId ZKP test" $ do
it "should pass" $ withMaxSuccess 10 $ property testZKP
describe "LockedByTxId ZKP test only output" $ do
it "should pass" $ withMaxSuccess 10 $ property testOnlyOutputZKP
describe "LockedByTxId ZKP test safe one public input" $ do
it "should pass" $ withMaxSuccess 10 $ property testSafeOneInputZKP
describe "LockedByTxId ZKP test attack safe one public input" $ do
it "should pass" $ withMaxSuccess 10 $ property testAttackSafeOneInputZKP

0 comments on commit 97b3ff3

Please sign in to comment.