Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored Plonk(up) to use arbitrary functors #394

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions symbolic-base/src/ZkFold/Base/Data/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Control.Monad.State.Strict (runState, state)
import Data.Aeson (ToJSON (..))
import Data.Distributive (Distributive (..))
import Data.Foldable (fold)
import Data.Functor.Classes (Show1)
import Data.Functor.Rep (Representable (..), collectRep, distributeRep, mzipRep, pureRep)
import Data.These (These (..))
import qualified Data.Vector as V
Expand All @@ -20,7 +21,7 @@ import GHC.IsList (IsList (..))
import Prelude hiding (concat, drop, head, length, mod, negate, replicate, sum, tail,
take, unzip, zip, zipWith, (*), (+), (-))
import System.Random (Random (..))
import Test.QuickCheck (Arbitrary (..))
import Test.QuickCheck (Arbitrary (..), Arbitrary1 (..), arbitrary1)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Field
Expand All @@ -29,7 +30,7 @@ import ZkFold.Base.Data.ByteString (Binary (..))
import ZkFold.Prelude (length)

newtype Vector (size :: Natural) a = Vector {toV :: V.Vector a}
deriving (Show, Eq, Functor, Foldable, Traversable, Generic, NFData, Ord)
deriving (Show, Show1, Eq, Functor, Foldable, Traversable, Generic, NFData, Ord)

-- helper
knownNat :: forall size n . (KnownNat size, Integral n) => n
Expand Down Expand Up @@ -173,7 +174,10 @@ instance Unzip (Vector size) where
unzip v = (fst <$> v, snd <$> v)

instance (Arbitrary a, KnownNat size) => Arbitrary (Vector size a) where
arbitrary = sequenceA (pureRep arbitrary)
arbitrary = arbitrary1

instance KnownNat size => Arbitrary1 (Vector size) where
liftArbitrary = sequenceA . pureRep

instance (Random a, KnownNat size) => Random (Vector size a) where
random = runState (sequenceA (pureRep (state random)))
Expand Down
18 changes: 10 additions & 8 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonk.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module ZkFold.Base.Protocol.Plonk (
) where

import Data.Binary (Binary)
import Data.Functor.Classes (Show1)
import Data.Functor.Rep (Rep)
import Data.Kind (Type)
import Data.Word (Word8)
Expand All @@ -17,7 +18,6 @@ import Test.QuickCheck (Arbitrary
import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup)
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Pairing, PointCompressed)
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Base.Protocol.NonInteractiveProof
import ZkFold.Base.Protocol.Plonk.Prover (plonkProve)
import ZkFold.Base.Protocol.Plonk.Verifier (plonkVerify)
Expand All @@ -32,31 +32,32 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

{-| Based on the paper https://eprint.iacr.org/2019/953.pdf -}

data Plonk p (i :: Natural) (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonk {
data Plonk p i (n :: Natural) l curve1 curve2 transcript = Plonk {
omega :: ScalarField curve1,
k1 :: ScalarField curve1,
k2 :: ScalarField curve1,
ac :: ArithmeticCircuit (ScalarField curve1) p (Vector i) (Vector l),
ac :: ArithmeticCircuit (ScalarField curve1) p i l,
x :: ScalarField curve1
}

fromPlonkup ::
( KnownNat i
, Arithmetic (ScalarField c1)
( Arithmetic (ScalarField c1)
, Binary (ScalarField c1)
, Binary (Rep p)
, Binary (Rep i)
, Ord (Rep i)
) => Plonkup p i n l c1 c2 ts -> Plonk p i n l c1 c2 ts
fromPlonkup Plonkup {..} = Plonk { ac = desugarRanges ac, ..}

toPlonkup :: Plonk p i n l c1 c2 ts -> Plonkup p i n l c1 c2 ts
toPlonkup Plonk {..} = Plonkup {..}

instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l, KnownNat i) => Show (Plonk p i n l c1 c2 t) where
instance (Show1 l, Show (Rep i), Show (ScalarField c1), Ord (Rep i)) => Show (Plonk p i n l c1 c2 t) where
show Plonk {..} =
"Plonk: " ++ show omega ++ " " ++ show k1 ++ " " ++ show k2 ++ " " ++ show (acOutput ac) ++ " " ++ show ac ++ " " ++ show x

instance ( KnownNat i, Arithmetic (ScalarField c1)
, Binary (ScalarField c1), Binary (Rep p)
instance ( Arithmetic (ScalarField c1), Binary (ScalarField c1)
, Binary (Rep p), Binary (Rep i), Ord (Rep i)
, Arbitrary (Plonkup p i n l c1 c2 t))
=> Arbitrary (Plonk p i n l c1 c2 t) where
arbitrary = fromPlonkup <$> arbitrary
Expand All @@ -69,6 +70,7 @@ instance forall p i n l c1 c2 (ts :: Type) core .
, Input (Plonkup p i n l c1 c2 ts) ~ PlonkupInput l c1
, Proof (Plonkup p i n l c1 c2 ts) ~ PlonkupProof c1
, KnownNat n
, Foldable l
, Ord (BaseField c1)
, AdditiveGroup (BaseField c1)
, Pairing c1 c2
Expand Down
5 changes: 3 additions & 2 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonk/Prover.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number (KnownNat, Natural, value)
import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), PointCompressed, compress)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr)
import ZkFold.Base.Data.Vector (fromVector, (!!))
import ZkFold.Base.Data.Vector ((!!))
import ZkFold.Base.Protocol.NonInteractiveProof
import ZkFold.Base.Protocol.Plonkup (with4n6)
import ZkFold.Base.Protocol.Plonkup.Input
Expand All @@ -32,6 +32,7 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

plonkProve :: forall p i n l c1 c2 ts core .
( KnownNat n
, Foldable l
, Ord (BaseField c1)
, AdditiveGroup (BaseField c1)
, Arithmetic (ScalarField c1)
Expand Down Expand Up @@ -61,7 +62,7 @@ plonkProve PlonkupProverSetup {..}
w2X = with4n6 @n $ polyVecInLagrangeBasis omega w2 :: PlonkupPolyExtended n c1
w3X = with4n6 @n $ polyVecInLagrangeBasis omega w3 :: PlonkupPolyExtended n c1

pi = toPolyVec @_ @n $ fromList $ fromVector (negate <$> wPub)
pi = toPolyVec @_ @n $ fromList $ foldMap (\x -> [negate x]) wPub
piX = with4n6 @n $ polyVecInLagrangeBasis omega pi :: PlonkupPolyExtended n c1

-- Round 1
Expand Down
4 changes: 2 additions & 2 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonk/Verifier.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number (KnownNat, Natural, value)
import ZkFold.Base.Algebra.EllipticCurve.Class
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr)
import ZkFold.Base.Data.Vector (fromVector)
import ZkFold.Base.Protocol.NonInteractiveProof hiding (verify)
import ZkFold.Base.Protocol.Plonkup.Input
import ZkFold.Base.Protocol.Plonkup.Internal
Expand All @@ -25,6 +24,7 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

plonkVerify :: forall p i n l c1 c2 ts .
( KnownNat n
, Foldable l
, Pairing c1 c2
, Ord (BaseField c1)
, AdditiveGroup (BaseField c1)
Expand Down Expand Up @@ -103,7 +103,7 @@ plonkVerify

-- Step 7: Compute public polynomial evaluation
pi_xi = with4n6 @n $ polyVecInLagrangeBasis @(ScalarField c1) @n @(PlonkupPolyExtendedLength n) omega
(toPolyVec $ fromList $ fromVector (negate <$> wPub))
(toPolyVec $ fromList $ foldMap (\x -> [negate x]) wPub)
`evalPolyVec` xi

-- Step 8: Compute the public table commitment
Expand Down
10 changes: 6 additions & 4 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonkup.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module ZkFold.Base.Protocol.Plonkup (
Plonkup (..)
) where

import Data.Functor.Rep (Representable)
import Data.Functor.Rep (Rep, Representable)
import Data.Word (Word8)
import Prelude hiding (Num (..), div, drop, length, replicate,
sum, take, (!!), (/), (^))
Expand All @@ -29,10 +29,12 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
{-| Based on the paper https://eprint.iacr.org/2022/086.pdf -}

instance forall p i n l c1 c2 ts core.
( KnownNat i
, KnownNat n
, KnownNat l
( KnownNat n
, Representable p
, Representable i
, Representable l
, Foldable l
, Ord (Rep i)
, Ord (BaseField c1)
, AdditiveGroup (BaseField c1)
, Pairing c1 c2
Expand Down
23 changes: 12 additions & 11 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Input.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@

module ZkFold.Base.Protocol.Plonkup.Input where

import Prelude hiding (Num (..), drop, length, sum, take, (!!), (/), (^))
import Data.Function (($))
import Data.Functor (Functor, (<$>))
import Data.Functor.Classes (Show1)
import Data.List ((++))
import Test.QuickCheck (Arbitrary (..))
import Text.Show (Show, show)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..))
import ZkFold.Base.Data.Vector (Vector (..), unsafeToVector)
import ZkFold.Prelude (take)
import ZkFold.Symbolic.Compiler ()

newtype PlonkupInput l c = PlonkupInput { unPlonkupInput :: Vector l (ScalarField c) }
newtype PlonkupInput l c = PlonkupInput { unPlonkupInput :: l (ScalarField c) }

instance Show (ScalarField c) => Show (PlonkupInput l c) where
instance (Show1 l, Show (ScalarField c)) => Show (PlonkupInput l c) where
show (PlonkupInput v) = "Plonkup Input: " ++ show v

instance (KnownNat l, Arbitrary (ScalarField c)) => Arbitrary (PlonkupInput l c) where
arbitrary = do
PlonkupInput . unsafeToVector . take (value @l) <$> arbitrary
instance (Arbitrary (l (ScalarField c))) => Arbitrary (PlonkupInput l c) where
arbitrary = PlonkupInput <$> arbitrary

plonkupVerifierInput :: Field (ScalarField c) => Vector l (ScalarField c) -> PlonkupInput l c
plonkupVerifierInput input = PlonkupInput $ fmap negate input
plonkupVerifierInput ::
(Functor l, Field (ScalarField c)) => l (ScalarField c) -> PlonkupInput l c
plonkupVerifierInput input = PlonkupInput $ negate <$> input
15 changes: 6 additions & 9 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

module ZkFold.Base.Protocol.Plonkup.Internal where

import Data.Binary (Binary)
import Data.Constraint (withDict)
import Data.Constraint.Nat (plusNat, timesNat)
import Data.Functor.Classes (Show1)
import Data.Functor.Rep (Rep)
import Prelude hiding (Num (..), drop, length, sum, take, (!!),
(/), (^))
Expand All @@ -14,7 +14,6 @@ import Test.QuickCheck (Arbitrary
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..))
import ZkFold.Base.Algebra.Polynomials.Univariate (PolyVec)
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Base.Protocol.Plonkup.Utils
import ZkFold.Symbolic.Compiler ()
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
Expand All @@ -24,11 +23,11 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
Additionally, we don't want this library to depend on Cardano libraries.
-}

data Plonkup p (i :: Natural) (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonkup {
data Plonkup p i (n :: Natural) l curve1 curve2 transcript = Plonkup {
vlasin marked this conversation as resolved.
Show resolved Hide resolved
omega :: ScalarField curve1,
k1 :: ScalarField curve1,
k2 :: ScalarField curve1,
ac :: ArithmeticCircuit (ScalarField curve1) p (Vector i) (Vector l),
ac :: ArithmeticCircuit (ScalarField curve1) p i l,
x :: ScalarField curve1
}

Expand All @@ -37,20 +36,18 @@ type PlonkupPermutationSize n = 3 * n
-- The maximum degree of the polynomials we need in the protocol is `4 * n + 5`.
type PlonkupPolyExtendedLength n = 4 * n + 6


with4n6 :: forall n {r}. KnownNat n => (KnownNat (4 * n + 6) => r) -> r
with4n6 f = withDict (timesNat @4 @n) (withDict (plusNat @(4 * n) @6) f)

type PlonkupPolyExtended n c = PolyVec (ScalarField c) (PlonkupPolyExtendedLength n)

instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l, KnownNat i) => Show (Plonkup p i n l c1 c2 t) where
instance (Show (ScalarField c1), Show (Rep i), Show1 l, Ord (Rep i)) => Show (Plonkup p i n l c1 c2 t) where
show Plonkup {..} =
"Plonkup: " ++ show omega ++ " " ++ show k1 ++ " " ++ show k2 ++ " " ++ show (acOutput ac) ++ " " ++ show ac ++ " " ++ show x

instance
( KnownNat i, KnownNat n, KnownNat l
, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1), Binary (ScalarField c1)
, Binary (Rep p)
( KnownNat n, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1)
, Arbitrary (ArithmeticCircuit (ScalarField c1) p i l)
) => Arbitrary (Plonkup p i n l c1 c2 t) where
arbitrary = do
ac <- arbitrary
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Plonkup.LookupConstraint where

import Data.Binary (Binary)
import Data.ByteString (ByteString)
import Data.Functor.Rep (Rep)
import Prelude hiding (Num (..), drop, length, sum, take, (!!),
(/), (^))
import Test.QuickCheck (Arbitrary (..))

import ZkFold.Base.Data.ByteString (toByteString)
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

newtype LookupConstraint i a = LookupConstraint { lkVar :: SysVar (Vector i) }
deriving (Show, Eq)
newtype LookupConstraint i a = LookupConstraint { lkVar :: SysVar i }

deriving instance Show (Rep i) => Show (LookupConstraint i a)
deriving instance Eq (Rep i) => Eq (LookupConstraint i a)

instance (Arbitrary a, Binary a) => Arbitrary (LookupConstraint i a) where
arbitrary = LookupConstraint . NewVar . toByteString @a <$> arbitrary
Expand Down
25 changes: 13 additions & 12 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import Data.Containers.ListUtils (nubOrd)
import Data.Eq (Eq (..))
import Data.Function (($), (.))
import Data.Functor ((<$>))
import Data.Functor.Rep (Rep)
import Data.List (find, head, map, permutations, sort, (!!), (++))
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (Maybe (..), fromMaybe, mapMaybe)
import Data.Ord (Ord)
import GHC.IsList (IsList (..))
import GHC.TypeNats (KnownNat)
import Numeric.Natural (Natural)
import Test.QuickCheck (Arbitrary (..))
import Text.Show (Show)
Expand All @@ -24,7 +24,6 @@ import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, evalMonomial, evalPolynomial, polynomial,
var, variables)
import ZkFold.Base.Data.ByteString (toByteString)
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Prelude (length, take)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

Expand All @@ -34,13 +33,15 @@ data PlonkConstraint i a = PlonkConstraint
, qr :: a
, qo :: a
, qc :: a
, x1 :: Var a (Vector i)
, x2 :: Var a (Vector i)
, x3 :: Var a (Vector i)
, x1 :: Var a i
, x2 :: Var a i
, x3 :: Var a i
}
deriving (Show, Eq)

instance (Ord a, Arbitrary a, Binary a, KnownNat i) => Arbitrary (PlonkConstraint i a) where
deriving instance (Show a, Show (Rep i)) => Show (PlonkConstraint i a)
deriving instance (Eq a, Eq (Rep i)) => Eq (PlonkConstraint i a)

instance (Ord a, Arbitrary a, Binary a, Ord (Rep i)) => Arbitrary (PlonkConstraint i a) where
arbitrary = do
qm <- arbitrary
ql <- arbitrary
Expand All @@ -49,10 +50,10 @@ instance (Ord a, Arbitrary a, Binary a, KnownNat i) => Arbitrary (PlonkConstrain
qc <- arbitrary
let arbitraryNewVar = SysVar . NewVar . toByteString @a <$> arbitrary
xs <- sort <$> replicateM 3 arbitraryNewVar
let x1 = xs !! 0; x2 = xs !! 1; x3 = xs !! 2
let x1 = head xs; x2 = xs !! 1; x3 = xs !! 2
return $ PlonkConstraint qm ql qr qo qc x1 x2 x3

toPlonkConstraint :: forall a i . (Ord a, FiniteField a, KnownNat i) => Poly a (Var a (Vector i)) Natural -> PlonkConstraint i a
toPlonkConstraint :: forall a i . (Ord a, FiniteField a, Ord (Rep i)) => Poly a (Var a i) Natural -> PlonkConstraint i a
toPlonkConstraint p =
let xs = Just <$> toList (variables p)
perms = nubOrd $ map (take 3) $ permutations $ case length xs of
Expand All @@ -61,12 +62,12 @@ toPlonkConstraint p =
2 -> [Nothing] ++ xs ++ xs
_ -> xs ++ xs

getCoef :: Map (Maybe (Var a (Vector i))) Natural -> a
getCoef :: Map (Maybe (Var a i)) Natural -> a
getCoef m = case find (\(_, as) -> m == Map.mapKeys Just as) (toList p) of
Just (c, _) -> c
_ -> zero

getCoefs :: [Maybe (Var a (Vector i))] -> Maybe (PlonkConstraint i a)
getCoefs :: [Maybe (Var a i)] -> Maybe (PlonkConstraint i a)
getCoefs [a, b, c] = do
let xa = [(a, 1)]
xb = [(b, 1)]
Expand All @@ -89,7 +90,7 @@ toPlonkConstraint p =
[] -> toPlonkConstraint zero
_ -> head $ mapMaybe getCoefs perms

fromPlonkConstraint :: (Ord a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var a (Vector i)) Natural
fromPlonkConstraint :: (Ord a, Field a, Ord (Rep i)) => PlonkConstraint i a -> Poly a (Var a i) Natural
fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) =
let xa = var a
xb = var b
Expand Down
Loading