Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fft poly div #440

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cabal-dev
.cabal-sandbox/
cabal.sandbox.config
*.prof
*.svg
*.aux
*.hp
*.eventlog
Expand Down
72 changes: 49 additions & 23 deletions symbolic-apps/src/ZkFold/Symbolic/Apps/KYC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,27 @@ import Data.Aeson
import Data.Functor ((<$>))
import Data.Maybe (fromJust)
import GHC.Generics (Generic)
import Prelude (String, error, ($), (.))
import Prelude (String, error, fst, snd, ($), (.))

import ZkFold.Base.Algebra.Basic.Class (FromConstant (fromConstant))
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.Vector (Vector, head, tail, toVector)
import ZkFold.Base.Control.HApplicative (HApplicative)
import ZkFold.Base.Data.Vector (Vector, splitAt, toVector)
import ZkFold.Symbolic.Class (Symbolic (BaseField))
import ZkFold.Symbolic.Data.Bool (Bool, not, (&&))
import ZkFold.Symbolic.Data.ByteString (ByteString, Resize (resize), concat, toWords)
import ZkFold.Symbolic.Data.Combinators (Ceil, GetRegisterSize, Iso (..), KnownRegisterSize, KnownRegisters,
RegisterSize (..))
import ZkFold.Symbolic.Data.Class (SymbolicData)
import ZkFold.Symbolic.Data.Combinators (Ceil, GetRegisterSize, Iso (..), KnownRegisterSize,
NumberOfRegisters)
import ZkFold.Symbolic.Data.Eq (Eq ((==)), elem)
import ZkFold.Symbolic.Data.Input (SymbolicInput)
import ZkFold.Symbolic.Data.Ord (Ord ((>=)))
import ZkFold.Symbolic.Data.UInt (OrdWord, UInt)
import ZkFold.Symbolic.Interpreter (Interpreter)

type KYCByteString context = ByteString 256 context

type KYCHash context = UInt 256 Auto context

{-
>>> type Prime256_1 = 28948022309329048855892746252171976963363056481941560715954676764349967630337
>>> :{
Expand All @@ -39,43 +40,68 @@ exKYC = KYCData
>>> encode exKYC
"{\"kycHash\":\"bb8\",\"kycID\":4000,\"kycType\":\"3e8\",\"kycValue\":\"7d0\"}"
-}
data KYCData n context = KYCData
{ kycType :: KYCByteString context
data KYCData n k r context = KYCData
{ kycType :: ByteString k context
, kycID :: UInt k r context
, kycHash :: UInt k r context
, kycValue :: ByteString n context
, kycHash :: KYCHash context
, kycID :: UInt 64 Auto context
} deriving Generic

data User r context = User
{ userAge :: UInt 64 r context
, userCountry :: ByteString 128 context
{ userAge :: UInt 8 r context
, userCountry :: ByteString 10 context
} deriving Generic

instance (Symbolic context) => FromJSON (KYCData 256 context)
instance (Symbolic (Interpreter (Zp p))) => ToJSON (KYCData 256 (Interpreter (Zp p)))
instance ( Symbolic context
, KnownNat n
, KnownNat k
, KnownRegisterSize r
) => FromJSON (KYCData n k r context)

instance ( Symbolic (Interpreter (Zp p))
, KnownNat n
, KnownNat k
, KnownRegisterSize r
) => ToJSON (KYCData n k r (Interpreter (Zp p)))

instance ( HApplicative context
, KnownNat n
, KnownNat k
, KnownRegisterSize r
, KnownNat (NumberOfRegisters (BaseField context) k r)
, Symbolic context
) => SymbolicData (KYCData n k r context)

instance (
Symbolic context
, KnownNat n
, KnownNat k
, KnownRegisterSize r
, KnownNat (NumberOfRegisters (BaseField context) k r)
) => SymbolicInput (KYCData n k r context)

isCitizen :: (Symbolic c) => KYCByteString c -> Vector n (KYCByteString c) -> Bool c
isCitizen = elem

kycExample :: forall n r rsc context . (
kycExample :: forall n k r rsc context . (
Symbolic context
, KnownNat n
, KnownNat rsc
, Eq (Bool context) (KYCHash context)
, KnownRegisterSize r
, KnownRegisters context 64 r
, KnownNat (Ceil (GetRegisterSize (BaseField context) 64 r) OrdWord)
) => KYCData n context -> KYCHash context -> Bool context
, KnownNat (NumberOfRegisters (BaseField context) 8 r)
, KnownNat (Ceil (GetRegisterSize (BaseField context) 8 r) OrdWord)
, KnownNat (NumberOfRegisters (BaseField context) k r)
) => KYCData n k r context -> UInt k r context -> Bool context
kycExample kycData hash =
let
v :: Vector 3 (ByteString 64 context)
v = toWords $ resize $ kycValue kycData
v :: (Vector 8 (ByteString 1 context), Vector 10 (ByteString 1 context))
v = splitAt @8 @10 $ toWords $ resize $ kycValue kycData

correctHash :: Bool context
correctHash = hash == kycHash kycData

user :: User r context
user = User (from $ head v) (concat $ tail v)
user = User (from $ concat $ fst v) (concat $ snd v)

validAge :: Bool context
validAge = userAge user >= fromConstant (18 :: Natural)
Expand Down Expand Up @@ -105,7 +131,7 @@ iso3166 = \case
restrictedCountries :: forall m context . (
Symbolic context
, KnownNat m
) => Vector m (ByteString 128 context)
) => Vector m (ByteString 10 context)
restrictedCountries =
fromJust $ toVector $ fromConstant . iso3166 <$>
[ "FRA"
Expand Down
4 changes: 3 additions & 1 deletion symbolic-base/src/ZkFold/Base/Algebra/EllipticCurve/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

module ZkFold.Base.Algebra.EllipticCurve.Class where

import Data.Aeson (ToJSON)
import Data.Functor ((<&>))
import Data.Kind (Type)
import Data.String (fromString)
Expand Down Expand Up @@ -111,6 +112,7 @@ instance
) => Show (Point curve) where
show (Point x y isInf) = if isInf then "Inf" else "(" ++ show x ++ ", " ++ show y ++ ")"

deriving instance (ToJSON (BaseField curve), ToJSON (BooleanOf curve)) => ToJSON (Point curve)
instance EllipticCurve curve => AdditiveSemigroup (Point curve) where
(+) = add

Expand Down Expand Up @@ -187,7 +189,7 @@ pointNegate (Point x y isInf) = if isInf then pointInf else pointXY x (negate y)
pointMul
:: forall curve s
. EllipticCurve curve
=> BinaryExpansion (s)
=> BinaryExpansion s
=> Bits s ~ [s]
=> P.Eq s
=> s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ import Control.DeepSeq (NFData (..))
import qualified Data.Vector as V
import GHC.Generics (Generic)
import GHC.IsList (IsList (..))
import Prelude hiding (Num (..), drop, length, product, replicate, sum, take, (/),
(^))
import Prelude hiding (Num (..), drop, length, product, replicate, sum, take,
truncate, (/), (^))
import qualified Prelude as P
import Test.QuickCheck (Arbitrary (..), chooseInt)

Expand Down Expand Up @@ -112,7 +112,7 @@ instance (Ring c, Eq c) => AdditiveMonoid (Poly c) where
instance (Ring c, Eq c) => AdditiveGroup (Poly c) where
negate (P cs) = P $ fmap negate cs

instance (Field c, Eq c) => MultiplicativeSemigroup (Poly c) where
instance {-# OVERLAPPABLE #-} (Field c, Eq c) => MultiplicativeSemigroup (Poly c) where
-- | If it is possible to calculate a primitive root of unity in the field, proceed with FFT multiplication.
-- Otherwise default to Karatsuba multiplication for polynomials of degree higher than 64 or use naive multiplication otherwise.
-- 64 is a threshold determined by benchmarking.
Expand Down Expand Up @@ -329,7 +329,7 @@ instance (Field c, KnownNat size) => Exponent (PolyVec c size) Natural where
instance {-# OVERLAPPING #-} (Field c, KnownNat size) => Scale (PolyVec c size) (PolyVec c size)

-- TODO (Issue #18): check for overflow
instance (Field c, KnownNat size) => MultiplicativeSemigroup (PolyVec c size) where
instance {-# OVERLAPPABLE #-} (Field c, KnownNat size) => MultiplicativeSemigroup (PolyVec c size) where
(PV l) * (PV r) = toPolyVec $ mulAdaptive l r

instance (Field c, KnownNat size) => MultiplicativeMonoid (PolyVec c size) where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import Data.Word (Word8)
import Numeric.Natural (Natural)
import Prelude hiding (Num ((*)), sum)

import ZkFold.Base.Algebra.Basic.Class (Field, MultiplicativeSemigroup ((*)), sum)
import ZkFold.Base.Algebra.Basic.Class (Field, sum, (*))
import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Point)
import ZkFold.Base.Algebra.Polynomials.Univariate (Poly, PolyVec, fromPolyVec)
import ZkFold.Base.Algebra.Polynomials.Univariate (Poly, PolyVec, fromPolyVec, qr)
import ZkFold.Base.Data.ByteString

class Monoid ts => ToTranscript ts a where
Expand Down Expand Up @@ -69,8 +69,13 @@ class (EllipticCurve curve) => CoreFunction curve core where

polyMul :: (f ~ ScalarField curve, Field f, Eq f) => Poly f -> Poly f -> Poly f

polyQr :: (f ~ ScalarField curve, Field f, Eq f) => Poly f -> Poly f -> (Poly f, Poly f)

data HaskellCore

instance (EllipticCurve curve, f ~ ScalarField curve) => CoreFunction curve HaskellCore where
msm gs f = sum $ V.zipWith mul (fromPolyVec f) gs

polyMul = (*)

polyQr = qr
2 changes: 1 addition & 1 deletion symbolic-base/src/ZkFold/Base/Protocol/Plonk.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ instance forall p i n l c1 c2 (ts :: Type) core .
in (input, proof)

verify :: SetupVerify (Plonk p i n l c1 c2 ts) -> Input (Plonk p i n l c1 c2 ts) -> Proof (Plonk p i n l c1 c2 ts) -> Bool
verify = plonkVerify @p @i @n @l @c1 @c2 @ts
verify = plonkVerify @p @i @n @l @c1 @c2 @ts @core
17 changes: 16 additions & 1 deletion symbolic-base/src/ZkFold/Base/Protocol/Plonk/Prover.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators #-}

module ZkFold.Base.Protocol.Plonk.Prover
( plonkProve
Expand All @@ -14,7 +15,8 @@ import Prelude hiding (Num
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number (KnownNat, Natural, value)
import ZkFold.Base.Algebra.EllipticCurve.Class (CompressedPoint, EllipticCurve (..), compress)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (polyVecDiv, polyVecInLagrangeBasis,
polyVecLagrange, qr)
import ZkFold.Base.Data.Vector ((!!))
import ZkFold.Base.Protocol.NonInteractiveProof
import ZkFold.Base.Protocol.Plonkup (with4n6)
Expand Down Expand Up @@ -50,6 +52,19 @@ plonkProve PlonkupProverSetup {..}
(@) :: forall size . (KnownNat size) => PolyVec (ScalarField c1) size -> PolyVec (ScalarField c1) size -> PolyVec (ScalarField c1) size
(@) a b = poly2vec $ polyMul @c1 @core (vec2poly a) (vec2poly b)

polyVecDiv :: forall c size . (c ~ ScalarField c1, KnownNat size) =>PolyVec c size -> PolyVec c size -> PolyVec c size
polyVecDiv l r = poly2vec $ fst $ (polyQr @c1 @core) (vec2poly l) (vec2poly r)

polyVecLagrange :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
Natural -> c -> PolyVec c size
polyVecLagrange i omega' = scalePV (omega'^i // fromConstant (value @m)) $ (polyVecZero @c @m @size - one) `polyVecDiv` polyVecLinear one (negate $ omega'^i)

polyVecInLagrangeBasis :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
c -> PolyVec c m -> PolyVec c size
polyVecInLagrangeBasis omega' cs =
let ls = fmap (\i -> polyVecLagrange @c @m @size i omega') (V.generate (V.length (fromPolyVec cs)) (fromIntegral . succ))
in sum $ V.zipWith scalePV (fromPolyVec cs) ls

PlonkupCircuitPolynomials {..} = polynomials
secret i = ps !! (i -! 1)

Expand Down
21 changes: 19 additions & 2 deletions symbolic-base/src/ZkFold/Base/Protocol/Plonk/Verifier.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Plonk.Verifier
( plonkVerify
) where

import qualified Data.Vector as V
import Data.Word (Word8)
import GHC.IsList (IsList (..))
import Prelude hiding (Num (..), Ord, drop, length, sum, take,
Expand All @@ -13,7 +15,8 @@ import Prelude hiding (Num
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number (KnownNat, Natural, value)
import ZkFold.Base.Algebra.EllipticCurve.Class
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (polyVecDiv, polyVecInLagrangeBasis,
polyVecLagrange, qr)
import ZkFold.Base.Protocol.NonInteractiveProof hiding (verify)
import ZkFold.Base.Protocol.Plonkup.Input
import ZkFold.Base.Protocol.Plonkup.Internal
Expand All @@ -23,7 +26,7 @@ import ZkFold.Base.Protocol.Plonkup.Verifier.Setup
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
import ZkFold.Symbolic.Data.Ord

plonkVerify :: forall p i n l c1 c2 ts .
plonkVerify :: forall p i n l c1 c2 ts core .
( KnownNat n
, Foldable l
, Pairing c1 c2
Expand All @@ -34,6 +37,7 @@ plonkVerify :: forall p i n l c1 c2 ts .
, ToTranscript ts (ScalarField c1)
, ToTranscript ts (CompressedPoint c1)
, FromTranscript ts (ScalarField c1)
, CoreFunction c1 core
) => PlonkupVerifierSetup p i n l c1 c2 -> PlonkupInput l c1 -> PlonkupProof c1 -> Bool
plonkVerify
PlonkupVerifierSetup {..}
Expand Down Expand Up @@ -150,3 +154,16 @@ plonkVerify
-- Step 13: Compute the pairing
p1 = pairing (proof1 + eta `mul` proof2) h1
p2 = pairing (xi `mul` proof1 + (eta * xi * omega) `mul` proof2 + f - e) (pointGen @c2)

polyVecDiv :: forall c size . (c ~ ScalarField c1, KnownNat size) =>PolyVec c size -> PolyVec c size -> PolyVec c size
polyVecDiv l r = poly2vec $ fst $ (polyQr @c1 @core) (vec2poly l) (vec2poly r)

polyVecLagrange :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
Natural -> c -> PolyVec c size
polyVecLagrange i omega' = scalePV (omega'^i // fromConstant (value @m)) $ (polyVecZero @c @m @size - one) `polyVecDiv` polyVecLinear one (negate $ omega'^i)

polyVecInLagrangeBasis :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
c -> PolyVec c m -> PolyVec c size
polyVecInLagrangeBasis omega' cs =
let ls = fmap (\i -> polyVecLagrange @c @m @size i omega') (V.generate (V.length (fromPolyVec cs)) (fromIntegral . succ))
in sum $ V.zipWith scalePV (fromPolyVec cs) ls
2 changes: 1 addition & 1 deletion symbolic-base/src/ZkFold/Base/Protocol/Plonkup.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ instance forall p i n l c1 c2 ts core.
in (input, proof)

verify :: SetupVerify (Plonkup p i n l c1 c2 ts) -> Input (Plonkup p i n l c1 c2 ts) -> Proof (Plonkup p i n l c1 c2 ts) -> Bool
verify = with4n6 @n $ plonkupVerify @p @i @n @l @c1 @c2 @ts
verify = with4n6 @n $ plonkupVerify @p @i @n @l @c1 @c2 @ts @core

17 changes: 16 additions & 1 deletion symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Prover.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators #-}

module ZkFold.Base.Protocol.Plonkup.Prover
( module ZkFold.Base.Protocol.Plonkup.Prover.Polynomials
Expand All @@ -17,7 +18,8 @@ import Prelude hiding (Num
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number (KnownNat, Natural, value)
import ZkFold.Base.Algebra.EllipticCurve.Class (CompressedPoint, EllipticCurve (..), compress)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (polyVecDiv, polyVecInLagrangeBasis,
polyVecLagrange, qr)
import ZkFold.Base.Data.Vector ((!!))
import ZkFold.Base.Protocol.NonInteractiveProof
import ZkFold.Base.Protocol.Plonkup.Input
Expand Down Expand Up @@ -53,6 +55,19 @@ plonkupProve PlonkupProverSetup {..}
(@) :: forall size . (KnownNat size) => PolyVec (ScalarField c1) size -> PolyVec (ScalarField c1) size -> PolyVec (ScalarField c1) size
(@) a b = poly2vec $ polyMul @c1 @core (vec2poly a) (vec2poly b)

polyVecDiv :: forall c size . (c ~ ScalarField c1, KnownNat size) =>PolyVec c size -> PolyVec c size -> PolyVec c size
polyVecDiv l r = poly2vec $ fst $ (polyQr @c1 @core) (vec2poly l) (vec2poly r)

polyVecLagrange :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
Natural -> c -> PolyVec c size
polyVecLagrange i omega' = scalePV (omega'^i // fromConstant (value @m)) $ (polyVecZero @c @m @size - one) `polyVecDiv` polyVecLinear one (negate $ omega'^i)

polyVecInLagrangeBasis :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
c -> PolyVec c m -> PolyVec c size
polyVecInLagrangeBasis omega' cs =
let ls = fmap (\i -> polyVecLagrange @c @m @size i omega') (V.generate (V.length (fromPolyVec cs)) (fromIntegral . succ))
in sum $ V.zipWith scalePV (fromPolyVec cs) ls

PlonkupCircuitPolynomials {..} = polynomials
secret i = ps !! (i -! 1)

Expand Down
17 changes: 16 additions & 1 deletion symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Setup.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Algebra.Basic.Permutations (fromPermutation)
import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Pairing, Point)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr)
import ZkFold.Base.Algebra.Polynomials.Univariate hiding (polyVecDiv, polyVecInLagrangeBasis,
polyVecLagrange, qr)
import ZkFold.Base.Protocol.NonInteractiveProof (CoreFunction (..))
import ZkFold.Base.Protocol.Plonkup.Internal
import ZkFold.Base.Protocol.Plonkup.Prover
Expand Down Expand Up @@ -119,3 +120,17 @@ plonkupSetup Plonkup {..} =
commitments = PlonkupCircuitCommitments {..}

in PlonkupSetup {..}
where
polyVecDiv :: forall c size . (c ~ ScalarField c1, KnownNat size) =>PolyVec c size -> PolyVec c size -> PolyVec c size
polyVecDiv l r = poly2vec $ fst $ (polyQr @c1 @core) (vec2poly l) (vec2poly r)

polyVecLagrange :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
Natural -> c -> PolyVec c size
polyVecLagrange i omega' = scalePV (omega'^i // fromConstant (value @m)) $ (polyVecZero @c @m @size - one) `polyVecDiv` polyVecLinear one (negate $ omega'^i)

polyVecInLagrangeBasis :: forall c m size . (c ~ ScalarField c1, KnownNat m, KnownNat size) =>
c -> PolyVec c m -> PolyVec c size
polyVecInLagrangeBasis omega' cs =
let ls = fmap (\i -> polyVecLagrange @c @m @size i omega') (V.generate (V.length (fromPolyVec cs)) (fromIntegral . succ))
in sum $ V.zipWith scalePV (fromPolyVec cs) ls

Loading