Skip to content

Commit

Permalink
Merge pull request #293 from zkFold/TurtlePU/remove-TypeSize
Browse files Browse the repository at this point in the history
Removed TypeSize
  • Loading branch information
vlasin authored Oct 11, 2024
2 parents faedb7a + de82b2c commit 48d6e75
Show file tree
Hide file tree
Showing 39 changed files with 332 additions and 375 deletions.
9 changes: 5 additions & 4 deletions symbolic-base/src/ZkFold/Base/Protocol/Protostar/Commit.hs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.Protostar.Commit (Commit (..), HomomorphicCommit (..), PedersonSetup (..)) where

import Data.Foldable (Foldable, toList)
import Prelude (type (~), zipWith, ($), (<$>))
import Data.Functor.Rep (Representable)
import Prelude (Traversable, type (~), zipWith, ($), (<$>))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Algebra.EllipticCurve.BLS12_381
import ZkFold.Base.Algebra.EllipticCurve.Class as EC
import ZkFold.Base.Algebra.EllipticCurve.Ed25519
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Base.Protocol.Protostar.Oracle
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Class
Expand Down Expand Up @@ -77,7 +76,9 @@ instance
, SymbolicData (Point c)
, Context (Point c) ~ ctx
, PedersonSetup (Point c)
, Layout (Point c) ~ Vector n
, Layout (Point c) ~ l
, Representable l
, Traversable l
) => HomomorphicCommit (FieldElement ctx) (FieldElement ctx) (Point c) where
hcommit r b = let (g, h) = pedersonGH @(Point c)
in scale b g + scale r h
Expand Down
81 changes: 30 additions & 51 deletions symbolic-base/src/ZkFold/Base/Protocol/Protostar/Fold.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,24 @@ module ZkFold.Base.Protocol.Protostar.Fold where

import Control.DeepSeq (NFData)
import Control.Lens ((^.))
import Data.Binary (Binary)
import Data.Function ((.))
import Data.Functor (fmap)
import Data.Functor.Rep (Rep, Representable)
import Data.Kind (Type)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Ord (Ord)
import Data.Proxy (Proxy)
import GHC.Generics (Generic, Par1)
import GHC.Generics (Generic, Par1 (..), U1 (..), type (:*:) (..),
type (:.:) (..))
import Prelude (type (~), ($), (<$>), (<*>))
import qualified Prelude as P

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import qualified ZkFold.Base.Algebra.Polynomials.Univariate as PU
import ZkFold.Base.Data.HFunctor (hmap)
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Base.Protocol.Protostar.Accumulator
import qualified ZkFold.Base.Protocol.Protostar.AccumulatorScheme as Acc
Expand Down Expand Up @@ -53,10 +60,10 @@ instance (Ring a, P.Ord key, KnownNat k) => Acc.LinearCombination (Map key a) (M
linearCombination mx ma = M.unionWith (+) (P.flip PU.polyVecLinear zero <$> mx) (PU.polyVecConstant <$> ma)

instance (Ring a, KnownNat n, KnownNat k) => Acc.LinearCombination (Vector n a) (Vector n (PU.PolyVec a k)) where
linearCombination mx ma = (+) <$> (P.flip PU.polyVecLinear zero <$> mx) <*> (PU.polyVecConstant <$> ma)
linearCombination mx ma = (+) . P.flip PU.polyVecLinear zero <$> mx <*> (PU.polyVecConstant <$> ma)

instance (Ring a, KnownNat n) => Acc.LinearCombinationWith a (Vector n a) where
linearCombinationWith coeff a b = (+) <$> (P.fmap (coeff *) a) <*> b
linearCombinationWith coeff a b = (+) <$> P.fmap (coeff *) a <*> b


type C n a = ArithmeticCircuit a (Vector n) (Vector n)
Expand All @@ -82,7 +89,7 @@ toFS
-> C n a
-> Vector n (FieldElement ctx)
-> FS_CM ctx n comm a
toFS ck rc v = FiatShamir (CommitOpen (hcommit ck) rc) v
toFS ck rc = FiatShamir (CommitOpen (hcommit ck) rc)

-- No SymbolicData instances for data
-- all protocols are one-round in case of arithmetic circuits, therefore we can replace lists with elements.
Expand All @@ -94,7 +101,7 @@ ivcVerifier
-> (a, (f, (f, f)), ((i, c, f, c, f), m))
-> Bool ctx
ivcVerifier (i, pi_x, accTuple, acc'Tuple, pf) (a, ckTuple, dkTuple)
= (Acc.verifier @i @f @c @m @ctx @a i [pi_x] acc acc' [pf]) && (Acc.decider @i @f @c @m @ctx @a a ck dk)
= Acc.verifier @i @f @c @m @ctx @a i [pi_x] acc acc' [pf] && Acc.decider @i @f @c @m @ctx @a a ck dk
where
acc = let (x1, x2, x3, x4, x5) = accTuple
in AccumulatorInstance x1 [x2] [x3] x4 x5
Expand All @@ -109,48 +116,14 @@ ivcVerifier (i, pi_x, accTuple, acc'Tuple, pf) (a, ckTuple, dkTuple)

-- TODO: this is insane
ivcVerifierAc
:: forall i f c m ctx a y typeSize ckSize dkSize accSize
:: forall i f c m ctx a y t
. Symbolic ctx
=> TypeSize y ~ 1
=> SymbolicData i
=> SymbolicData f
=> SymbolicData c
=> SymbolicData m
=> SymbolicData a
=> SymbolicData y
=> typeSize ~ ((TypeSize i
+ (TypeSize c
+ ((TypeSize i
+ (TypeSize c + (TypeSize f + (TypeSize c + TypeSize f))))
+ ((TypeSize i
+ (TypeSize c
+ (TypeSize f + (TypeSize c + TypeSize f))))
+ TypeSize c))))
+ (TypeSize a
+ ((TypeSize f + (TypeSize f + TypeSize f))
+ ((TypeSize i
+ (TypeSize c + (TypeSize f + (TypeSize c + TypeSize f))))
+ TypeSize m))))
=> ckSize ~ (TypeSize f + (TypeSize f + TypeSize f))
=> dkSize ~ TypeSize a + ((TypeSize f + (TypeSize f + TypeSize f))
+ ((TypeSize i
+ (TypeSize c + (TypeSize f + (TypeSize c + TypeSize f))))
+ TypeSize m))
=> accSize ~ (TypeSize i + (TypeSize c
+ ((TypeSize i
+ (TypeSize c + (TypeSize f + (TypeSize c + TypeSize f))))
+ ((TypeSize i
+ (TypeSize c + (TypeSize f + (TypeSize c + TypeSize f))))
+ TypeSize c))))
=> KnownNat typeSize
=> KnownNat dkSize
=> KnownNat ckSize
=> KnownNat accSize
=> KnownNat (TypeSize i + (TypeSize c + (TypeSize f + (TypeSize c + TypeSize f))))
=> KnownNat (TypeSize i)
=> KnownNat (TypeSize f)
=> KnownNat (TypeSize c)
=> KnownNat (TypeSize a)
=> Context i ~ ctx
=> Context f ~ ctx
=> Context c ~ ctx
Expand All @@ -163,22 +136,28 @@ ivcVerifierAc
=> Support m ~ Proxy ctx
=> Support a ~ Proxy ctx
=> Support y ~ Proxy ctx
=> Layout i ~ Vector (TypeSize i)
=> Layout f ~ Vector (TypeSize f)
=> Layout c ~ Vector (TypeSize c)
=> Layout m ~ Vector (TypeSize m)
=> Layout a ~ Vector (TypeSize a)
=> Layout y ~ Vector (TypeSize y)
=> ctx ~ ArithmeticCircuit a (Vector typeSize)
=> Representable (Layout i)
=> Representable (Layout c)
=> Representable (Layout f)
=> Representable (Layout a)
=> Representable (Layout m)
=> Ord (Rep (Layout i))
=> Ord (Rep (Layout c))
=> Ord (Rep (Layout f))
=> Ord (Rep (Layout a))
=> Ord (Rep (Layout m))
=> Layout y ~ Par1
=> t ~ ((i,c,(i,c,f,c,f),(i,c,f,c,f),c),(a,(f,f,f),(i,c,f,c,f),m),Proxy ctx)
=> ctx ~ ArithmeticCircuit a (Layout t)
=> Acc.AccumulatorScheme i f c m ctx a
=> y
ivcVerifierAc = compile (ivcVerifier @i @f @c @m @ctx @a)

iterate
:: forall ctx n comm a
. Symbolic ctx
=> KnownNat n
. KnownNat n
=> Arithmetic a
=> Binary a
=> Scale a (BaseField ctx)
=> FromConstant a (BaseField ctx)
=> Eq (Bool ctx) comm
Expand All @@ -194,7 +173,7 @@ iterate
=> Scale (FieldElement ctx) comm
=> ctx ~ ArithmeticCircuit a (Vector n)
=> SPS.Input (FieldElement ctx) (C n a) ~ Vector n (FieldElement ctx)
=> (Vector n (FieldElement ctx) -> Vector n (FieldElement ctx))
=> (forall c. (Symbolic c, BaseField c ~ a) => Vector n (FieldElement c) -> Vector n (FieldElement c))
-> Vector n a
-> Natural
-> ProtostarResult ctx n comm a
Expand All @@ -204,7 +183,7 @@ iterate f i0 n = iteration n ck f ac i0_arith i0 initialAccumulator (Acc.KeyScal
i0_arith = fromConstant <$> i0

ac :: C n a
ac = compile @a f
ac = hmap (fmap unPar1 . unComp1) $ hlmap ((:*: U1) . Comp1 . fmap Par1) (compile @a f)

initE = hcommit ck $ replicate (SPS.outputLength @(FieldElement ctx) ac) (zero :: FieldElement ctx)

Expand Down
9 changes: 5 additions & 4 deletions symbolic-base/src/ZkFold/Symbolic/Algorithms/Hash/MiMC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

module ZkFold.Symbolic.Algorithms.Hash.MiMC where

import Data.Foldable (toList)
import Data.List.NonEmpty (NonEmpty ((:|)), nonEmpty)
import Data.Proxy (Proxy (..))
import Numeric.Natural (Natural)
import Prelude hiding (Eq (..), Num (..), any, length, not, (!!), (/),
(^), (||))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Data.HFunctor (hmap)
import ZkFold.Base.Data.Package (unpacked)
import ZkFold.Base.Data.Vector (Vector, fromVector)
import ZkFold.Symbolic.Algorithms.Hash.MiMC.Constants (mimcConstants)
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Class
Expand Down Expand Up @@ -39,12 +40,12 @@ mimcHashN xs k = go
[zL, zR] -> mimcHash2 xs k zL zR
(zL:zR:zs') -> go (mimcHash2 xs k zL zR : zs')

hash :: forall context x a size .
hash :: forall context x a .
( Symbolic context
, SymbolicData x
, BaseField context ~ a
, Context x ~ context
, Support x ~ Proxy context
, Layout x ~ Vector size
, Foldable (Layout x)
) => x -> FieldElement context
hash = mimcHashN mimcConstants (zero :: a) . fromVector . fmap FieldElement . unpacked . flip pieces Proxy
hash = mimcHashN mimcConstants (zero :: a) . fmap FieldElement . unpacked . hmap toList . flip pieces Proxy
87 changes: 47 additions & 40 deletions symbolic-base/src/ZkFold/Symbolic/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import Data.Aeson (FromJSON, ToJSON)
import Data.Binary (Binary)
import Data.Function (const, (.))
import Data.Functor (($>))
import Data.Functor.Rep (Rep, Representable)
import Data.Ord (Ord)
import Data.Proxy (Proxy)
import Data.Traversable (for)
import Prelude (FilePath, IO, Monoid (mempty), Show (..), Traversable,
putStrLn, type (~), ($), (++))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Prelude (writeFileJSON)
import ZkFold.Symbolic.Class (Arithmetic, Symbolic (..))
import ZkFold.Symbolic.Compiler.ArithmeticCircuit
Expand All @@ -43,80 +43,87 @@ forceOne r = fromCircuitF r (\fi -> for fi $ \i -> constraint (\x -> x i - one)

-- | Arithmetizes an argument by feeding an appropriate amount of inputs.
solder ::
forall a c f ni .
( KnownNat ni
, c ~ ArithmeticCircuit a (Vector ni)
forall a c f s l .
( c ~ ArithmeticCircuit a l
, SymbolicData f
, Context f ~ c
, SymbolicData (Support f)
, Context (Support f) ~ c
, Support (Support f) ~ Proxy c
, Layout f ~ Vector (TypeSize f)
, Layout (Support f) ~ Vector ni
) => f -> c (Vector (TypeSize f))
, Support f ~ s
, SymbolicData s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Ord (Rep l)
) => f -> c (Layout f)
solder f = pieces f (restore @(Support f) $ const inputC)
where
inputC = mempty { acOutput = acInput }

-- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1.
compileForceOne ::
forall a c f y ni .
( KnownNat ni
, c ~ ArithmeticCircuit a (Vector ni)
forall a c f s l y .
( c ~ ArithmeticCircuit a l
, Arithmetic a
, Binary a
, SymbolicData f
, Context f ~ c
, SymbolicData (Support f)
, Context (Support f) ~ c
, Support (Support f) ~ Proxy c
, Support f ~ s
, SymbolicData s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Binary (Rep l)
, Ord (Rep l)
, SymbolicData y
, Context y ~ c
, Support y ~ Proxy c
, TypeSize f ~ TypeSize y
, Layout f ~ Vector (TypeSize y)
, Layout y ~ Vector (TypeSize y)
, Layout (Support f) ~ Vector ni
, Layout f ~ Layout y
, Traversable (Layout y)
) => f -> y
compileForceOne = restore . const . optimize . forceOne . solder @a

-- | Compiles function `f` into an arithmetic circuit.
compile ::
forall a c f y ni .
( KnownNat ni
, c ~ ArithmeticCircuit a (Vector ni)
forall a c f s l y .
( c ~ ArithmeticCircuit a l
, SymbolicData f
, Context f ~ c
, SymbolicData (Support f)
, Context (Support f) ~ c
, Support (Support f) ~ Proxy c
, Support f ~ s
, SymbolicData s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Ord (Rep l)
, SymbolicData y
, Context y ~ c
, Support y ~ Proxy c
, TypeSize f ~ TypeSize y
, Layout f ~ Vector (TypeSize y)
, Layout y ~ Vector (TypeSize y)
, Layout (Support f) ~ Vector ni
, Layout f ~ Layout y
) => f -> y
compile = restore . const . optimize . solder @a

-- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file.
compileIO ::
forall a c f ni .
( KnownNat ni
, c ~ ArithmeticCircuit a (Vector ni)
forall a c f s l .
( c ~ ArithmeticCircuit a l
, FromJSON a
, ToJSON a
, SymbolicData f
, Context f ~ c
, SymbolicData (Support f)
, Context (Support f) ~ c
, Support (Support f) ~ Proxy c
, Layout f ~ Vector (TypeSize f)
, Layout (Support f) ~ Vector ni
, Support f ~ s
, ToJSON (Layout f (Var a l))
, SymbolicData s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Ord (Rep l)
, FromJSON (Rep l)
, ToJSON (Rep l)
) => FilePath -> f -> IO ()
compileIO scriptFile f = do
let ac = optimize (solder @a f) :: c (Vector (TypeSize f))
let ac = optimize (solder @a f) :: c (Layout f)

putStrLn "\nCompiling the script...\n"

Expand Down
Loading

0 comments on commit 48d6e75

Please sign in to comment.