From 252e9f08dca29c36d63493e5a7eca4fd658d89b3 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Wed, 7 Aug 2024 11:00:15 -0700 Subject: [PATCH 01/48] almost there structured i/o --- .../Polynomials/Multivariate/Polynomial.hs | 6 +- src/ZkFold/Base/Data/Vector.hs | 11 ++ src/ZkFold/Base/Protocol/ARK/Protostar.hs | 2 +- src/ZkFold/Symbolic/Compiler.hs | 13 +- .../Symbolic/Compiler/ArithmeticCircuit.hs | 60 ++++--- .../Compiler/ArithmeticCircuit/Combinators.hs | 49 +++--- .../Compiler/ArithmeticCircuit/Instance.hs | 23 ++- .../Compiler/ArithmeticCircuit/Internal.hs | 163 +++++++++++------- .../Compiler/ArithmeticCircuit/Map.hs | 43 +++-- .../ArithmeticCircuit/MonadBlueprint.hs | 17 +- src/ZkFold/Symbolic/Data/ByteString.hs | 51 +++--- src/ZkFold/Symbolic/Data/Combinators.hs | 10 +- src/ZkFold/Symbolic/Data/UInt.hs | 127 ++++++++------ zkfold-base.cabal | 1 + 14 files changed, 322 insertions(+), 254 deletions(-) diff --git a/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs b/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs index 505cc41dc..4d1dec0df 100644 --- a/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs +++ b/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs @@ -48,9 +48,9 @@ evalPolynomial -> b evalPolynomial e f (P p) = foldr (\(c, m) x -> x + scale c (e f m)) zero p -variables :: forall c . - MultiplicativeMonoid c => - Poly c Natural Natural -> Set Natural +variables :: forall c v . + (Ord v, MultiplicativeMonoid c) => + Poly c v Natural -> Set v variables = runSources . evalPolynomial evalMonomial (Sources @c . singleton) mapVarPolynomial :: Variable i => Map i i-> Poly c i j -> Poly c i j diff --git a/src/ZkFold/Base/Data/Vector.hs b/src/ZkFold/Base/Data/Vector.hs index 52f08a674..a7c17fccd 100644 --- a/src/ZkFold/Base/Data/Vector.hs +++ b/src/ZkFold/Base/Data/Vector.hs @@ -9,6 +9,8 @@ import qualified Control.Monad as M import Control.Parallel.Strategies (parMap, rpar) import Data.Aeson (ToJSON (..)) import Data.Bifunctor (first) +import Data.Distributive (Distributive (..)) +import Data.Functor.Rep (Representable (..), distributeRep, collectRep) import qualified Data.List as List import Data.List.Split (chunksOf) import Data.These (These (..)) @@ -29,6 +31,15 @@ import ZkFold.Prelude (length, replicate) newtype Vector (size :: Natural) a = Vector [a] deriving (Show, Eq, Functor, Foldable, Traversable, Generic, NFData) +instance KnownNat size => Representable (Vector size) where + type Rep (Vector size) = Int + index (Vector v) ix = v Prelude.!! ix + tabulate f = Vector [f ix | ix <- [0 .. fromIntegral (value @size) Prelude.- 1]] + +instance KnownNat size => Distributive (Vector size) where + distribute = distributeRep + collect = collectRep + parFmap :: (a -> b) -> Vector size a -> Vector size b parFmap f (Vector lst) = Vector $ parMap rpar f lst diff --git a/src/ZkFold/Base/Protocol/ARK/Protostar.hs b/src/ZkFold/Base/Protocol/ARK/Protostar.hs index 456648f7d..679245ce3 100644 --- a/src/ZkFold/Base/Protocol/ARK/Protostar.hs +++ b/src/ZkFold/Base/Protocol/ARK/Protostar.hs @@ -42,7 +42,7 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal data RecursiveCircuit n a = RecursiveCircuit { iterations :: Natural - , circuit :: ArithmeticCircuit a (Vector n) + , circuit :: ArithmeticCircuit a (Vector n) (Vector n) } deriving (Generic, NFData) instance Arithmetic a => SpecialSoundProtocol a (RecursiveCircuit n a) where diff --git a/src/ZkFold/Symbolic/Compiler.hs b/src/ZkFold/Symbolic/Compiler.hs index 9d41151c8..02ad4b119 100644 --- a/src/ZkFold/Symbolic/Compiler.hs +++ b/src/ZkFold/Symbolic/Compiler.hs @@ -17,7 +17,7 @@ import Prelude (FilePat import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number -import ZkFold.Base.Data.Vector (Vector, unsafeToVector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Prelude (writeFileJSON) import ZkFold.Symbolic.Class (Arithmetic) import ZkFold.Symbolic.Compiler.ArithmeticCircuit @@ -40,7 +40,7 @@ solder :: forall a c f . ( Eq a , MultiplicativeMonoid a - , c ~ ArithmeticCircuit a + , c ~ ArithmeticCircuit a (Vector (TypeSize c (Support c f))) , SymbolicData c f , SymbolicData c (Support c f) , Support c (Support c f) ~ () @@ -48,13 +48,12 @@ solder :: ) => f -> c (Vector (TypeSize c f)) solder f = pieces f (restore @c @(Support c f) $ const inputC) where - inputList = [1..(typeSize @c @(Support c f))] - inputC = mempty { acInput = inputList, acOutput = unsafeToVector inputList } + inputC = mempty { acOutput = acInput } -- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1. compileForceOne :: forall a c f y . - ( c ~ ArithmeticCircuit a + ( c ~ ArithmeticCircuit a (Vector (TypeSize c (Support c f))) , Arithmetic a , SymbolicData c f , SymbolicData c (Support c f) @@ -71,7 +70,7 @@ compile :: forall a c f y . ( Eq a , MultiplicativeMonoid a - , c ~ ArithmeticCircuit a + , c ~ ArithmeticCircuit a (Vector (TypeSize c (Support c f))) , SymbolicData c f , SymbolicData c (Support c f) , Support c (Support c f) ~ () @@ -87,7 +86,7 @@ compileIO :: forall a c f . ( Eq a , MultiplicativeMonoid a - , c ~ ArithmeticCircuit a + , c ~ ArithmeticCircuit a (Vector (TypeSize c (Support c f))) , ToJSON a , SymbolicData c f , SymbolicData c (Support c f) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 219d2621d..86e8b5fb1 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -1,11 +1,11 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE TypeOperators #-} module ZkFold.Symbolic.Compiler.ArithmeticCircuit ( ArithmeticCircuit, Constraint, witnessGenerator, -- high-level functions - applyArgs, optimize, desugarRanges, -- low-level functions @@ -32,14 +32,16 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit ( ) where import Control.Monad.State (execState) +import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldr, map, null, splitAt, take) +import Data.Void (absurd) import GHC.Generics (U1 (..)) import Numeric.Natural (Natural) import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) -import Test.QuickCheck (Arbitrary, Property, conjoin, property, vector, - withMaxSuccess, (===)) +import Test.QuickCheck (Arbitrary, Property, conjoin, property, + withMaxSuccess, (===), arbitrary) import Text.Pretty.Simple (pPrint) import ZkFold.Base.Algebra.Basic.Class @@ -48,52 +50,50 @@ import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (desugarRange) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, - apply, eval, eval1, exec, exec1, - witnessGenerator) + eval, eval1, exec, exec1, + witnessGenerator, Var (..), acInput) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map --------------------------------- High-level functions -------------------------------- -- TODO: make this work for different input types. -applyArgs :: ArithmeticCircuit a f -> [a] -> ArithmeticCircuit a f -applyArgs r args = - (execState (apply args) r {acOutput = U1}) {acOutput = acOutput r} +-- applyArgs :: ArithmeticCircuit a (i :*: j) o -> i a -> ArithmeticCircuit a j o +-- applyArgs r args = (apply args r{acOutput = U1}) {acOutput = fmap _ (acOutput r)} -- | Optimizes the constraint system. -- -- TODO: Implement nontrivial optimizations. -optimize :: ArithmeticCircuit a f -> ArithmeticCircuit a f +optimize :: ArithmeticCircuit a i o -> ArithmeticCircuit a i o optimize = id -- | Desugars range constraints into polynomial constraints -desugarRanges :: Arithmetic a => ArithmeticCircuit a f -> ArithmeticCircuit a f +desugarRanges :: (Arithmetic a, Ord (Rep i), Representable i) => ArithmeticCircuit a i o -> ArithmeticCircuit a i o desugarRanges c = - let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ toList (acRange c) + let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(NewVar k, v) | (k,v) <- toList (acRange c)] in r' { acRange = mempty, acOutput = acOutput c } ----------------------------------- Information ----------------------------------- -- | Calculates the number of constraints in the system. -acSizeN :: ArithmeticCircuit a f -> Natural +acSizeN :: ArithmeticCircuit a i o -> Natural acSizeN = length . acSystem -- | Calculates the number of variables in the system. -- The constant `1` is not counted. -acSizeM :: ArithmeticCircuit a f -> Natural +acSizeM :: ArithmeticCircuit a i o -> Natural acSizeM = length . acVarOrder -acValue :: Functor f => ArithmeticCircuit a f -> f a -acValue r = eval r mempty +acValue :: Functor o => ArithmeticCircuit a U1 o -> o a +acValue r = eval r U1 -- | Prints the constraint system, the witness, and the output. -- -- TODO: Move this elsewhere (?) -- TODO: Check that all arguments have been applied. -acPrint :: (Show a, Show (f Natural), Show (f a), Functor f) => ArithmeticCircuit a f -> IO () +acPrint :: (Show a, Show (o (Var U1)), Show (o a), Functor o) => ArithmeticCircuit a U1 o -> IO () acPrint ac = do let m = elems (acSystem ac) - i = acInput ac - w = witnessGenerator ac empty + w = witnessGenerator ac U1 v = acValue ac vo = acVarOrder ac o = acOutput ac @@ -103,8 +103,6 @@ acPrint ac = do pPrint $ acSizeM ac putStr "Matrices: " pPrint m - putStr "Input: " - pPrint i putStr "Witness: " pPrint w putStr "Variable order: " @@ -121,23 +119,29 @@ checkClosedCircuit . Arithmetic a => Scale a a => Show a - => ArithmeticCircuit a n + => ArithmeticCircuit a U1 n -> Property checkClosedCircuit c = withMaxSuccess 1 $ conjoin [ testPoly p | p <- elems (acSystem c) ] where - w = witnessGenerator c empty - testPoly p = evalPolynomial evalMonomial (w !) p === zero + w = witnessGenerator c U1 + testPoly p = evalPolynomial evalMonomial varF p === zero + varF (NewVar v) = w ! v + varF (InVar v) = absurd v checkCircuit - :: Arbitrary a + :: Arbitrary (i a) => Arithmetic a => Scale a a => Show a - => ArithmeticCircuit a n + => Representable i + => ArithmeticCircuit a i n -> Property checkCircuit c = conjoin [ property (testPoly p) | p <- elems (acSystem c) ] where testPoly p = do - ins <- vector . fromIntegral $ length (acInput c) - let w = witnessGenerator c . fromList $ zip (acInput c) ins - return $ evalPolynomial evalMonomial (w !) p === zero + ins <- arbitrary + let w = witnessGenerator c ins + varF (NewVar v) = w ! v + varF (InVar v) = index ins v + return $ evalPolynomial evalMonomial varF p === zero + diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs index 9e985e6ff..928b2ecaf 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs @@ -24,14 +24,15 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators ( import Control.Monad (foldM, replicateM) import Data.Containers.ListUtils (nubOrd) import Data.Eq ((==)) -import Data.Foldable (foldlM) +import Data.Foldable (foldlM, toList) import Data.Functor (($>)) +import Data.Functor.Rep (Representable (..)) import Data.List (sort) import Data.Map (elems) import Data.Traversable (for) import qualified Data.Zip as Z import GHC.Generics (Par1) -import GHC.IsList (IsList (..)) +-- import GHC.IsList (IsList (..)) import Prelude hiding (Bool, Eq (..), drop, length, negate, splitAt, take, (!!), (*), (+), (-), (^)) @@ -41,17 +42,17 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (vari import qualified ZkFold.Base.Data.Vector as V import ZkFold.Base.Data.Vector (Vector (..)) import ZkFold.Prelude (drop, length, take, (!!)) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), acInput) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), acInput, Var (..)) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint import ZkFold.Symbolic.MonadCircuit -boolCheckC :: (Arithmetic a, Traversable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f +boolCheckC :: (Arithmetic a, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f -- ^ @boolCheckC r@ computes @r (r - 1)@ in one PLONK constraint. boolCheckC r = circuitF $ do is <- runCircuit r for is $ \i -> newAssigned (\x -> let xi = x i in xi * (xi - one)) -foldCircuit :: forall n a. Arithmetic a => (forall i m . MonadBlueprint i a m => i -> i -> m i) -> ArithmeticCircuit a (Vector n) -> ArithmeticCircuit a Par1 +foldCircuit :: forall n i a. (Arithmetic a, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => v -> v -> m v) -> ArithmeticCircuit a i (Vector n) -> ArithmeticCircuit a i Par1 foldCircuit f c = circuit $ do outputs <- runCircuit c let (element, rest) = V.uncons outputs @@ -59,19 +60,19 @@ foldCircuit f c = circuit $ do -- | TODO: Think about circuits with multiple outputs -- -embed :: Arithmetic a => a -> ArithmeticCircuit a Par1 +embed :: (Arithmetic a, Ord (Rep i), Representable i) => a -> ArithmeticCircuit a i Par1 embed x = circuit $ newAssigned $ const (fromConstant x) -embedV :: (Arithmetic a, Traversable f) => f a -> ArithmeticCircuit a f +embedV :: (Arithmetic a, Traversable f, Ord (Rep i), Representable i) => f a -> ArithmeticCircuit a i f embedV v = circuitF $ for v $ \x -> newAssigned $ const (fromConstant x) -embedVar :: forall a . a -> (forall i m . MonadBlueprint i a m => m i) +embedVar :: forall a . a -> (forall i v m . MonadBlueprint i v a m => m v) embedVar x = newAssigned $ const (fromConstant x) -embedAll :: forall a n . (Arithmetic a, KnownNat n) => a -> ArithmeticCircuit a (Vector n) +embedAll :: forall a i n . (Arithmetic a, KnownNat n, Ord (Rep i), Representable i) => a -> ArithmeticCircuit a i (Vector n) embedAll x = circuitF $ Vector <$> replicateM (fromIntegral $ value @n) (newAssigned $ const (fromConstant x)) -expansion :: MonadCircuit i a m => Natural -> i -> m [i] +expansion :: MonadCircuit v a m => Natural -> v -> m [v] -- ^ @expansion n k@ computes a binary expansion of @k@ if it fits in @n@ bits. expansion n k = do bits <- bitsOf n k @@ -79,7 +80,7 @@ expansion n k = do constraint (\x -> x k - x k') return bits -splitExpansion :: (MonadCircuit i a m, Arithmetic a) => Natural -> Natural -> i -> m (i, i) +splitExpansion :: (MonadCircuit v a m, Arithmetic a) => Natural -> Natural -> v -> m (v, v) -- ^ @splitExpansion n1 n2 k@ computes two values @(l, h)@ such that -- @k = 2^n1 h + l@, @l@ fits in @n1@ bits and @h@ fits in n2 bits (if such -- values exist). @@ -93,7 +94,7 @@ splitExpansion n1 n2 k = do repr :: forall b . (BinaryExpansion b, Bits b ~ [b]) => b -> [b] repr = padBits (n1 + n2) . binaryExpansion -bitsOf :: MonadCircuit i a m => Natural -> i -> m [i] +bitsOf :: MonadCircuit v a m => Natural -> v -> m [v] -- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller -- bits of @k@. bitsOf n k = for [0 .. n -! 1] $ \j -> @@ -102,14 +103,14 @@ bitsOf n k = for [0 .. n -! 1] $ \j -> repr :: forall b . (BinaryExpansion b, Bits b ~ [b], Finite b) => b -> [b] repr = padBits (numberOfBits @b) . binaryExpansion -horner :: MonadCircuit i a m => [i] -> m i +horner :: MonadCircuit v a m => [v] -> m v -- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using -- Horner's scheme. horner xs = case reverse xs of [] -> newAssigned (const zero) (b : bs) -> foldlM (\a i -> newAssigned (\x -> let xa = x a in x i + xa + xa)) b bs -desugarRange :: (Arithmetic a, MonadBlueprint i a m) => i -> a -> m () +desugarRange :: (Arithmetic a, MonadBlueprint i v a m) => v -> a -> m () desugarRange i b | b == negate one = return () | otherwise = do @@ -125,28 +126,28 @@ desugarRange i b | c == zero = ($ j) * (one - ($ k)) | otherwise = one + ($ k) * (($ j) - one) -forceOne :: (Arithmetic a, Traversable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f +forceOne :: (Arithmetic a, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f forceOne r = circuitF $ do is' <- runCircuit r for is' $ \i -> constraint (\x -> x i - one) $> i -isZeroC :: (Arithmetic a, Z.Zip f, Traversable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f +isZeroC :: (Arithmetic a, Z.Zip f, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f isZeroC r = circuitF $ runCircuit r >>= fmap fst . runInvert -invertC :: (Arithmetic a, Z.Zip f, Traversable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f +invertC :: (Arithmetic a, Z.Zip f, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f invertC r = circuitF $ runCircuit r >>= fmap snd . runInvert -runInvert :: (MonadCircuit i a m, Z.Zip f, Traversable f) => f i -> m (f i, f i) +runInvert :: (MonadCircuit v a m, Z.Zip f, Traversable f) => f v -> m (f v, f v) runInvert is = do js <- for is $ \i -> newConstrained (\x j -> x i * x j) (\x -> let xi = x i in one - xi // xi) ks <- for (Z.zip is js) $ \(i, j) -> newConstrained (\x k -> x i * x k + x j - one) (finv . ($ i)) return (js, ks) -embedVarIndex :: Arithmetic a => Natural -> ArithmeticCircuit a Par1 -embedVarIndex n = mempty { acInput = [ n ], acOutput = pure n} +embedVarIndex :: Arithmetic a => Rep i -> ArithmeticCircuit a i Par1 +embedVarIndex n = mempty { acOutput = pure (InVar n)} -embedVarIndexV :: (Arithmetic a, KnownNat n) => Natural -> ArithmeticCircuit a (Vector n) -embedVarIndexV n = mempty { acInput = [ n ], acOutput = pure n} +embedVarIndexV :: (Arithmetic a, KnownNat n) => Rep i -> ArithmeticCircuit a i (Vector n) +embedVarIndexV n = mempty { acOutput = pure (InVar n)} -getAllVars :: MultiplicativeMonoid a => ArithmeticCircuit a o -> [Natural] -getAllVars ac = nubOrd $ sort $ 0 : acInput ac ++ concatMap (toList . variables) (elems $ acSystem ac) +getAllVars :: (MultiplicativeMonoid a, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuit a i o -> [Var i] +getAllVars ac = nubOrd $ sort $ NewVar 0 : toList acInput ++ concatMap (toList . variables) (elems $ acSystem ac) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs index e97b60a12..124897b16 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs @@ -7,15 +7,15 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance where import Data.Aeson hiding (Bool) +import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldl', foldr, map, null, splitAt, take) import GHC.Generics (Par1 (..)) -import GHC.Num (integerToNatural) import Prelude (Show, mempty, pure, return, show, ($), (++), (<$>)) import qualified Prelude as Haskell import System.Random (mkStdGen) -import Test.QuickCheck (Arbitrary (arbitrary), Gen, chooseInteger, +import Test.QuickCheck (Arbitrary (arbitrary), Gen, elements) import ZkFold.Base.Algebra.Basic.Class @@ -27,13 +27,13 @@ import ZkFold.Symbolic.Data.FieldElement (FieldEl ------------------------------------- Instances ------------------------------------- -instance (Arithmetic a, Arbitrary a) => Arbitrary (ArithmeticCircuit a Par1) where +instance (Arithmetic a, Arbitrary a, Arbitrary (Rep i), Haskell.Ord (Rep i), Representable i, Haskell.Foldable i) => Arbitrary (ArithmeticCircuit a i Par1) where arbitrary = do - k <- integerToNatural <$> chooseInteger (2, 10) - let ac = mempty { acInput = [1..k], acOutput = pure k } + outVar <- InVar <$> arbitrary + let ac = mempty {acOutput = Par1 outVar} fromFieldElement <$> arbitrary' (FieldElement ac) 10 -arbitrary' :: forall a . (Arithmetic a, Arbitrary a, FromConstant a a) => FieldElement (ArithmeticCircuit a) -> Natural -> Gen (FieldElement (ArithmeticCircuit a)) +arbitrary' :: forall a i . (Arithmetic a, Arbitrary a, FromConstant a a, Haskell.Ord (Rep i), Representable i, Haskell.Foldable i) => FieldElement (ArithmeticCircuit a i) -> Natural -> Gen (FieldElement (ArithmeticCircuit a i)) arbitrary' ac 0 = return ac arbitrary' ac iter = do let vars = getAllVars (fromFieldElement ac) @@ -50,28 +50,25 @@ arbitrary' ac iter = do arbitrary' ac' (iter -! 1) -- TODO: make it more readable -instance (FiniteField a, Haskell.Eq a, Show a, Show (f Natural)) => Show (ArithmeticCircuit a f) where - show r = "ArithmeticCircuit { acInput = " ++ show (acInput r) - ++ "\n, acSystem = " ++ show (acSystem r) ++ "\n, acOutput = " ++ show (acOutput r) ++ "\n, acVarOrder = " ++ show (acVarOrder r) ++ " }" +instance (FiniteField a, Haskell.Eq a, Show a, Show (o (Var i)), Haskell.Ord (Rep i), Show (Var i)) => Show (ArithmeticCircuit a i o) where + show r = "ArithmeticCircuit { acSystem = " ++ show (acSystem r) ++ "\n, acOutput = " ++ show (acOutput r) ++ "\n, acVarOrder = " ++ show (acVarOrder r) ++ " }" -- TODO: add witness generation info to the JSON object -instance (ToJSON a, ToJSON (f Natural)) => ToJSON (ArithmeticCircuit a f) where +instance (ToJSON a, ToJSON (o (Var i)), ToJSONKey (Var i), FromJSONKey (Var i)) => ToJSON (ArithmeticCircuit a i o) where toJSON r = object [ "system" .= acSystem r, - "input" .= acInput r, "output" .= acOutput r, "order" .= acVarOrder r ] -- TODO: properly restore the witness generation function -- TODO: Check that there are exactly N outputs -instance (FromJSON a, FromJSON (f Natural)) => FromJSON (ArithmeticCircuit a f) where +instance (FromJSON a, FromJSON (o (Var i)), ToJSONKey (Var i), FromJSONKey (Var i), Haskell.Ord (Rep i)) => FromJSON (ArithmeticCircuit a i o) where parseJSON = withObject "ArithmeticCircuit" $ \v -> do acSystem <- v .: "system" acRange <- v .: "range" - acInput <- v .: "input" acVarOrder <- v .: "order" acOutput <- v .: "output" let acWitness = empty diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 2da40d38e..d2bb650f3 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -1,10 +1,13 @@ {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( ArithmeticCircuit(..), + Var (..), + acInput, Arithmetic, ConstraintMonomial, Constraint, @@ -24,12 +27,14 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( import Control.DeepSeq (NFData, force) import Control.Monad.State (MonadState (..), State, gets, modify, runState) +import Data.Aeson (ToJSON, ToJSONKey, FromJSON, FromJSONKey) import Data.Foldable (fold) +import Data.Functor.Rep (Representable (..), fmapRep) import Data.Map.Strict hiding (drop, foldl, foldr, map, null, splitAt, take) import qualified Data.Map.Strict as M import Data.Semialign (unzipDefault) import qualified Data.Set as S -import GHC.Generics (Generic, Par1 (..), U1 (..)) +import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..)) import Optics import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) @@ -45,63 +50,77 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (Mono, Poly, evalM import ZkFold.Base.Control.HApplicative import ZkFold.Base.Data.HFunctor import ZkFold.Base.Data.Package -import ZkFold.Prelude (drop, length) import ZkFold.Symbolic.Class import ZkFold.Symbolic.MonadCircuit -- | Arithmetic circuit in the form of a system of polynomial constraints. -data ArithmeticCircuit a o = ArithmeticCircuit +data ArithmeticCircuit a i o = ArithmeticCircuit { - acSystem :: Map Natural (Constraint a), + acSystem :: Map Natural (Constraint a i), -- ^ The system of polynomial constraints acRange :: Map Natural a, -- ^ The range constraints [0, a] for the selected variables - acInput :: [Natural], - -- ^ The input variables - acWitness :: Map Natural (Map Natural a -> a), + acWitness :: Map Natural (i a -> a), -- ^ The witness generation functions acVarOrder :: Map (Natural, Natural) Natural, -- ^ The order of variable assignments acRNG :: StdGen, -- ^ random generator for generating unique variables - acOutput :: o Natural + acOutput :: o (Var i) -- ^ The output variables } deriving (Generic) -deriving instance (NFData a, NFData (o Natural)) - => NFData (ArithmeticCircuit a o) - -witnessGenerator :: ArithmeticCircuit a o -> Map Natural a -> Map Natural a +deriving instance (NFData a, NFData (o (Var i)), NFData (Rep i)) + => NFData (ArithmeticCircuit a i o) + +acInput :: Representable i => i (Var i) +acInput = fmapRep InVar (tabulate id) + +data Var i + = InVar (Rep i) + | NewVar Natural + deriving Generic +deriving anyclass instance FromJSON (Rep i) => FromJSON (Var i) +deriving anyclass instance FromJSON (Rep i) => FromJSONKey (Var i) +deriving anyclass instance ToJSON (Rep i) => ToJSONKey (Var i) +deriving anyclass instance ToJSON (Rep i) => ToJSON (Var i) +deriving stock instance Show (Rep i) => Show (Var i) +deriving stock instance Eq (Rep i) => Eq (Var i) +deriving stock instance Ord (Rep i) => Ord (Var i) +deriving instance NFData (Rep i) => NFData (Var i) + +witnessGenerator :: ArithmeticCircuit a i o -> i a -> Map Natural a witnessGenerator circuit inputs = - let srcs = acWitness circuit - witness = ($ witness) <$> (srcs `union` fmap const inputs) - in witness + fmap ($ inputs) (acWitness circuit) +-- let srcs = acWitness circuit +-- witness = ($ witness) <$> (srcs `union` fmap const inputs) +-- in witness ------------------------------ Symbolic compiler context ---------------------------- -crown :: ArithmeticCircuit a g -> f Natural -> ArithmeticCircuit a f +crown :: ArithmeticCircuit a i g -> f (Var i) -> ArithmeticCircuit a i f crown = flip (set #acOutput) -behead :: ArithmeticCircuit a f -> (ArithmeticCircuit a U1, f Natural) +behead :: ArithmeticCircuit a i f -> (ArithmeticCircuit a i U1, f (Var i)) behead = liftA2 (,) (set #acOutput U1) acOutput -instance HFunctor (ArithmeticCircuit a) where +instance HFunctor (ArithmeticCircuit a i) where hmap = over #acOutput -instance (Eq a, MultiplicativeMonoid a) => HApplicative (ArithmeticCircuit a) where +instance (Eq a, MultiplicativeMonoid a) => HApplicative (ArithmeticCircuit a i) where hpure = crown mempty hliftA2 f (behead -> (c, o)) (behead -> (d, p)) = crown (c <> d) (f o p) -instance (Eq a, MultiplicativeMonoid a) => Package (ArithmeticCircuit a) where +instance (Eq a, MultiplicativeMonoid a) => Package (ArithmeticCircuit a i) where unpackWith f (behead -> (c, o)) = crown c <$> f o packWith f (unzipDefault . fmap behead -> (cs, os)) = crown (fold cs) (f os) -instance Arithmetic a => Symbolic (ArithmeticCircuit a) where - type BaseField (ArithmeticCircuit a) = a +instance (Arithmetic a, Ord (Rep i), Representable i) => Symbolic (ArithmeticCircuit a i) where + type BaseField (ArithmeticCircuit a i) = a symbolicF (behead -> (c, o)) _ f = uncurry (set #acOutput) (runState (f o) c) -------------------------------- MonadCircuit instance ------------------------------ -instance (Arithmetic a, o ~ U1) => MonadCircuit Natural a (State (ArithmeticCircuit a o)) where +instance (Arithmetic a, Ord (Rep i), Representable i, o ~ U1) => MonadCircuit (Var i) a (State (ArithmeticCircuit a i o)) where newRanged upperBound witness = do let s = sources @a witness b = fromConstant upperBound @@ -111,23 +130,31 @@ instance (Arithmetic a, o ~ U1) => MonadCircuit Natural a (State (ArithmeticCirc p i = b * var i * (var i - b) i <- addVariable =<< newVariableWithSource (S.toList s) p rangeConstraint i upperBound - assignment i (\m -> witness (m !)) - return i + currentWitness <- gets acWitness + assignment i $ \m -> witness $ \case + InVar inV -> index m inV + NewVar newV -> (currentWitness ! newV) m + return (NewVar i) newConstrained - :: NewConstraint Natural a - -> Witness Natural a - -> State (ArithmeticCircuit a U1) Natural + :: NewConstraint (Var i) a + -> Witness (Var i) a + -> State (ArithmeticCircuit a i U1) (Var i) newConstrained new witness = do let ws = sources @a witness + varF (NewVar v) = NewVar (v + 1) + varF (InVar v) = InVar v -- | We need a throwaway variable to feed into `new` which definitely would not be present in a witness - x = maximum (S.mapMonotonic (+1) ws <> S.singleton 0) + x = maximum (S.mapMonotonic varF ws <> S.singleton (NewVar 0)) -- | `s` is meant to be a set of variables used in a witness not present in a constraint. s = ws `S.difference` sources @a (`new` x) i <- addVariable =<< newVariableWithSource (S.toList s) (new var) - constraint (`new` i) - assignment i (\m -> witness (m !)) - return i + constraint (`new` (NewVar i)) + currentWitness <- gets acWitness + assignment i $ \m -> witness $ \case + InVar inV -> index m inV + NewVar newV -> (currentWitness ! newV) m + return (NewVar i) constraint p = addConstraint (p var) @@ -136,7 +163,7 @@ sources = runSources . ($ Sources @a . S.singleton) ----------------------------------- Circuit monoid ---------------------------------- -instance (Eq a, o ~ U1) => Semigroup (ArithmeticCircuit a o) where +instance (Eq a, o ~ U1) => Semigroup (ArithmeticCircuit a i o) where c1 <> c2 = ArithmeticCircuit { @@ -144,25 +171,18 @@ instance (Eq a, o ~ U1) => Semigroup (ArithmeticCircuit a o) where , acRange = acRange c1 `union` acRange c2 -- NOTE: is it possible that we get a wrong argument order when doing `apply` because of this concatenation? -- We need a way to ensure the correct order no matter how `(<>)` is used. - , acInput = nubConcat (acInput c1) (acInput c2) , acWitness = acWitness c1 `union` acWitness c2 , acVarOrder = acVarOrder c1 `union` acVarOrder c2 , acRNG = mkStdGen $ fst (uniform (acRNG c1)) Haskell.* fst (uniform (acRNG c2)) , acOutput = U1 } -nubConcat :: Ord a => [a] -> [a] -> [a] -nubConcat l r = l ++ Prelude.filter (`S.notMember` lSet) r - where - lSet = S.fromList l - -instance (Eq a, MultiplicativeMonoid a, o ~ U1) => Monoid (ArithmeticCircuit a o) where +instance (Eq a, MultiplicativeMonoid a, o ~ U1) => Monoid (ArithmeticCircuit a i o) where mempty = ArithmeticCircuit { acSystem = empty, acRange = empty, - acInput = [], acWitness = singleton 0 one, acVarOrder = empty, acRNG = mkStdGen 0, @@ -179,20 +199,22 @@ toField :: Arithmetic a => a -> VarField toField = toZp . fromConstant . fromBinary @Natural . castBits . binaryExpansion -- TODO: Remove the hardcoded constant. -toVar :: Arithmetic a => [Natural] -> Constraint a -> Natural +toVar :: Arithmetic a => [Var i] -> Constraint a i -> Natural toVar srcs c = force $ fromZp ex where r = toZp 903489679376934896793395274328947923579382759823 :: VarField g = toZp 89175291725091202781479751781509570912743212325 :: VarField - v = (+ r) . fromConstant + varF (NewVar w) = w + varF (InVar _) = 0 + v = (+ r) . fromConstant . varF x = g ^ fromZp (evalPolynomial evalMonomial v $ mapCoeffs toField c) - ex = foldr (\p y -> x ^ p + y) x srcs + ex = foldr (\p y -> x ^ (varF p) + y) x srcs -newVariableWithSource :: Arithmetic a => [Natural] -> (Natural -> Constraint a) -> State (ArithmeticCircuit a U1) Natural -newVariableWithSource srcs con = toVar srcs . con . fst <$> do +newVariableWithSource :: Arithmetic a => [Var i] -> (Var i -> Constraint a i) -> State (ArithmeticCircuit a i U1) Natural +newVariableWithSource srcs con = toVar srcs . con . NewVar . fst <$> do zoom #acRNG $ get >>= traverse put . uniformR (0, order @VarField -! 1) -addVariable :: Natural -> State (ArithmeticCircuit a U1) Natural +addVariable :: Natural -> State (ArithmeticCircuit a i U1) Natural addVariable x = do zoom #acVarOrder . modify $ \vo -> insert (Haskell.fromIntegral $ M.size vo, x) x vo @@ -203,43 +225,56 @@ addVariable x = do type ConstraintMonomial = Mono Natural Natural -- | The type that represents a constraint in the arithmetic circuit. -type Constraint c = Poly c Natural Natural +type Constraint c i = Poly c (Var i) Natural -- | Adds a constraint to the arithmetic circuit. -addConstraint :: Arithmetic a => Constraint a -> State (ArithmeticCircuit a U1) () +addConstraint :: Arithmetic a => Constraint a i -> State (ArithmeticCircuit a i U1) () addConstraint c = zoom #acSystem . modify $ insert (toVar [] c) c -rangeConstraint :: Natural -> a -> State (ArithmeticCircuit a U1) () +rangeConstraint :: Natural -> a -> State (ArithmeticCircuit a i U1) () rangeConstraint i b = zoom #acRange . modify $ insert i b -- | Adds a new variable assignment to the arithmetic circuit. -- TODO: forbid reassignment of variables -assignment :: Natural -> (Map Natural a -> a) -> State (ArithmeticCircuit a U1) () +assignment :: Natural -> (i a -> a) -> State (ArithmeticCircuit a i U1) () assignment i f = zoom #acWitness . modify $ insert i f -- | Evaluates the arithmetic circuit with one output using the supplied input map. -eval1 :: ArithmeticCircuit a Par1 -> Map Natural a -> a -eval1 ctx i = witnessGenerator ctx i ! unPar1 (acOutput ctx) +eval1 :: Representable i => ArithmeticCircuit a i Par1 -> i a -> a +eval1 ctx i = case unPar1 (acOutput ctx) of + NewVar k -> witnessGenerator ctx i ! k + InVar j -> index i j -- | Evaluates the arithmetic circuit using the supplied input map. -eval :: Functor o => ArithmeticCircuit a o -> Map Natural a -> o a -eval ctx i = (witnessGenerator ctx i !) <$> acOutput ctx +eval :: (Representable i, Functor o) => ArithmeticCircuit a i o -> i a -> o a +eval ctx i = acOutput ctx <&> \case + NewVar k -> witnessGenerator ctx i ! k + InVar j -> index i j -- | Evaluates the arithmetic circuit with no inputs and one output using the supplied input map. -exec1 :: ArithmeticCircuit a Par1 -> a -exec1 ac = eval1 ac empty +exec1 :: ArithmeticCircuit a U1 Par1 -> a +exec1 ac = eval1 ac U1 -- | Evaluates the arithmetic circuit with no inputs using the supplied input map. -exec :: Functor o => ArithmeticCircuit a o -> o a -exec ac = eval ac empty +exec :: Functor o => ArithmeticCircuit a U1 o -> o a +exec ac = eval ac U1 -- | Applies the values of the first `n` inputs to the arithmetic circuit. -- TODO: make this safe -apply :: [a] -> State (ArithmeticCircuit a U1) () -apply xs = do - inputs <- gets acInput - zoom #acInput . put $ drop (length xs) inputs - zoom #acWitness . modify . union . fromList $ zip inputs (map const xs) +apply :: (Eq a, Field a, Ord (Rep j), Scale a a, FromConstant a a, Representable i) => i a -> ArithmeticCircuit a (i :*: j) U1 -> ArithmeticCircuit a j U1 +apply xs ac = ac + { acSystem = fmap (evalPolynomial evalMonomial varF) (acSystem ac) + , acWitness = fmap witF (acWitness ac) + , acOutput = U1 + } + where + varF (InVar (Left v)) = fromConstant (index xs v) + varF (InVar (Right v)) = var (InVar v) + varF (NewVar v) = var (NewVar v) + witF f j = f (xs :*: j) + + -- let inputs = acInput + -- zoom #acWitness . modify . union . fromList $ zip inputs (map const xs) -- TODO: Add proper symbolic application functions diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index b789cc352..dabb71d17 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -6,58 +6,55 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map ( ArithmeticCircuitTest(..) ) where +import Data.Traversable (for) +import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldr, fromList, map, null, splitAt, take, toList) import qualified Data.Map as Map import GHC.Generics (Par1) import GHC.IsList (IsList (..)) -import GHC.Natural (naturalToInteger) -import GHC.Num (integerToInt) -import Numeric.Natural (Natural) import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) -import Test.QuickCheck (Arbitrary (arbitrary), Gen, vector) +import Test.QuickCheck (Arbitrary (arbitrary), Gen) -import ZkFold.Base.Algebra.Basic.Class (MultiplicativeMonoid (..)) +import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate -import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (getAllVars) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), acInput, Var (..)) -- This module contains functions for mapping variables in arithmetic circuits. -data ArithmeticCircuitTest a f = ArithmeticCircuitTest +data ArithmeticCircuitTest a i o = ArithmeticCircuitTest { - arithmeticCircuit :: ArithmeticCircuit a f - , witnessInput :: Map.Map Natural a + arithmeticCircuit :: ArithmeticCircuit a i o + , witnessInput :: i a } -instance (Show (ArithmeticCircuit a f), Show a) => Show (ArithmeticCircuitTest a f) where +instance (Show (ArithmeticCircuit a i o), Show a, Show (i a)) => Show (ArithmeticCircuitTest a i o) where show (ArithmeticCircuitTest ac wi) = show ac ++ ",\nwitnessInput: " ++ show wi -instance (Arithmetic a, Arbitrary a, Arbitrary (ArithmeticCircuit a Par1)) => Arbitrary (ArithmeticCircuitTest a Par1) where - arbitrary :: Gen (ArithmeticCircuitTest a Par1) +instance (Arithmetic a, Arbitrary a, Arbitrary (ArithmeticCircuit a i Par1), Traversable i, Representable i) => Arbitrary (ArithmeticCircuitTest a i Par1) where + arbitrary :: Gen (ArithmeticCircuitTest a i Par1) arbitrary = do ac <- arbitrary - let keysAC = acInput ac - values <- vector . integerToInt . naturalToInteger . length $ keysAC - let wi = fromList $ zip keysAC values + wi <- for acInput $ \_ -> arbitrary return ArithmeticCircuitTest { arithmeticCircuit = ac , witnessInput = wi } -mapVarArithmeticCircuit :: (MultiplicativeMonoid a, Functor f) => ArithmeticCircuitTest a f -> ArithmeticCircuitTest a f +mapVarArithmeticCircuit :: (Field a, Scale a a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o mapVarArithmeticCircuit (ArithmeticCircuitTest ac wi) = - let vars = getAllVars ac + let vars = [v | NewVar v <- getAllVars ac] forward = Map.fromAscList $ zip vars [0..] backward = Map.fromAscList $ zip [0..] vars + varF (InVar v) = InVar v + varF (NewVar v) = NewVar (forward ! v) mappedCircuit = ac { - acSystem = fromList $ zip [0..] $ mapVarPolynomial forward <$> elems (acSystem ac), + acSystem = fromList $ zip [0..] $ evalPolynomial evalMonomial (var . varF) <$> elems (acSystem ac), -- TODO: the new arithmetic circuit expects the old input variables! We should make this safer. - acWitness = (`Map.compose` backward) $ (. (`Map.compose` forward)) <$> acWitness ac + acWitness = (`Map.compose` backward) $ acWitness ac } - mappedOutputs = mapVar forward <$> acOutput ac - wi' = wi `Map.compose` backward - in ArithmeticCircuitTest (mappedCircuit {acOutput = mappedOutputs}) wi' + mappedOutputs = varF <$> acOutput ac + in ArithmeticCircuitTest (mappedCircuit {acOutput = mappedOutputs}) wi diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs index 577fd836b..1da89284e 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE UndecidableInstances #-} + module ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint ( ClosedPoly, MonadBlueprint (..), @@ -12,35 +14,36 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint ( import Control.Applicative (pure) import Control.Monad.State (State, modify, runState) import Data.Functor (Functor, fmap, ($>), (<$>)) +import Data.Functor.Rep (Representable (..)) import Data.Monoid (mempty, (<>)) +import Data.Ord (Ord) import GHC.Generics (Par1, U1 (..)) -import Numeric.Natural (Natural) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal hiding (constraint) import ZkFold.Symbolic.MonadCircuit -- | A @'MonadCircuit'@ with an added capability -- of embedding another arithmetic circuit inside. -class MonadCircuit i a m => MonadBlueprint i a m where +class MonadCircuit v a m => MonadBlueprint i v a m | m -> i where -- | Adds the supplied circuit to the blueprint and returns its output variable. - runCircuit :: ArithmeticCircuit a f -> m (f i) + runCircuit :: ArithmeticCircuit a i o -> m (o v) -instance Arithmetic a => MonadBlueprint Natural a (State (ArithmeticCircuit a U1)) where +instance (Arithmetic a, Ord (Rep i), Representable i) => MonadBlueprint i (Var i) a (State (ArithmeticCircuit a i U1)) where runCircuit r = modify (<> r {acOutput = U1}) $> acOutput r -circuit :: Arithmetic a => (forall i m . MonadBlueprint i a m => m i) -> ArithmeticCircuit a Par1 +circuit :: (Arithmetic a, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => m v) -> ArithmeticCircuit a i Par1 -- ^ Builds a circuit from blueprint. A blueprint is a function which, given an -- arbitrary type of variables @i@ and a monad @m@ supporting the 'MonadBlueprint' -- API, computes the output variable of a future circuit. circuit b = circuitF (pure <$> b) -circuitF :: forall a f . Arithmetic a => (forall i m . MonadBlueprint i a m => m (f i)) -> ArithmeticCircuit a f +circuitF :: forall a i o . (Arithmetic a, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => m (o v)) -> ArithmeticCircuit a i o -- TODO: I should really rethink this... circuitF b = let (os, r) = runState b mempty in r { acOutput = os } -- TODO: kept for compatibility with @binaryExpansion@ only. Perhaps remove it in the future? -circuits :: forall a f . (Arithmetic a, Functor f) => (forall i m . MonadBlueprint i a m => m (f i)) -> f (ArithmeticCircuit a Par1) +circuits :: forall a i o . (Arithmetic a, Functor o, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => m (o v)) -> o (ArithmeticCircuit a i Par1) -- ^ Builds a collection of circuits from one blueprint. A blueprint is a function -- which, given an arbitrary type of variables @i@ and a monad @m@ supporting the -- 'MonadBlueprint' API, computes the collection of output variables of future circuits. diff --git a/src/ZkFold/Symbolic/Data/ByteString.hs b/src/ZkFold/Symbolic/Data/ByteString.hs index 7c111e8e7..8d0b4d8d1 100644 --- a/src/ZkFold/Symbolic/Data/ByteString.hs +++ b/src/ZkFold/Symbolic/Data/ByteString.hs @@ -22,6 +22,7 @@ import Control.DeepSeq (NFDa import Control.Monad (replicateM) import Data.Bits as B import qualified Data.ByteString as Bytes +import Data.Functor.Rep (Representable (..)) import Data.Kind (Type) import Data.List (foldl, reverse, unfoldr) import Data.Maybe (Maybe (..)) @@ -51,7 +52,7 @@ import ZkFold.Symbolic.Data.Class (Symb import ZkFold.Symbolic.Data.Combinators import ZkFold.Symbolic.Interpreter (Interpreter (..)) import ZkFold.Symbolic.MonadCircuit (Arithmetic, newAssigned) - +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var) -- | A ByteString which stores @n@ bits and uses elements of @a@ as registers, one element per register. -- Bit layout is Big-endian. @@ -163,7 +164,7 @@ instance (KnownNat n, Finite (Zp p)) => FromConstant Natural (ByteString n (Inte instance (KnownNat n, Finite (Zp p)) => FromConstant Integer (ByteString n (Interpreter (Zp p))) where fromConstant = fromConstant . naturalFromInteger . (`Haskell.mod` (2 ^ getNatural @n)) -instance (FromConstant Natural a, Arithmetic a, KnownNat n) => FromConstant Natural (ByteString n (ArithmeticCircuit a)) where +instance (FromConstant Natural a, Arithmetic a, KnownNat n, Haskell.Ord (Rep i), Representable i) => FromConstant Natural (ByteString n (ArithmeticCircuit a i)) where -- | Pack a ByteString using one field element per bit. -- @fromConstant@ discards bits after @n@. @@ -187,7 +188,7 @@ toBase _ 0 = Nothing toBase base b = let (d, m) = b `divMod` base in Just (m, d) -instance (FromConstant Natural a, Arithmetic a, KnownNat n) => FromConstant Integer (ByteString n (ArithmeticCircuit a)) where +instance (FromConstant Natural a, Arithmetic a, KnownNat n, Haskell.Ord (Rep i), Representable i) => FromConstant Integer (ByteString n (ArithmeticCircuit a i)) where fromConstant = fromConstant . naturalFromInteger . (`Haskell.mod` (2 ^ getNatural @n)) instance (Finite (Zp p), KnownNat n) => Arbitrary (ByteString n (Interpreter (Zp p))) where @@ -350,13 +351,13 @@ instance Finite (Zp p) => BitState ByteString n (Interpreter (Zp p)) where -------------------------------------------------------------------------------- -instance (Arithmetic a, KnownNat n) => ShiftBits (ByteString n (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, Haskell.Ord (Rep i), Representable i) => ShiftBits (ByteString n (ArithmeticCircuit a i)) where shiftBits bs@(ByteString oldBits) s | s == 0 = bs | Haskell.abs s >= Haskell.fromIntegral (getNatural @n) = false | otherwise = ByteString $ circuitF solve where - solve :: forall i m. MonadBlueprint i a m => m (Vector n i) + solve :: forall v m. MonadBlueprint i v a m => m (Vector n v) solve = do bits <- V.fromVector <$> runCircuit oldBits zeros <- replicateM (Haskell.fromIntegral $ Haskell.abs s) $ newAssigned (Haskell.const zero) @@ -376,7 +377,7 @@ instance ( KnownNat wordSize , (Div n wordSize) * wordSize ~ n , (Div wordSize 8) * 8 ~ wordSize - ) => ReverseEndianness wordSize (ByteString n (ArithmeticCircuit a)) where + ) => ReverseEndianness wordSize (ByteString n (ArithmeticCircuit a i)) where reverseEndianness (ByteString v) = ByteString $ v { acOutput = reverseEndianness' @wordSize (acOutput v) } @@ -384,25 +385,25 @@ instance -- TODO: Shall we expose it to users? Can they do something malicious having such function? AFAIK there are checks that constrain each bit to 0 or 1. -- bitwiseOperation - :: forall a n - . Arithmetic a - => ByteString n (ArithmeticCircuit a) - -> ByteString n (ArithmeticCircuit a) - -> (forall i. i -> i -> ClosedPoly i a) - -> ByteString n (ArithmeticCircuit a) + :: forall a i n + . (Arithmetic a, Haskell.Ord (Rep i), Representable i) + => ByteString n (ArithmeticCircuit a i) + -> ByteString n (ArithmeticCircuit a i) + -> (forall v. v -> v -> ClosedPoly v a) + -> ByteString n (ArithmeticCircuit a i) bitwiseOperation (ByteString bits1) (ByteString bits2) cons = ByteString $ circuitF solve where - solve :: forall i m. MonadBlueprint i a m => m (Vector n i) + solve :: forall v m. MonadBlueprint i v a m => m (Vector n v) solve = do varsLeft <- runCircuit bits1 varsRight <- runCircuit bits2 V.zipWithM applyBitwise varsLeft varsRight - applyBitwise :: forall i m . MonadBlueprint i a m => i -> i -> m i + applyBitwise :: forall v m . MonadBlueprint i v a m => v -> v -> m v applyBitwise l r = newAssigned $ cons l r -instance (Arithmetic a, KnownNat n) => BoolType (ByteString n (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, Haskell.Ord (Rep i), Representable i) => BoolType (ByteString n (ArithmeticCircuit a i)) where false = ByteString $ embedV (pure zero) true = not false @@ -428,7 +429,7 @@ instance ( KnownNat wordSize , Mod n wordSize ~ 0 , (Div n wordSize) * wordSize ~ n - ) => ToWords (ByteString n (ArithmeticCircuit a)) (ByteString wordSize (ArithmeticCircuit a)) where + ) => ToWords (ByteString n (ArithmeticCircuit a i)) (ByteString wordSize (ArithmeticCircuit a i)) where toWords (ByteString bits) = (\o -> ByteString $ bits { acOutput = o} ) <$> V.fromVector (V.chunks @(Div n wordSize) @wordSize $ acOutput bits) @@ -437,19 +438,19 @@ instance ( Mod k m ~ 0 , (Div k m) * m ~ k , Arithmetic a - ) => Concat (ByteString m (ArithmeticCircuit a)) (ByteString k (ArithmeticCircuit a)) where + ) => Concat (ByteString m (ArithmeticCircuit a i)) (ByteString k (ArithmeticCircuit a i)) where concat bs = ByteString $ bsCircuit {acOutput = bsOutputs} where bsCircuit = Haskell.mconcat $ (\(ByteString bits) -> bits {acOutput = U1}) <$> bs - bsOutputs :: Vector k Natural + bsOutputs :: Vector k (Var i) bsOutputs = V.unsafeConcat @(Div k m) $ (\(ByteString bits) -> acOutput bits) <$> bs instance ( KnownNat n , n <= m - ) => Truncate (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where + ) => Truncate (ByteString m (ArithmeticCircuit a i)) (ByteString n (ArithmeticCircuit a i)) where truncate (ByteString bits) = ByteString $ bits { acOutput = V.take @n (acOutput bits) } @@ -457,12 +458,12 @@ instance ( KnownNat m , KnownNat n , m <= n - , Arithmetic a - ) => Extend (ByteString m (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where + , Arithmetic a, Haskell.Ord (Rep i), Representable i + ) => Extend (ByteString m (ArithmeticCircuit a i)) (ByteString n (ArithmeticCircuit a i)) where extend (ByteString oldBits) = ByteString $ circuitF (Vector <$> solve) where - solve :: forall i m'. MonadBlueprint i a m' => m' [i] + solve :: forall v m'. MonadBlueprint i v a m' => m' [v] solve = do bits <- runCircuit oldBits zeros <- replicateM diff $ newAssigned (Haskell.const zero) @@ -472,14 +473,14 @@ instance diff = Haskell.fromIntegral $ getNatural @n Haskell.- getNatural @m -instance Arithmetic a => BitState ByteString n (ArithmeticCircuit a) where +instance (Arithmetic a, Haskell.Ord (Rep i), Representable i) => BitState ByteString n (ArithmeticCircuit a i) where isSet (ByteString v) ix = Bool $ circuit solve where - solve :: forall i m . MonadBlueprint i a m => m i + solve :: forall v m . MonadBlueprint i v a m => m v solve = (!! ix) . V.fromVector <$> runCircuit v isUnset (ByteString v) ix = Bool $ circuit solve where - solve :: forall i m . MonadBlueprint i a m => m i + solve :: forall v m . MonadBlueprint i v a m => m v solve = do i <- (!! ix) . V.fromVector <$> runCircuit v newAssigned $ \p -> one - p i diff --git a/src/ZkFold/Symbolic/Data/Combinators.hs b/src/ZkFold/Symbolic/Data/Combinators.hs index 72cc7f458..0a8602bb0 100644 --- a/src/ZkFold/Symbolic/Data/Combinators.hs +++ b/src/ZkFold/Symbolic/Data/Combinators.hs @@ -47,12 +47,12 @@ class Shrink a b where -- | Convert an @ArithmeticCircuit@ to bits and return their corresponding variables. -- toBits - :: forall i a m - . MonadBlueprint i a m - => [i] + :: forall i v a m + . MonadBlueprint i v a m + => [v] -> Natural -> Natural - -> m [i] + -> m [v] toBits regs hiBits loBits = do let lows = tail regs high = head regs @@ -69,7 +69,7 @@ fromBits :: forall a . Natural -> Natural - -> (forall i m. MonadBlueprint i a m => [i] -> m [i]) + -> (forall i v m. MonadBlueprint i v a m => [v] -> m [v]) fromBits hiBits loBits bits = do let (bitsHighNew, bitsLowNew) = splitAt (Haskell.fromIntegral hiBits) bits let lowVarsNew = chunksOf (Haskell.fromIntegral loBits) bitsLowNew diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 2fc4059a0..49b965a23 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -20,6 +20,7 @@ import Control.DeepSeq import Control.Monad.State (StateT (..)) import Data.Foldable (foldr, foldrM, for_) import Data.Functor ((<$>)) +import Data.Functor.Rep (Representable (..)) import Data.Kind (Type) import Data.List (unfoldr, zip) import Data.Map (fromList, (!)) @@ -77,7 +78,9 @@ instance , Arithmetic a , KnownNat n , KnownRegisterSize r - ) => FromConstant Natural (UInt n r (ArithmeticCircuit a)) where + , Haskell.Ord (Rep i) + , Representable i + ) => FromConstant Natural (UInt n r (ArithmeticCircuit a i)) where fromConstant c = let (lo, hi, _) = cast @a @n @r . (`Haskell.mod` (2 ^ getNatural @n)) $ c in UInt $ embedV $ Vector $ fromConstant <$> (lo <> [hi]) @@ -87,7 +90,9 @@ instance , Arithmetic a , KnownNat n , KnownRegisterSize r - ) => FromConstant Integer (UInt n r (ArithmeticCircuit a)) where + , Haskell.Ord (Rep i) + , Representable i + ) => FromConstant Integer (UInt n r (ArithmeticCircuit a i)) where fromConstant = fromConstant . naturalFromInteger . (`Haskell.mod` (2 ^ getNatural @n)) instance (FromConstant Natural (UInt n r b), KnownNat n, MultiplicativeSemigroup (UInt n r b)) => Scale Natural (UInt n r b) @@ -219,10 +224,12 @@ instance , n <= k , from ~ NumberOfRegisters a n r , to ~ NumberOfRegisters a k r - ) => Extend (UInt n r (ArithmeticCircuit a)) (UInt k r (ArithmeticCircuit a)) where + , Haskell.Ord (Rep i) + , Representable i + ) => Extend (UInt n r (ArithmeticCircuit a i)) (UInt k r (ArithmeticCircuit a i)) where extend (UInt ac) = UInt (circuitF solve) where - solve :: forall i m. MonadBlueprint i a m => m (Vector to i) + solve :: forall v m. MonadBlueprint i v a m => m (Vector to v) solve = do regs <- V.fromVector <$> runCircuit ac bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) @@ -238,10 +245,12 @@ instance , k <= n , from ~ NumberOfRegisters a n r , to ~ NumberOfRegisters a k r - ) => Shrink (UInt n r (ArithmeticCircuit a)) (UInt k r (ArithmeticCircuit a)) where + , Haskell.Ord (Rep i) + , Representable i + ) => Shrink (UInt n r (ArithmeticCircuit a i)) (UInt k r (ArithmeticCircuit a i)) where shrink (UInt ac) = UInt (circuitF solve) where - solve :: forall i m. MonadBlueprint i a m => m (Vector to i) + solve :: forall v m. MonadBlueprint i v a m => m (Vector to v) solve = do regs <- V.fromVector <$> runCircuit ac bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) @@ -286,30 +295,30 @@ instance let rs = force $ addBit (r' + r') (value @n -! i -! 1) in bool @(Bool b) (q', rs) (q' + fromConstant ((2 :: Natural) ^ i), rs - d) (rs >= d) -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegisters a n r)) => Ord (Bool (ArithmeticCircuit a)) (UInt n r (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegisters a n r), Haskell.Ord (Rep i), Representable i) => Ord (Bool (ArithmeticCircuit a i)) (UInt n r (ArithmeticCircuit a i)) where x <= y = y >= x x < y = y > x u1 >= u2 = - let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a) - ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a) + let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a i) + ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a i) in bitwiseGE rs1 rs2 u1 > u2 = - let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a) - ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a) + let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a i) + ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a i) in bitwiseGT rs1 rs2 - max x y = bool @(Bool (ArithmeticCircuit a)) x y $ x < y + max x y = bool @(Bool (ArithmeticCircuit a i)) x y $ x < y - min x y = bool @(Bool (ArithmeticCircuit a)) x y $ x > y + min x y = bool @(Bool (ArithmeticCircuit a i)) x y $ x > y -instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => AdditiveSemigroup (UInt n r (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => AdditiveSemigroup (UInt n r (ArithmeticCircuit a i)) where UInt x + UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do j <- newAssigned (Haskell.const zero) xs <- V.fromVector <$> runCircuit x @@ -329,7 +338,9 @@ instance , KnownNat n , KnownNat (NumberOfRegisters a n r) , KnownRegisterSize r - ) => AdditiveMonoid (UInt n r (ArithmeticCircuit a)) where + , Haskell.Ord (Rep i) + , Representable i + ) => AdditiveMonoid (UInt n r (ArithmeticCircuit a i)) where zero = UInt $ embedV (pure zero) instance @@ -337,14 +348,16 @@ instance , KnownNat n , KnownRegisterSize r , KnownNat (NumberOfRegisters a n r) - ) => AdditiveGroup (UInt n r (ArithmeticCircuit a)) where + , Haskell.Ord (Rep i) + , Representable i + ) => AdditiveGroup (UInt n r (ArithmeticCircuit a i)) where UInt x - UInt y = UInt $ circuitF (V.unsafeToVector <$> solve) where t :: a t = (one + one) ^ registerSize @a @n @r - one - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -355,13 +368,13 @@ instance (ris, rjs) = Haskell.unzip $ Haskell.init rest in solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i a m => i -> i -> m [i] + solve1 :: MonadBlueprint i v a m => v -> v -> m [v] solve1 i j = do z0 <- newAssigned (\v -> v i - v j + fromConstant (2 ^ registerSize @a @n @r :: Natural)) (z, _) <- splitExpansion (highRegisterSize @a @n @r) 1 z0 return [z] - solveN :: MonadBlueprint i a m => (i, i) -> ([i], [i]) -> (i, i) -> m [i] + solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do s <- newAssigned (\v -> v i - v j + fromConstant (t + one)) (k, b0) <- splitExpansion (registerSize @a @n @r) 1 s @@ -371,7 +384,7 @@ instance (s', _) <- splitExpansion (highRegisterSize @a @n @r) 1 s'0 return (k : zs <> [s']) - fullSub :: MonadBlueprint i a m => i -> i -> i -> m (i, i) + fullSub :: MonadBlueprint i v a m => v -> v -> v -> m (v, v) fullSub xk yk b = do d <- newAssigned (\v -> v xk - v yk) s <- newAssigned (\v -> v d + v b + fromConstant t) @@ -379,7 +392,7 @@ instance negate (UInt x) = UInt $ circuitF (V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do j <- newAssigned (Haskell.const zero) @@ -393,16 +406,16 @@ instance (zs, _) <- flip runStateT j $ traverse StateT (Haskell.zipWith negateN ns xs) return zs - negateN :: MonadBlueprint i a m => Natural -> i -> i -> m (i, i) + negateN :: MonadBlueprint i v a m => Natural -> v -> v -> m (v, v) negateN n i b = do r <- newAssigned (\v -> fromConstant n - v i + v b) splitExpansion (registerSize @a @n @r) 1 r -instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs) => MultiplicativeSemigroup (UInt n rs (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs, Haskell.Ord (Rep i), Representable i) => MultiplicativeSemigroup (UInt n rs (ArithmeticCircuit a i)) where UInt x * UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -413,12 +426,12 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters (ris, rjs) = Haskell.unzip $ Haskell.init rest in solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i a m => i -> i -> m [i] + solve1 :: MonadBlueprint i v a m => v -> v -> m [v] solve1 i j = do (z, _) <- newAssigned (\v -> v i * v j) >>= splitExpansion (highRegisterSize @a @n @rs) (maxOverflow @a @n @rs) return [z] - solveN :: MonadBlueprint i a m => (i, i) -> ([i], [i]) -> (i, i) -> m [i] + solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do let cs = fromList $ zip [0..] (i : is ++ [i']) ds = fromList $ zip [0..] (j : js ++ [j']) @@ -450,9 +463,11 @@ instance , KnownNat (NumberOfRegisters a n r - 1) , KnownRegisterSize r , (NumberOfRegisters a n r - 1) + 1 ~ NumberOfRegisters a n r - ) => MultiplicativeMonoid (UInt n r (ArithmeticCircuit a)) where + , Haskell.Ord (Rep i) + , Representable i + ) => MultiplicativeMonoid (UInt n r (ArithmeticCircuit a i)) where - one = UInt $ hliftA2 (\(Par1 h) t -> h V..: t) (embed one :: ArithmeticCircuit a Par1) (embedV (pure zero) :: ArithmeticCircuit a (Vector (NumberOfRegisters a n r - 1))) + one = UInt $ hliftA2 (\(Par1 h) t -> h V..: t) (embed one :: ArithmeticCircuit a i Par1) (embedV (pure zero) :: ArithmeticCircuit a i (Vector (NumberOfRegisters a n r - 1))) instance @@ -462,7 +477,9 @@ instance , KnownNat (NumberOfRegisters a n r - 1) , KnownRegisterSize r , (NumberOfRegisters a n r - 1) + 1 ~ NumberOfRegisters a n r - ) => Semiring (UInt n r (ArithmeticCircuit a)) + , Haskell.Ord (Rep i) + , Representable i + ) => Semiring (UInt n r (ArithmeticCircuit a i)) instance ( Arithmetic a @@ -471,13 +488,15 @@ instance , KnownNat (NumberOfRegisters a n r - 1) , KnownRegisterSize r , (NumberOfRegisters a n r - 1) + 1 ~ NumberOfRegisters a n r - ) => Ring (UInt n r (ArithmeticCircuit a)) + , Haskell.Ord (Rep i) + , Representable i + ) => Ring (UInt n r (ArithmeticCircuit a i)) -deriving via (Structural (UInt n rs (ArithmeticCircuit a))) - instance (Arithmetic a, r ~ NumberOfRegisters a n rs, 1 <= r) => - Eq (Bool (ArithmeticCircuit a)) (UInt n rs (ArithmeticCircuit a)) +deriving via (Structural (UInt n rs (ArithmeticCircuit a i))) + instance (Arithmetic a, r ~ NumberOfRegisters a n rs, 1 <= r, Haskell.Ord (Rep i), Representable i) => + Eq (Bool (ArithmeticCircuit a i)) (UInt n rs (ArithmeticCircuit a i)) -instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => Arbitrary (UInt n r (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Arbitrary (UInt n r (ArithmeticCircuit a i)) where arbitrary = do lows <- replicateA (numberOfRegisters @a @n @r -! 1) (toss $ registerSize @a @n @r) hi <- toss (highRegisterSize @a @n @r) @@ -485,18 +504,18 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => Arbitrary (UInt n r where toss b = fromConstant <$> chooseInteger (0, 2 ^ b - 1) -instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => Iso (ByteString n (ArithmeticCircuit a)) (UInt n r (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Iso (ByteString n (ArithmeticCircuit a i)) (UInt n r (ArithmeticCircuit a i)) where from (ByteString bits) = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: forall i m. MonadBlueprint i a m => m [i] + solve :: forall v m. MonadBlueprint i v a m => m [v] solve = do bsBits <- V.fromVector <$> runCircuit bits Haskell.reverse <$> fromBits (highRegisterSize @a @n @r) (registerSize @a @n @r) bsBits -instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => Iso (UInt n r (ArithmeticCircuit a)) (ByteString n (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Iso (UInt n r (ArithmeticCircuit a i)) (ByteString n (ArithmeticCircuit a i)) where from (UInt ac) = ByteString $ circuitF $ Vector <$> solve where - solve :: forall i m. MonadBlueprint i a m => m [i] + solve :: forall v m. MonadBlueprint i v a m => m [v] solve = do regs <- V.fromVector <$> runCircuit ac toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) @@ -512,7 +531,7 @@ instance (Finite (Zp p), KnownNat n, KnownRegisterSize r) => StrictConv Natural (lo, hi, []) -> UInt $ Interpreter $ V.unsafeToVector $ (toZp . Haskell.fromIntegral <$> lo) <> [toZp . Haskell.fromIntegral $ hi] _ -> error "strictConv: overflow" -instance (FromConstant Natural a, Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs) => StrictConv Natural (UInt n rs (ArithmeticCircuit a)) where +instance (FromConstant Natural a, Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs, Haskell.Ord (Rep i), Representable i) => StrictConv Natural (UInt n rs (ArithmeticCircuit a i)) where strictConv n = case cast @a @n @rs n of (lo, hi, []) -> UInt $ embedV $ V.unsafeToVector $ fromConstant <$> (lo <> [hi]) _ -> error "strictConv: overflow" @@ -520,13 +539,13 @@ instance (FromConstant Natural a, Arithmetic a, KnownNat n, KnownRegisterSize rs instance (Finite (Zp p), KnownNat n, KnownRegisterSize r) => StrictConv (Zp p) (UInt n r (Interpreter (Zp p))) where strictConv = strictConv . toConstant @_ @Natural -instance (Finite (Zp p), Prime p, KnownNat n, KnownRegisterSize r) => StrictConv (Zp p) (UInt n r (ArithmeticCircuit (Zp p))) where +instance (Finite (Zp p), Prime p, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => StrictConv (Zp p) (UInt n r (ArithmeticCircuit (Zp p) i)) where strictConv = strictConv . toConstant @_ @Natural -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, NumberOfBits a <= n) => StrictConv (ArithmeticCircuit a Par1) (UInt n r (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r, NumberOfBits a <= n, Haskell.Ord (Rep i), Representable i) => StrictConv (ArithmeticCircuit a i Par1) (UInt n r (ArithmeticCircuit a i)) where strictConv a = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do i <- unPar1 <$> runCircuit a let len = Haskell.min (getNatural @n) (numberOfBits @a) @@ -544,10 +563,10 @@ instance (Finite (Zp p), KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r strictSub x y = strictConv $ toConstant x -! toConstant y strictMul x y = strictConv $ toConstant x * toConstant @_ @Natural y -instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r (ArithmeticCircuit a)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => StrictNum (UInt n r (ArithmeticCircuit a i)) where strictAdd (UInt x) (UInt y) = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do j <- newAssigned (Haskell.const zero) xs <- V.fromVector <$> runCircuit x @@ -568,7 +587,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r t :: a t = (one + one) ^ registerSize @a @n @r - one - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -579,13 +598,13 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r (ris, rjs) = Haskell.unzip $ Haskell.init rest in solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i a m => i -> i -> m [i] + solve1 :: MonadBlueprint i v a m => v -> v -> m [v] solve1 i j = do z <- newAssigned (\v -> v i - v j) _ <- expansion (highRegisterSize @a @n @r) z return [z] - solveN :: MonadBlueprint i a m => (i, i) -> ([i], [i]) -> (i, i) -> m [i] + solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do s <- newAssigned (\v -> v i - v j + fromConstant (t + one)) (k, b0) <- splitExpansion (registerSize @a @n @r) 1 s @@ -596,7 +615,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r return (k : zs <> [s']) - fullSub :: MonadBlueprint i a m => i -> i -> i -> m (i, i) + fullSub :: MonadBlueprint i v a m => v -> v -> v -> m (v, v) fullSub xk yk b = do k <- newAssigned (\v -> v xk - v yk) s <- newAssigned (\v -> v k + v b + fromConstant t) @@ -604,7 +623,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r strictMul (UInt x) (UInt y) = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i a m => m [i] + solve :: MonadBlueprint i v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -615,13 +634,13 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r (ris, rjs) = Haskell.unzip $ Haskell.init rest in solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i a m => i -> i -> m [i] + solve1 :: MonadBlueprint i v a m => v -> v -> m [v] solve1 i j = do z <- newAssigned $ \v -> v i * v j _ <- expansion (highRegisterSize @a @n @r) z return [z] - solveN :: MonadBlueprint i a m => (i, i) -> ([i], [i]) -> (i, i) -> m [i] + solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do let cs = fromList $ zip [0..] (i : is ++ [i']) ds = fromList $ zip [0..] (j : js ++ [j']) @@ -652,10 +671,10 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r -------------------------------------------------------------------------------- -fullAdder :: (Arithmetic a, MonadBlueprint i a m) => Natural -> i -> i -> i -> m (i, i) +fullAdder :: (Arithmetic a, MonadBlueprint i v a m) => Natural -> v -> v -> v -> m (v, v) fullAdder r xk yk c = fullAdded xk yk c >>= splitExpansion r 1 -fullAdded :: MonadBlueprint i a m => i -> i -> i -> m i +fullAdded :: MonadBlueprint i v a m => v -> v -> v -> m v fullAdded i j c = do k <- newAssigned (\v -> v i + v j) newAssigned (\v -> v k + v c) diff --git a/zkfold-base.cabal b/zkfold-base.cabal index 7861a4159..1846eec75 100644 --- a/zkfold-base.cabal +++ b/zkfold-base.cabal @@ -208,6 +208,7 @@ library containers < 0.7, cryptohash-sha256 < 0.12, deepseq <= 1.5.0.0, + distributive , lens , mtl < 2.4, optics < 0.5, From c1acfb5de6fd70a944fce15b1976834492a2b16e Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Thu, 8 Aug 2024 09:44:02 -0700 Subject: [PATCH 02/48] more --- src/ZkFold/Base/Protocol/ARK/Protostar.hs | 2 +- src/ZkFold/Symbolic/Cardano/Types/Address.hs | 6 ++++-- src/ZkFold/Symbolic/Cardano/Types/Output.hs | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Protostar.hs b/src/ZkFold/Base/Protocol/ARK/Protostar.hs index 679245ce3..7b618f2c7 100644 --- a/src/ZkFold/Base/Protocol/ARK/Protostar.hs +++ b/src/ZkFold/Base/Protocol/ARK/Protostar.hs @@ -59,7 +59,7 @@ instance Arithmetic a => SpecialSoundProtocol a (RecursiveCircuit n a) where -- The transcript will be empty at this point, it is a one-round protocol -- - prover rc _ i _ = eval (circuit rc) (M.fromList $ P.zip [1..] (V.fromVector i)) + prover rc _ i _ = eval (circuit rc) i -- We can use the polynomial system from the circuit, no need to build it from scratch -- diff --git a/src/ZkFold/Symbolic/Cardano/Types/Address.hs b/src/ZkFold/Symbolic/Cardano/Types/Address.hs index 93ee91a23..82bf1d420 100644 --- a/src/ZkFold/Symbolic/Cardano/Types/Address.hs +++ b/src/ZkFold/Symbolic/Cardano/Types/Address.hs @@ -3,6 +3,7 @@ module ZkFold.Symbolic.Cardano.Types.Address where +import Data.Functor.Rep (Representable (..)) import Prelude hiding (Bool, Eq, length, splitAt, (*), (+)) import qualified Prelude as Haskell @@ -23,8 +24,9 @@ deriving instance (Haskell.Eq (ByteString 4 context), Haskell.Eq (ByteString 224 deriving instance HApplicative context => SymbolicData context (Address context) -deriving via (Structural (Address CtxCompilation)) - instance Eq (Bool CtxCompilation) (Address CtxCompilation) +deriving via (Structural (Address (CtxCompilation i))) + instance (Ord (Rep i), Representable i) + => Eq (Bool (CtxCompilation i)) (Address (CtxCompilation i)) addressType :: Address context -> AddressType context addressType (Address (t, _)) = t diff --git a/src/ZkFold/Symbolic/Cardano/Types/Output.hs b/src/ZkFold/Symbolic/Cardano/Types/Output.hs index 1348bb18b..1075092ef 100644 --- a/src/ZkFold/Symbolic/Cardano/Types/Output.hs +++ b/src/ZkFold/Symbolic/Cardano/Types/Output.hs @@ -11,6 +11,7 @@ module ZkFold.Symbolic.Cardano.Types.Output ( txoDatumHash ) where +import Data.Functor.Rep (Representable (..)) import Prelude hiding (Bool, Eq, length, splitAt, (*), (+)) import qualified Prelude as Haskell @@ -39,13 +40,15 @@ deriving instance , KnownNat tokens ) => SymbolicData context (Output tokens datum context) -deriving via (Structural (Output tokens datum CtxCompilation)) +deriving via (Structural (Output tokens datum (CtxCompilation i))) instance - ( ts ~ TypeSize CtxCompilation (Output tokens datum CtxCompilation) + ( ts ~ TypeSize (CtxCompilation i) (Output tokens datum (CtxCompilation i)) , 1 <= ts , KnownNat tokens - , KnownNat (TypeSize CtxCompilation (Value tokens CtxCompilation)) - ) => Eq (Bool CtxCompilation) (Output tokens datum CtxCompilation) + , KnownNat (TypeSize (CtxCompilation i) (Value tokens (CtxCompilation i))) + , Ord (Rep i) + , Representable i + ) => Eq (Bool (CtxCompilation i)) (Output tokens datum (CtxCompilation i)) txoAddress :: Output tokens datum context -> Address context txoAddress (Output (addr, _)) = addr From 5524fe317f424ab3b55e62b5c47937eed3285328 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 9 Aug 2024 08:42:54 -0700 Subject: [PATCH 03/48] mawr --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 14 +++++++------- src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs | 10 +++++----- src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs | 13 ++++++++----- src/ZkFold/Base/Protocol/ARK/Protostar.hs | 15 +++++++++++---- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index fa226ab4b..3223dbe4a 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -12,6 +12,7 @@ module ZkFold.Base.Protocol.ARK.Plonk ( ) where import Data.Maybe (fromJust) +import Data.Functor.Rep (Representable (..)) import qualified Data.Vector as V import GHC.Generics (Par1) import GHC.IsList (IsList (..)) @@ -30,8 +31,7 @@ import ZkFold.Base.Protocol.ARK.Plonk.Internal import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) import ZkFold.Base.Protocol.Commitment.KZG (com) import ZkFold.Base.Protocol.NonInteractiveProof -import ZkFold.Prelude (length, (!)) -import ZkFold.Symbolic.Compiler (ArithmeticCircuit (acInput)) +import ZkFold.Symbolic.Compiler (ArithmeticCircuit) import ZkFold.Symbolic.MonadCircuit (Arithmetic) {- @@ -44,7 +44,7 @@ data Plonk (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonk { k1 :: ScalarField curve1, k2 :: ScalarField curve1, iPub :: Vector l Natural, - ac :: ArithmeticCircuit (ScalarField curve1) Par1, + ac :: ArithmeticCircuit (ScalarField curve1) (Vector l) Par1, x :: ScalarField curve1 } instance (Show (ScalarField c1), Arithmetic (ScalarField c1)) => Show (Plonk n l c1 c2 t) where @@ -55,7 +55,7 @@ instance (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (Scalar => Arbitrary (Plonk n l c1 c2 t) where arbitrary = do ac <- arbitrary - let fullInp = length . acInput $ ac + let fullInp = value @l vecPubInp <- genSubset (value @l) fullInp let (omega, k1, k2) = getParams (value @n) Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary @@ -114,9 +114,9 @@ instance forall n l c1 c2 t plonk f g1. , FromTranscript t (ScalarField c1) ) => NonInteractiveProof (Plonk n l c1 c2 t) where type Transcript (Plonk n l c1 c2 t) = t - type SetupProve (Plonk n l c1 c2 t) = (PlonkSetupParamsProve c1 c2, PlonkPermutation n c1, PlonkCircuitPolynomials n c1 , PlonkWitnessMap n c1) + type SetupProve (Plonk n l c1 c2 t) = (PlonkSetupParamsProve c1 c2, PlonkPermutation n c1, PlonkCircuitPolynomials n c1 , PlonkWitnessMap n l c1) type SetupVerify (Plonk n l c1 c2 t) = (PlonkSetupParamsVerify c1 c2, PlonkCircuitCommitments c1) - type Witness (Plonk n l c1 c2 t) = (PlonkWitnessInput c1, PlonkProverSecret c1) + type Witness (Plonk n l c1 c2 t) = (PlonkWitnessInput l c1, PlonkProverSecret c1) type Input (Plonk n l c1 c2 t) = PlonkInput c1 type Proof (Plonk n l c1 c2 t) = PlonkProof c1 @@ -172,7 +172,7 @@ instance forall n l c1 c2 t plonk f g1. (w1, w2, w3) = wmap wInput - wPub = fmap (negate . (wInput !)) iPub' + wPub = fmap (negate . index wInput . P.fromIntegral) iPub' pubPoly = polyVecInLagrangeBasis omega' $ toPolyVec @f @n wPub diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index dd7e41bc2..b65dc9f4d 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -5,7 +5,6 @@ module ZkFold.Base.Protocol.ARK.Plonk.Internal where import Data.Bifunctor (first) import Data.Bool (bool) -import qualified Data.Map as Map import qualified Data.Vector as V import GHC.Generics (Generic) import GHC.IsList (IsList (..)) @@ -18,6 +17,7 @@ import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Point) import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) import ZkFold.Prelude (take) +import ZkFold.Base.Data.Vector (Vector) log2 :: (Integral a, Integral b) => a -> b log2 = ceiling @Double . logBase 2 . fromIntegral @@ -141,11 +141,11 @@ instance (Show (BaseField c), EllipticCurve c) => Show (PlonkCircuitCommitments ++ show cmS2 ++ " " ++ show cmS3 -newtype PlonkWitnessMap n c = PlonkWitnessMap - (Map.Map Natural (ScalarField c) -> (PolyVec (ScalarField c) n, PolyVec (ScalarField c) n, PolyVec (ScalarField c) n)) +newtype PlonkWitnessMap n l c = PlonkWitnessMap + (Vector l (ScalarField c) -> (PolyVec (ScalarField c) n, PolyVec (ScalarField c) n, PolyVec (ScalarField c) n)) -newtype PlonkWitnessInput c = PlonkWitnessInput (Map.Map Natural (ScalarField c)) -instance Show (ScalarField c) => Show (PlonkWitnessInput c) where +newtype PlonkWitnessInput l c = PlonkWitnessInput (Vector l (ScalarField c)) +instance Show (ScalarField c) => Show (PlonkWitnessInput l c) where show (PlonkWitnessInput m) = "Witness Input: " ++ show m data PlonkProverSecret c = PlonkProverSecret { diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs index 5cbce2a0e..a13649246 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs @@ -3,7 +3,7 @@ module ZkFold.Base.Protocol.ARK.Plonk.Relation where -import Data.Map (Map, elems, (!)) +import Data.Map (elems, (!)) import GHC.Generics (Par1) import GHC.IsList (IsList (..)) import Prelude hiding (Num (..), drop, length, replicate, sum, take, @@ -18,7 +18,7 @@ import ZkFold.Base.Data.Vector (Vector, fromVecto import ZkFold.Base.Protocol.ARK.Plonk.Constraint (PlonkConstraint (..), toPlonkConstraint) import ZkFold.Prelude (replicate) import ZkFold.Symbolic.Compiler -import ZkFold.Symbolic.MonadCircuit (Arithmetic) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal -- Here `n` is the total number of constraints, `l` is the number of public inputs, and `a` is the field type. data PlonkRelation n l a = PlonkRelation @@ -28,7 +28,7 @@ data PlonkRelation n l a = PlonkRelation , qO :: PolyVec a n , qC :: PolyVec a n , sigma :: Permutation (3 * n) - , wmap :: Map Natural a -> (PolyVec a n, PolyVec a n, PolyVec a n) + , wmap :: Vector l a -> (PolyVec a n, PolyVec a n, PolyVec a n) } toPlonkRelation :: forall n l a . @@ -38,11 +38,14 @@ toPlonkRelation :: forall n l a . => Arithmetic a => Scale a a => Vector l Natural - -> ArithmeticCircuit a Par1 + -> ArithmeticCircuit a (Vector l) Par1 -> Maybe (PlonkRelation n l a) toPlonkRelation xPub ac0 = let ac = desugarRanges ac0 - evalX0 = evalPolynomial evalMonomial (\x -> if x == 0 then one else var x) + + varF (NewVar ix) = if ix == 0 then one else var (ix + value @l) + varF (InVar ix) = var (fromIntegral ix) + evalX0 = evalPolynomial evalMonomial varF pubInputConstraints = map var (fromVector xPub) acConstraints = map evalX0 $ elems (acSystem ac) diff --git a/src/ZkFold/Base/Protocol/ARK/Protostar.hs b/src/ZkFold/Base/Protocol/ARK/Protostar.hs index 7b618f2c7..458fb81df 100644 --- a/src/ZkFold/Base/Protocol/ARK/Protostar.hs +++ b/src/ZkFold/Base/Protocol/ARK/Protostar.hs @@ -11,10 +11,10 @@ import Prelude (($), (==)) import qualified Prelude as P import ZkFold.Base.Algebra.Basic.Number -import qualified ZkFold.Base.Data.Vector as V import ZkFold.Base.Data.Vector (Vector) import ZkFold.Base.Protocol.ARK.Protostar.SpecialSound import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal +import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) {-- @@ -45,7 +45,7 @@ data RecursiveCircuit n a , circuit :: ArithmeticCircuit a (Vector n) (Vector n) } deriving (Generic, NFData) -instance Arithmetic a => SpecialSoundProtocol a (RecursiveCircuit n a) where +instance (Arithmetic a, KnownNat n) => SpecialSoundProtocol a (RecursiveCircuit n a) where type Witness a (RecursiveCircuit n a) = Map Natural a type Input a (RecursiveCircuit n a) = Vector n a type ProverMessage a (RecursiveCircuit n a) = Vector n a @@ -63,9 +63,16 @@ instance Arithmetic a => SpecialSoundProtocol a (RecursiveCircuit n a) where -- We can use the polynomial system from the circuit, no need to build it from scratch -- - algebraicMap rc _ _ _ = M.elems $ acSystem (circuit rc) + algebraicMap rc _ _ _ = + let + varF (NewVar ix) = var (ix P.+ value @n) + varF (InVar ix) = var (P.fromIntegral ix) + in + [ evalPolynomial evalMonomial varF poly + | poly <- M.elems $ acSystem (circuit rc) + ] -- The transcript is only one prover message since this is a one-round protocol -- - verifier rc i pm _ = eval (circuit rc) (M.fromList $ P.zip [1..] (V.fromVector i)) == P.head pm + verifier rc i pm _ = eval (circuit rc) i == P.head pm From fcf44d2079b479d836894c6d96d57541d51f73be Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 9 Aug 2024 10:20:46 -0700 Subject: [PATCH 04/48] lots of editing --- examples/Examples/BatchTransfer.hs | 3 +- examples/Examples/ByteString.hs | 7 ++-- examples/Examples/Conditional.hs | 5 ++- examples/Examples/Eq.hs | 3 +- examples/Examples/FFA.hs | 7 ++-- examples/Examples/Fibonacci.hs | 3 +- examples/Examples/LEQ.hs | 3 +- examples/Examples/MiMCHash.hs | 3 +- examples/Examples/ReverseList.hs | 2 +- examples/Examples/UInt.hs | 15 ++++--- src/ZkFold/Symbolic/Compiler.hs | 21 ++++++---- .../Compiler/ArithmeticCircuit/Internal.hs | 3 -- src/ZkFold/Symbolic/Data/FieldElement.hs | 38 ++--------------- tests/Tests/ArithmeticCircuit.hs | 20 ++++----- tests/Tests/Arithmetization.hs | 12 +++--- tests/Tests/Arithmetization/Test1.hs | 7 ++-- tests/Tests/Arithmetization/Test2.hs | 5 ++- tests/Tests/Arithmetization/Test3.hs | 7 ++-- tests/Tests/Arithmetization/Test4.hs | 30 +++++++------- tests/Tests/Blake2b.hs | 4 +- tests/Tests/ByteString.hs | 41 ++++++++++--------- tests/Tests/FFA.hs | 5 ++- tests/Tests/NonInteractiveProof/Internal.hs | 8 ++-- tests/Tests/NonInteractiveProof/Plonk.hs | 6 +-- tests/Tests/SHA2.hs | 15 +++---- tests/Tests/UInt.hs | 28 ++++++------- zkfold-base.cabal | 1 + 27 files changed, 145 insertions(+), 157 deletions(-) diff --git a/examples/Examples/BatchTransfer.hs b/examples/Examples/BatchTransfer.hs index 2520a1354..984468fc7 100644 --- a/examples/Examples/BatchTransfer.hs +++ b/examples/Examples/BatchTransfer.hs @@ -6,6 +6,7 @@ module Examples.BatchTransfer (exampleBatchTransfer) where import Prelude hiding (Eq (..), Num (..), any, not, (!!), (/), (^), (||)) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Cardano.Contracts.BatchTransfer (batchTransfer) import ZkFold.Symbolic.Cardano.Types import ZkFold.Symbolic.Compiler (compileIO) @@ -16,4 +17,4 @@ exampleBatchTransfer = do putStrLn "\nExample: Batch Transfer smart contract\n" - compileIO @F file (batchTransfer @CtxCompilation) + compileIO @151810 @F file (batchTransfer @(CtxCompilation (Vector 151810))) diff --git a/examples/Examples/ByteString.hs b/examples/Examples/ByteString.hs index 697529bdb..06deb6c20 100644 --- a/examples/Examples/ByteString.hs +++ b/examples/Examples/ByteString.hs @@ -18,6 +18,7 @@ import Text.Show (show) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler (ArithmeticCircuit, compileIO) import ZkFold.Symbolic.Data.Bool import ZkFold.Symbolic.Data.ByteString @@ -35,15 +36,15 @@ exampleByteStringExtend = do let k = show $ natVal (Proxy @k) putStrLn $ "\nExample: Extending a bytestring of length " ++ n ++ " to length " ++ k let file = "compiled_scripts/bytestring" ++ n ++ "_to_" ++ k ++ ".json" - compileIO @(Zp BLS12_381_Scalar) file $ extend @(ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar))) @(ByteString k (ArithmeticCircuit (Zp BLS12_381_Scalar))) + compileIO @n @(Zp BLS12_381_Scalar) file $ extend @(ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector n))) @(ByteString k (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector n))) type Binary a = a -> a -> a -type UBinary n = Binary (ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar))) +type UBinary n = Binary (ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector (n + n)))) makeExample :: forall n . (KnownNat n, KnownNat (n + n)) => String -> String -> UBinary n -> IO () makeExample shortName name op = do let n = show $ natVal (Proxy @n) putStrLn $ "\nExample: (" ++ shortName ++ ") operation on ByteString" ++ n let file = "compiled_scripts/bytestring" ++ n ++ "_" ++ name ++ ".json" - compileIO @(Zp BLS12_381_Scalar) file op + compileIO @(n+n) @(Zp BLS12_381_Scalar) file op diff --git a/examples/Examples/Conditional.hs b/examples/Examples/Conditional.hs index 443d8efed..0c746cf1a 100644 --- a/examples/Examples/Conditional.hs +++ b/examples/Examples/Conditional.hs @@ -7,12 +7,13 @@ import Prelude (IO, putStrLn) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool) import ZkFold.Symbolic.Data.Conditional (Conditional (..)) type F = Zp BLS12_381_Scalar -type A = ArithmeticCircuit F +type A = ArithmeticCircuit F (Vector 3) type B = Bool A exampleConditional :: IO () @@ -21,4 +22,4 @@ exampleConditional = do putStrLn "\nExample: conditional\n" - compileIO @F file (bool @B @(A Par1)) + compileIO @3 @F file (bool @B @(A Par1)) diff --git a/examples/Examples/Eq.hs b/examples/Examples/Eq.hs index 78beef5c7..7873186da 100644 --- a/examples/Examples/Eq.hs +++ b/examples/Examples/Eq.hs @@ -7,6 +7,7 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) @@ -23,4 +24,4 @@ exampleEq = do putStrLn "\nExample: (==) operation\n" - compileIO @(Zp BLS12_381_Scalar) file (eq @(ArithmeticCircuit (Zp BLS12_381_Scalar))) + compileIO @2 @(Zp BLS12_381_Scalar) file (eq @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 2))) diff --git a/examples/Examples/FFA.hs b/examples/Examples/FFA.hs index 43cad7173..7145e1a7a 100644 --- a/examples/Examples/FFA.hs +++ b/examples/Examples/FFA.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE AllowAmbiguousTypes, TypeOperators #-} module Examples.FFA (examplesFFA) where @@ -12,6 +12,7 @@ import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler (ArithmeticCircuit, compileIO) import ZkFold.Symbolic.Data.FFA (FFA) @@ -33,9 +34,9 @@ exampleFFAmul = makeExample @p "*" "mul" (*) type Binary a = a -> a -> a -makeExample :: forall p. KnownNat p => String -> String -> Binary (FFA p (ArithmeticCircuit (Zp BLS12_381_Scalar))) -> IO () +makeExample :: forall p. KnownNat p => String -> String -> Binary (FFA p (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 14))) -> IO () makeExample shortName name op = do let p = show $ value @p putStrLn $ "\nExample: (" ++ shortName ++ ") operation on FFA " ++ p let file = "compiled_scripts/ffa_" ++ p ++ "_" ++ name ++ ".json" - compileIO @(Zp BLS12_381_Scalar) file op + compileIO @14 @(Zp BLS12_381_Scalar) file op diff --git a/examples/Examples/Fibonacci.hs b/examples/Examples/Fibonacci.hs index 5ecdbd9e4..662e8ca5e 100644 --- a/examples/Examples/Fibonacci.hs +++ b/examples/Examples/Fibonacci.hs @@ -7,6 +7,7 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) @@ -29,4 +30,4 @@ exampleFibonacci = do putStrLn "\nExample: Fibonacci index function\n" - compileIO @(Zp BLS12_381_Scalar) file (fibonacciIndex @(ArithmeticCircuit (Zp BLS12_381_Scalar)) nMax) + compileIO @1 @(Zp BLS12_381_Scalar) file (fibonacciIndex @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 1)) nMax) diff --git a/examples/Examples/LEQ.hs b/examples/Examples/LEQ.hs index 94ec0e436..c91852404 100644 --- a/examples/Examples/LEQ.hs +++ b/examples/Examples/LEQ.hs @@ -7,6 +7,7 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool) @@ -23,4 +24,4 @@ exampleLEQ = do putStrLn "\nExample: (<=) operation\n" - compileIO @(Zp BLS12_381_Scalar) file (leq @(ArithmeticCircuit (Zp BLS12_381_Scalar))) + compileIO @2 @(Zp BLS12_381_Scalar) file (leq @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 2))) diff --git a/examples/Examples/MiMCHash.hs b/examples/Examples/MiMCHash.hs index 99653415d..894ce6f8f 100644 --- a/examples/Examples/MiMCHash.hs +++ b/examples/Examples/MiMCHash.hs @@ -8,6 +8,7 @@ import Prelude hiding (Eq (..), import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Algorithms.Hash.MiMC (mimcHash2) import ZkFold.Symbolic.Algorithms.Hash.MiMC.Constants (mimcConstants) import ZkFold.Symbolic.Compiler @@ -21,4 +22,4 @@ exampleMiMC = do putStrLn "\nExample: MiMC hash function\n" - compileIO @F file (mimcHash2 @F @(FieldElement (ArithmeticCircuit F)) mimcConstants zero) + compileIO @2 @F file (mimcHash2 @F @(FieldElement (ArithmeticCircuit F (Vector 2))) mimcConstants zero) diff --git a/examples/Examples/ReverseList.hs b/examples/Examples/ReverseList.hs index 931d4751e..4b58de22b 100644 --- a/examples/Examples/ReverseList.hs +++ b/examples/Examples/ReverseList.hs @@ -22,4 +22,4 @@ exampleReverseList = do putStrLn "\nExample: Reverse List function\n" - compileIO @(Zp BLS12_381_Scalar) file (reverseList @(ArithmeticCircuit (Zp BLS12_381_Scalar) Par1) @32) + compileIO @32 @(Zp BLS12_381_Scalar) file (reverseList @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 32) Par1) @32) diff --git a/examples/Examples/UInt.hs b/examples/Examples/UInt.hs index 8f776149b..7a67e96c0 100644 --- a/examples/Examples/UInt.hs +++ b/examples/Examples/UInt.hs @@ -22,6 +22,7 @@ import Text.Show (show) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler (ArithmeticCircuit, compileIO) import ZkFold.Symbolic.Data.Combinators import ZkFold.Symbolic.Data.UInt @@ -29,7 +30,7 @@ import ZkFold.Symbolic.Data.UInt exampleUIntAdd :: forall n r . KnownNat n - => r ~ NumberOfRegisters (Zp BLS12_381_Scalar) n Auto + => r ~ Num n => KnownNat r => KnownNat (r + r) => IO () @@ -38,7 +39,7 @@ exampleUIntAdd = makeExample @n "+" "add" (+) exampleUIntMul :: forall n r . KnownNat n - => r ~ NumberOfRegisters (Zp BLS12_381_Scalar) n Auto + => r ~ Num n => KnownNat r => KnownNat (r + r) => IO () @@ -47,7 +48,7 @@ exampleUIntMul = makeExample @n "*" "mul" (*) exampleUIntStrictAdd :: forall n r . KnownNat n - => r ~ NumberOfRegisters (Zp BLS12_381_Scalar) n Auto + => r ~ Num n => KnownNat r => KnownNat (r + r) => IO () @@ -56,7 +57,7 @@ exampleUIntStrictAdd = makeExample @n "strictAdd" "strict_add" strictAdd exampleUIntStrictMul :: forall n r . KnownNat n - => r ~ NumberOfRegisters (Zp BLS12_381_Scalar) n Auto + => r ~ Num n => KnownNat r => KnownNat (r + r) => IO () @@ -64,7 +65,9 @@ exampleUIntStrictMul = makeExample @n "strictMul" "strict_mul" strictMul type Binary a = a -> a -> a -type UBinary n = Binary (UInt n Auto (ArithmeticCircuit (Zp BLS12_381_Scalar))) +type Num n = NumberOfRegisters (Zp BLS12_381_Scalar) n Auto + +type UBinary n = Binary (UInt n Auto (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector (Num n + Num n)))) makeExample :: forall n r @@ -77,4 +80,4 @@ makeExample shortName name op = do let n = show $ natVal (Proxy @n) putStrLn $ "\nExample: (" ++ shortName ++ ") operation on UInt" ++ n let file = "compiled_scripts/uint" ++ n ++ "_" ++ name ++ ".json" - compileIO @(Zp BLS12_381_Scalar) file op + compileIO @(Num n + Num n) @(Zp BLS12_381_Scalar) file op diff --git a/src/ZkFold/Symbolic/Compiler.hs b/src/ZkFold/Symbolic/Compiler.hs index 02ad4b119..41aa9d237 100644 --- a/src/ZkFold/Symbolic/Compiler.hs +++ b/src/ZkFold/Symbolic/Compiler.hs @@ -52,13 +52,14 @@ solder f = pieces f (restore @c @(Support c f) $ const inputC) -- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1. compileForceOne :: - forall a c f y . - ( c ~ ArithmeticCircuit a (Vector (TypeSize c (Support c f))) + forall n a c f y . + ( n ~ TypeSize c (Support c f) + , c ~ ArithmeticCircuit a (Vector n) , Arithmetic a , SymbolicData c f , SymbolicData c (Support c f) , Support c (Support c f) ~ () - , KnownNat (TypeSize c (Support c f)) + , KnownNat n , SymbolicData c y , Support c y ~ () , TypeSize c f ~ TypeSize c y @@ -67,14 +68,15 @@ compileForceOne = restore @c . const . optimize . forceOne . solder @a -- | Compiles function `f` into an arithmetic circuit. compile :: - forall a c f y . + forall n a c f y . ( Eq a , MultiplicativeMonoid a - , c ~ ArithmeticCircuit a (Vector (TypeSize c (Support c f))) + , n ~ TypeSize c (Support c f) + , c ~ ArithmeticCircuit a (Vector n) , SymbolicData c f , SymbolicData c (Support c f) , Support c (Support c f) ~ () - , KnownNat (TypeSize c (Support c f)) + , KnownNat n , SymbolicData c y , Support c y ~ () , TypeSize c f ~ TypeSize c y @@ -83,15 +85,16 @@ compile = restore @c . const . optimize . solder @a -- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file. compileIO :: - forall a c f . + forall n a c f . ( Eq a , MultiplicativeMonoid a - , c ~ ArithmeticCircuit a (Vector (TypeSize c (Support c f))) + , n ~ TypeSize c (Support c f) + , c ~ ArithmeticCircuit a (Vector n) , ToJSON a , SymbolicData c f , SymbolicData c (Support c f) , Support c (Support c f) ~ () - , KnownNat (TypeSize c (Support c f)) + , KnownNat n ) => FilePath -> f -> IO () compileIO scriptFile f = do let ac = optimize (solder @a f) :: c (Vector (TypeSize c f)) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index d2bb650f3..5387e388e 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -91,9 +91,6 @@ deriving instance NFData (Rep i) => NFData (Var i) witnessGenerator :: ArithmeticCircuit a i o -> i a -> Map Natural a witnessGenerator circuit inputs = fmap ($ inputs) (acWitness circuit) --- let srcs = acWitness circuit --- witness = ($ witness) <$> (srcs `union` fmap const inputs) --- in witness ------------------------------ Symbolic compiler context ---------------------------- diff --git a/src/ZkFold/Symbolic/Data/FieldElement.hs b/src/ZkFold/Symbolic/Data/FieldElement.hs index 25e8dfe22..0517e4f27 100644 --- a/src/ZkFold/Symbolic/Data/FieldElement.hs +++ b/src/ZkFold/Symbolic/Data/FieldElement.hs @@ -4,14 +4,10 @@ module ZkFold.Symbolic.Data.FieldElement where -import Data.Bool (bool) -import Data.Foldable (foldlM, foldr) +import Data.Foldable (foldr) import Data.Function (($), (.)) import Data.Functor (fmap, (<$>)) -import Data.Ord (Ordering (..), compare) -import Data.Traversable (for) -import Data.Tuple (fst, snd) -import Data.Zip (zip) +import Data.Tuple (snd) import GHC.Generics (Par1 (..)) import Prelude (Integer) import qualified Prelude as Haskell @@ -23,9 +19,8 @@ import ZkFold.Base.Data.Par1 () import ZkFold.Base.Data.Vector (Vector, fromVector, unsafeToVector) import ZkFold.Symbolic.Class import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (expansion, horner, runInvert) -import ZkFold.Symbolic.Data.Bool (Bool (Bool)) +import ZkFold.Symbolic.Data.Bool (Bool) import ZkFold.Symbolic.Data.Class -import ZkFold.Symbolic.Data.DiscreteField import ZkFold.Symbolic.Data.Eq (Eq) import ZkFold.Symbolic.Data.Ord import ZkFold.Symbolic.MonadCircuit (newAssigned) @@ -102,30 +97,3 @@ instance Symbolic c => BinaryExpansion (FieldElement c) where fromBinary bits = FieldElement $ symbolicF bits (Par1 . foldr (\x y -> x + y + y) zero) $ fmap Par1 . horner . fromVector - -instance Symbolic c => DiscreteField' (FieldElement c) where - equal x y = let Bool c = isZero (x - y) in FieldElement c - -instance Symbolic c => TrichotomyField (FieldElement c) where - trichotomy (FieldElement x) (FieldElement y) = - FieldElement $ symbolic2F x y - (\u v -> Par1 $ case compare u v of { LT -> negate one; EQ -> zero; GT -> one }) - $ \(Par1 i) (Par1 j) -> do - is <- expansion (numberOfBits @(BaseField c)) i - js <- expansion (numberOfBits @(BaseField c)) j - -- zip pairs of bits in {0,1} to orderings in {-1,0,1} - delta <- for (zip is js) $ \(bi, bj) -> newAssigned (\w -> w bi - w bj) - -- least significant bit first, - -- reverse lexicographical ordering - let reverseLexicographical v u = do - is0 <- newAssigned (\p -> one - p u * p u) - v' <- newAssigned (\p -> p is0 * p v) - newAssigned (\p -> p v' + p u) - Par1 <$> case delta of - [] -> newAssigned zero - (d:ds) -> foldlM reverseLexicographical d ds - -instance Symbolic c => DiscreteField (Bool c) (FieldElement c) where - isZero (FieldElement x) = - Bool $ symbolicF x (Par1 . bool zero one . (Haskell.== Par1 zero)) - $ fmap fst . runInvert diff --git a/tests/Tests/ArithmeticCircuit.hs b/tests/Tests/ArithmeticCircuit.hs index 52d74ea5b..35cdb2568 100644 --- a/tests/Tests/ArithmeticCircuit.hs +++ b/tests/Tests/ArithmeticCircuit.hs @@ -5,7 +5,7 @@ module Tests.ArithmeticCircuit (exec1, it, specArithmeticCircuit) where import Data.Bool (bool) -import GHC.Generics (Par1) +import GHC.Generics (Par1, U1) import Prelude (IO, Show, String, id, ($)) import qualified Prelude as Haskell import qualified Test.Hspec @@ -19,7 +19,7 @@ import qualified ZkFold.Base.Data.Vector as V import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (embed) import ZkFold.Symbolic.Data.Bool -import ZkFold.Symbolic.Data.DiscreteField +-- import ZkFold.Symbolic.Data.DiscreteField import ZkFold.Symbolic.Data.Eq import ZkFold.Symbolic.Data.FieldElement import ZkFold.Symbolic.MonadCircuit (Arithmetic) @@ -47,20 +47,20 @@ specArithmeticCircuit' = hspec $ do it "has one" $ correctHom0 @a one it "inverts nonzero correctly" $ correctHom1 @a finv it "inverts zero correctly" $ correctHom0 @a (finv zero) - it "checks isZero(nonzero)" $ \(x :: a) -> - let Bool (r :: ArithmeticCircuit a Par1) = isZero $ FieldElement (embed x) - in checkClosedCircuit r .&&. exec1 r === bool zero one (x Haskell.== zero) - it "checks isZero(0)" $ - let Bool (r :: ArithmeticCircuit a Par1) = isZero (zero :: FieldElement (ArithmeticCircuit a)) - in withMaxSuccess 1 $ checkClosedCircuit r .&&. exec1 r === one + -- it "checks isZero(nonzero)" $ \(x :: a) -> + -- let Bool (r :: ArithmeticCircuit a U1 Par1) = isZero $ FieldElement (embed x) + -- in checkClosedCircuit r .&&. exec1 r === bool zero one (x Haskell.== zero) + -- it "checks isZero(0)" $ + -- let Bool (r :: ArithmeticCircuit a U1 Par1) = isZero (zero :: FieldElement (ArithmeticCircuit a U1)) + -- in withMaxSuccess 1 $ checkClosedCircuit r .&&. exec1 r === one it "computes binary expansion" $ \(x :: a) -> let rs = binaryExpansion $ FieldElement (embed x) in checkClosedCircuit rs .&&. V.fromVector (exec rs) === padBits (numberOfBits @a) (binaryExpansion x) it "internalizes equality" $ \(x :: a) (y :: a) -> - let Bool (r :: ArithmeticCircuit a Par1) = embed x == embed y + let Bool (r :: ArithmeticCircuit a U1 Par1) = (embed x :: ArithmeticCircuit a U1 Par1) == embed y in checkClosedCircuit r .&&. exec1 r === bool zero one (x Haskell.== y) it "internal equality is reflexive" $ \(x :: a) -> - let Bool (r :: ArithmeticCircuit a Par1) = embed x == embed x + let Bool (r :: ArithmeticCircuit a U1 Par1) = (embed x :: ArithmeticCircuit a U1 Par1) == embed x in checkClosedCircuit r .&&. exec1 r === one specArithmeticCircuit :: IO () diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index 69c1ffeaa..d7cb37773 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -3,6 +3,7 @@ module Tests.Arithmetization (specArithmetization) where +import Data.Functor.Rep (Representable (..)) import GHC.Generics (Par1) import Prelude import Test.Hspec @@ -12,25 +13,26 @@ import Tests.Arithmetization.Test2 (specArithmetiza import Tests.Arithmetization.Test3 (specArithmetization3) import Tests.Arithmetization.Test4 (specArithmetization4) -import ZkFold.Base.Algebra.Basic.Class (FromConstant, MultiplicativeMonoid) +import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (ArithmeticCircuitTest (..)) import ZkFold.Symbolic.MonadCircuit (Arithmetic) -propCircuitInvariance :: (MultiplicativeMonoid a, Eq a) => ArithmeticCircuitTest a Par1 -> Bool +propCircuitInvariance :: (Arithmetic a, Scale a a, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i Par1 -> Bool propCircuitInvariance act@(ArithmeticCircuitTest ac wi) = let ArithmeticCircuitTest ac' wi' = mapVarArithmeticCircuit act v = ac `eval` wi v' = ac' `eval` wi' in v == v' -specArithmetization' :: forall a . (FromConstant a a, Arithmetic a, Arbitrary a, Show a, Show (ArithmeticCircuitTest a Par1)) => IO () +specArithmetization' :: forall a i . (FromConstant a a, Scale a a, Arithmetic a, Arbitrary a, Show a, Show (ArithmeticCircuitTest a i Par1), Arbitrary (Rep i), Ord (Rep i), Representable i, Traversable i) => IO () specArithmetization' = hspec $ do describe "Arithmetization specification" $ do describe "Variable mapping" $ do - it "does not change the circuit" $ property $ propCircuitInvariance @a + it "does not change the circuit" $ property $ propCircuitInvariance @a @i specArithmetization1 @a specArithmetization2 specArithmetization3 @@ -38,4 +40,4 @@ specArithmetization' = hspec $ do specArithmetization :: IO () specArithmetization = do - specArithmetization' @(Zp BLS12_381_Scalar) + specArithmetization' @(Zp BLS12_381_Scalar) @(Vector 2) diff --git a/tests/Tests/Arithmetization/Test1.hs b/tests/Tests/Arithmetization/Test1.hs index 812f245cc..11ad65cc1 100644 --- a/tests/Tests/Arithmetization/Test1.hs +++ b/tests/Tests/Arithmetization/Test1.hs @@ -11,6 +11,7 @@ import Test.Hspec import Test.QuickCheck import ZkFold.Base.Algebra.Basic.Class +import ZkFold.Base.Data.Vector (Vector, unsafeToVector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) @@ -29,13 +30,13 @@ testFunc x y = g3 = c 2 // x in (g3 == y :: Bool c) ? g1 $ g2 -testResult :: forall a . (FromConstant a a, Arithmetic a) => ArithmeticCircuit a Par1 -> a -> a -> Haskell.Bool -testResult r x y = fromConstant (unPar1 $ acValue $ applyArgs r [x, y]) Haskell.== +testResult :: forall a . (FromConstant a a, Arithmetic a) => ArithmeticCircuit a (Vector 2) Par1 -> a -> a -> Haskell.Bool +testResult r x y = fromConstant (unPar1 $ eval r (unsafeToVector [x, y])) Haskell.== testFunc @(Interpreter a) (fromConstant x) (fromConstant y) specArithmetization1 :: forall a . (FromConstant a a, Arithmetic a, Arbitrary a, Show a) => Spec specArithmetization1 = do describe "Arithmetization test 1" $ do it "should pass" $ do - let ac = compile @a (testFunc @(ArithmeticCircuit a)) :: ArithmeticCircuit a Par1 + let ac = compile @2 @a (testFunc @(ArithmeticCircuit a (Vector 2))) :: ArithmeticCircuit a (Vector 2) Par1 property $ \x y -> testResult ac x y diff --git a/tests/Tests/Arithmetization/Test2.hs b/tests/Tests/Arithmetization/Test2.hs index e83d5c711..3d0242e00 100644 --- a/tests/Tests/Arithmetization/Test2.hs +++ b/tests/Tests/Arithmetization/Test2.hs @@ -17,6 +17,7 @@ import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..)) import ZkFold.Symbolic.Data.Eq (Eq (..)) import ZkFold.Symbolic.Data.FieldElement (FieldElement) +import ZkFold.Base.Data.Vector (Vector, unsafeToVector) import ZkFold.Symbolic.MonadCircuit (Arithmetic) -- A true statement. @@ -25,8 +26,8 @@ tautology x y = (x /= y) || (x == y) testTautology :: forall a . Arithmetic a => a -> a -> Haskell.Bool testTautology x y = - let Bool ac = compile @a (tautology @(ArithmeticCircuit a)) - b = unPar1 $ acValue (applyArgs ac [x, y]) + let Bool (ac :: ArithmeticCircuit a (Vector 2) Par1) = compile @2 @a (tautology @(ArithmeticCircuit a (Vector 2))) + b = unPar1 (eval ac (unsafeToVector [x, y])) in b Haskell.== one specArithmetization2 :: Spec diff --git a/tests/Tests/Arithmetization/Test3.hs b/tests/Tests/Arithmetization/Test3.hs index e1726dc6e..ac5da8775 100644 --- a/tests/Tests/Arithmetization/Test3.hs +++ b/tests/Tests/Arithmetization/Test3.hs @@ -9,6 +9,7 @@ import Test.Hspec import ZkFold.Base.Algebra.Basic.Class (fromConstant) import ZkFold.Base.Algebra.Basic.Field (Zp) +import ZkFold.Base.Data.Vector (Vector, unsafeToVector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) @@ -16,7 +17,7 @@ import ZkFold.Symbolic.Data.FieldElement (FieldElement) import ZkFold.Symbolic.Data.Ord ((<=)) import ZkFold.Symbolic.Interpreter (Interpreter (Interpreter)) -type R = ArithmeticCircuit (Zp 97) +type R = ArithmeticCircuit (Zp 97) (Vector 2) -- A comparison test testFunc :: Symbolic c => FieldElement c -> FieldElement c -> Bool c @@ -26,5 +27,5 @@ specArithmetization3 :: Spec specArithmetization3 = do describe "Arithmetization test 3" $ do it "should pass" $ do - let Bool r = compile @(Zp 97) (testFunc @R) :: Bool R - Bool (Interpreter $ acValue (applyArgs r [3, 5])) `shouldBe` testFunc (fromConstant (3 :: Natural)) (fromConstant (5 :: Natural)) + let Bool r = compile @2 @(Zp 97) (testFunc @R) :: Bool R + Bool (Interpreter (eval r (unsafeToVector [3, 5]))) `shouldBe` testFunc (fromConstant (3 :: Natural)) (fromConstant (5 :: Natural)) diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index bce0f8193..6a21b31b7 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -2,7 +2,6 @@ module Tests.Arithmetization.Test4 (specArithmetization4) where -import Data.Map (fromList) import GHC.Generics (Par1 (unPar1)) import GHC.Num (Natural) import Prelude hiding (Bool, Eq (..), Num (..), Ord (..), (&&)) @@ -20,8 +19,9 @@ import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkI import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams) import ZkFold.Base.Protocol.NonInteractiveProof (NonInteractiveProof (..)) import ZkFold.Symbolic.Class -import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), acValue, applyArgs, compile, - compileForceOne) +import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), compile, + compileForceOne, eval) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var (..)) import ZkFold.Symbolic.Data.Bool (Bool (..)) import ZkFold.Symbolic.Data.Eq (Eq (..)) import ZkFold.Symbolic.Data.FieldElement (FieldElement) @@ -36,23 +36,25 @@ lockedByTxId targetValue inputValue = inputValue == fromConstant targetValue testSameValue :: F -> Haskell.Bool testSameValue targetValue = - let Bool ac = compile @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F) - b = unPar1 $ acValue (applyArgs ac [targetValue]) + let Bool ac = compile @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) + b = unPar1 (eval ac (V.singleton targetValue)) in b Haskell.== one testDifferentValue :: F -> F -> Haskell.Bool testDifferentValue targetValue otherValue = - let Bool ac = compile @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F) - b = unPar1 $ acValue (applyArgs ac [otherValue]) + let Bool ac = compile @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) + b = unPar1 (eval ac (V.singleton otherValue)) in b Haskell.== zero testOnlyOutputZKP :: F -> PlonkProverSecret C -> F -> Haskell.Bool testOnlyOutputZKP x ps targetValue = - let Bool ac = compile @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F) + let Bool ac = compile @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) (omega, k1, k2) = getParams 32 - witnessInputs = fromList [(1, targetValue), (unPar1 $ acOutput ac, 1)] - indexOutputBool = V.singleton $ unPar1 $ acOutput ac + witnessInputs = V.singleton targetValue + varF (NewVar ix) = ix + 1 + varF (InVar ix) = fromIntegral ix + indexOutputBool = V.singleton $ varF $ unPar1 $ acOutput ac plonk = Plonk @32 omega k1 k2 indexOutputBool ac x setupP = setupProve @(PlonkBS N) plonk setupV = setupVerify @(PlonkBS N) plonk @@ -66,10 +68,10 @@ testOnlyOutputZKP x ps targetValue = testSafeOneInputZKP :: F -> PlonkProverSecret C -> F -> Haskell.Bool testSafeOneInputZKP x ps targetValue = - let Bool ac = compileForceOne @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F) + let Bool ac = compileForceOne @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) (omega, k1, k2) = getParams 32 - witnessInputs = fromList [(1, targetValue), (unPar1 $ acOutput ac, 1)] + witnessInputs = V.singleton targetValue indexTargetValue = V.singleton (1 :: Natural) plonk = Plonk @32 omega k1 k2 indexTargetValue ac x setupP = setupProve @(PlonkBS N) plonk @@ -83,10 +85,10 @@ testSafeOneInputZKP x ps targetValue = testAttackSafeOneInputZKP :: F -> PlonkProverSecret C -> F -> Haskell.Bool testAttackSafeOneInputZKP x ps targetValue = - let Bool ac = compileForceOne @F (lockedByTxId @F @(ArithmeticCircuit F) targetValue) :: Bool (ArithmeticCircuit F) + let Bool ac = compileForceOne @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) (omega, k1, k2) = getParams 32 - witnessInputs = fromList [(1, targetValue + 1), (unPar1 $ acOutput ac, 0)] + witnessInputs = V.singleton (targetValue + 1) indexTargetValue = V.singleton (1 :: Natural) plonk = Plonk @32 omega k1 k2 indexTargetValue ac x setupP = setupProve @(PlonkBS N) plonk diff --git a/tests/Tests/Blake2b.hs b/tests/Tests/Blake2b.hs index 435b83d6a..1e7ac76f4 100644 --- a/tests/Tests/Blake2b.hs +++ b/tests/Tests/Blake2b.hs @@ -41,8 +41,8 @@ blake2bSimple = blake2bAC :: Spec blake2bAC = - let bs = compile @(Zp BLS12_381_Scalar) (blake2b_512 @8 @(ArithmeticCircuit (Zp BLS12_381_Scalar))) :: ByteString 512 (ArithmeticCircuit (Zp BLS12_381_Scalar)) - ac = pieces @(ArithmeticCircuit (Zp BLS12_381_Scalar)) bs () + let bs = compile @64 @(Zp BLS12_381_Scalar) (blake2b_512 @8 @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 64))) :: ByteString 512 (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 64)) + ac = pieces @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 64)) bs () in it "simple test with cardano-crypto " $ acSizeN ac == 564239 specBlake2b :: IO () diff --git a/tests/Tests/ByteString.hs b/tests/Tests/ByteString.hs index aa49e0619..60c9a372c 100644 --- a/tests/Tests/ByteString.hs +++ b/tests/Tests/ByteString.hs @@ -9,6 +9,7 @@ import Control.Monad (return) import Data.Function (($)) import Data.Functor ((<$>)) import Data.List ((++)) +import GHC.Generics (U1) import Prelude (show, type (~), (<>)) import qualified Prelude as Haskell import System.IO (IO) @@ -31,22 +32,22 @@ import ZkFold.Symbolic.Interpreter (Interpreter (Inter toss :: Natural -> Gen Natural toss x = chooseNatural (0, x) -eval :: forall a n . ByteString n (ArithmeticCircuit a) -> ByteString n (Interpreter a) +eval :: forall a n . ByteString n (ArithmeticCircuit a U1) -> ByteString n (Interpreter a) eval (ByteString bits) = ByteString $ Interpreter (exec bits) type Binary a = a -> a -> a type UBinary n b = Binary (ByteString n b) -isHom :: (KnownNat n, PrimeField (Zp p)) => UBinary n (Interpreter (Zp p)) -> UBinary n (ArithmeticCircuit (Zp p)) -> Natural -> Natural -> Property +isHom :: (KnownNat n, PrimeField (Zp p)) => UBinary n (Interpreter (Zp p)) -> UBinary n (ArithmeticCircuit (Zp p) U1) -> Natural -> Natural -> Property isHom f g x y = eval (fromConstant x `g` fromConstant y) === fromConstant x `f` fromConstant y isRightNeutral :: (KnownNat n, PrimeField (Zp p)) => UBinary n (Interpreter (Zp p)) - -> UBinary n (ArithmeticCircuit (Zp p)) + -> UBinary n (ArithmeticCircuit (Zp p) U1) -> ByteString n (Interpreter (Zp p)) - -> ByteString n (ArithmeticCircuit (Zp p)) + -> ByteString n (ArithmeticCircuit (Zp p) U1) -> Natural -> Property isRightNeutral f g n1 n2 x = eval (fromConstant x `g` n2) === fromConstant x `f` n1 @@ -54,9 +55,9 @@ isRightNeutral f g n1 n2 x = eval (fromConstant x `g` n2) === fromConstant x `f` isLeftNeutral :: (KnownNat n, PrimeField (Zp p)) => UBinary n (Interpreter (Zp p)) - -> UBinary n (ArithmeticCircuit (Zp p)) + -> UBinary n (ArithmeticCircuit (Zp p) U1) -> ByteString n (Interpreter (Zp p)) - -> ByteString n (ArithmeticCircuit (Zp p)) + -> ByteString n (ArithmeticCircuit (Zp p) U1) -> Natural -> Property isLeftNeutral f g n1 n2 x = eval (n2 `g` fromConstant x) === n1 `f` fromConstant x @@ -66,14 +67,14 @@ testWords . KnownNat n => PrimeField (Zp p) => KnownNat wordSize - => ToWords (ByteString n (ArithmeticCircuit (Zp p))) (ByteString wordSize (ArithmeticCircuit (Zp p))) + => ToWords (ByteString n (ArithmeticCircuit (Zp p) U1)) (ByteString wordSize (ArithmeticCircuit (Zp p) U1)) => ToWords (ByteString n (Interpreter (Zp p))) (ByteString wordSize (Interpreter (Zp p))) => Spec testWords = it ("divides a bytestring of length " <> show (value @n) <> " into words of length " <> show (value @wordSize)) $ do x <- toss m - let arithBS = fromConstant x :: ByteString n (ArithmeticCircuit (Zp p)) + let arithBS = fromConstant x :: ByteString n (ArithmeticCircuit (Zp p) U1) zpBS = fromConstant x :: ByteString n (Interpreter (Zp p)) - return (Haskell.fmap eval (toWords arithBS :: [ByteString wordSize (ArithmeticCircuit (Zp p))]) === toWords zpBS) + return (Haskell.fmap eval (toWords arithBS :: [ByteString wordSize (ArithmeticCircuit (Zp p) U1)]) === toWords zpBS) where n = Haskell.toInteger $ value @n m = 2 Haskell.^ n -! 1 @@ -83,14 +84,14 @@ testTruncate . KnownNat n => PrimeField (Zp p) => KnownNat m - => Truncate (ByteString n (ArithmeticCircuit (Zp p))) (ByteString m (ArithmeticCircuit (Zp p))) + => Truncate (ByteString n (ArithmeticCircuit (Zp p) U1)) (ByteString m (ArithmeticCircuit (Zp p) U1)) => Truncate (ByteString n (Interpreter (Zp p))) (ByteString m (Interpreter (Zp p))) => Spec testTruncate = it ("truncates a bytestring of length " <> show (value @n) <> " to length " <> show (value @m)) $ do x <- toss m - let arithBS = fromConstant x :: ByteString n (ArithmeticCircuit (Zp p)) + let arithBS = fromConstant x :: ByteString n (ArithmeticCircuit (Zp p) U1) zpBS = fromConstant x :: ByteString n (Interpreter (Zp p)) - return (eval (truncate arithBS :: ByteString m (ArithmeticCircuit (Zp p))) === truncate zpBS) + return (eval (truncate arithBS :: ByteString m (ArithmeticCircuit (Zp p) U1)) === truncate zpBS) where n = Haskell.toInteger $ value @n m = 2 Haskell.^ n -! 1 @@ -100,14 +101,14 @@ testGrow . KnownNat n => PrimeField (Zp p) => KnownNat m - => Extend (ByteString n (ArithmeticCircuit (Zp p))) (ByteString m (ArithmeticCircuit (Zp p))) + => Extend (ByteString n (ArithmeticCircuit (Zp p) U1)) (ByteString m (ArithmeticCircuit (Zp p) U1)) => Extend (ByteString n (Interpreter (Zp p))) (ByteString m (Interpreter (Zp p))) => Spec testGrow = it ("extends a bytestring of length " <> show (value @n) <> " to length " <> show (value @m)) $ do x <- toss m - let arithBS = fromConstant x :: ByteString n (ArithmeticCircuit (Zp p)) + let arithBS = fromConstant x :: ByteString n (ArithmeticCircuit (Zp p) U1) zpBS = fromConstant x :: ByteString n (Interpreter (Zp p)) - return (eval (extend arithBS :: ByteString m (ArithmeticCircuit (Zp p))) === extend zpBS) + return (eval (extend arithBS :: ByteString m (ArithmeticCircuit (Zp p) U1)) === extend zpBS) where n = Haskell.toInteger $ value @n m = 2 Haskell.^ n -! 1 @@ -162,10 +163,10 @@ specByteString' = hspec $ do x <- toss m y <- toss m - let acX :: ByteString n (ArithmeticCircuit (Zp p)) = fromConstant x - acY :: ByteString n (ArithmeticCircuit (Zp p)) = fromConstant y + let acX :: ByteString n (ArithmeticCircuit (Zp p) U1) = fromConstant x + acY :: ByteString n (ArithmeticCircuit (Zp p) U1) = fromConstant y - acSum :: ByteString n (ArithmeticCircuit (Zp p)) = from $ from acX + (from acY :: UInt n Auto (ArithmeticCircuit (Zp p))) + acSum :: ByteString n (ArithmeticCircuit (Zp p) U1) = from $ from acX + (from acY :: UInt n Auto (ArithmeticCircuit (Zp p) U1)) zpSum :: ByteString n (Interpreter (Zp p)) = fromConstant $ x + y @@ -201,9 +202,9 @@ specByteString' = hspec $ do x <- toss m y <- toss m z <- toss m - let acs = fromConstant @Natural @(ByteString n (ArithmeticCircuit (Zp p))) <$> [x, y, z] + let acs = fromConstant @Natural @(ByteString n (ArithmeticCircuit (Zp p) U1)) <$> [x, y, z] zps = fromConstant @Natural @(ByteString n (Interpreter (Zp p))) <$> [x, y, z] - let ac = concat acs :: ByteString (3 * n) (ArithmeticCircuit (Zp p)) + let ac = concat acs :: ByteString (3 * n) (ArithmeticCircuit (Zp p) U1) zp = concat zps return $ eval @(Zp p) @(3 * n) ac === zp testTruncate @n @1 @p diff --git a/tests/Tests/FFA.hs b/tests/Tests/FFA.hs index 46ccc6553..9bf41947b 100644 --- a/tests/Tests/FFA.hs +++ b/tests/Tests/FFA.hs @@ -6,6 +6,7 @@ module Tests.FFA (specFFA) where import Data.Function (($)) import Data.List ((++)) +import GHC.Generics (U1) import System.IO (IO) import Test.Hspec (describe, hspec) import Test.QuickCheck (Property, withMaxSuccess, (===)) @@ -47,7 +48,7 @@ specFFA' = hspec $ do execAcFFA @p @q (negate $ fromConstant x) === execZpFFA @p @q (negate $ fromConstant x) it "multiplies correctly" $ withMaxSuccess 1 $ isHom @p @q (*) (*) -execAcFFA :: (PrimeField (Zp p), PrimeField (Zp q)) => FFA q (ArithmeticCircuit (Zp p)) -> Zp q +execAcFFA :: (PrimeField (Zp p), PrimeField (Zp q)) => FFA q (ArithmeticCircuit (Zp p) U1) -> Zp q execAcFFA (FFA v) = execZpFFA $ FFA $ Interpreter (exec v) execZpFFA :: (PrimeField (Zp p), PrimeField (Zp q)) => FFA q (Interpreter (Zp p)) -> Zp q @@ -56,5 +57,5 @@ execZpFFA = toConstant type Binary a = a -> a -> a type Predicate a = a -> a -> Property -isHom :: (PrimeField (Zp p), PrimeField (Zp q)) => Binary (FFA q (Interpreter (Zp p))) -> Binary (FFA q (ArithmeticCircuit (Zp p))) -> Predicate (Zp q) +isHom :: (PrimeField (Zp p), PrimeField (Zp q)) => Binary (FFA q (Interpreter (Zp p))) -> Binary (FFA q (ArithmeticCircuit (Zp p) U1)) -> Predicate (Zp q) isHom f g x y = execAcFFA (fromConstant x `g` fromConstant y) === execZpFFA (fromConstant x `f` fromConstant y) diff --git a/tests/Tests/NonInteractiveProof/Internal.hs b/tests/Tests/NonInteractiveProof/Internal.hs index 44e42c886..2979cf79a 100644 --- a/tests/Tests/NonInteractiveProof/Internal.hs +++ b/tests/Tests/NonInteractiveProof/Internal.hs @@ -18,8 +18,6 @@ import ZkFold.Base.Protocol.ARK.Plonk (Plonk (Plonk), getParams) import ZkFold.Base.Protocol.Commitment.KZG (KZG) import ZkFold.Base.Protocol.NonInteractiveProof (NonInteractiveProof (..)) -import ZkFold.Prelude (length) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit (acInput, witnessGenerator) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (ArithmeticCircuitTest (..)) data NonInteractiveProofTestData a = TestData a (Witness a) @@ -36,10 +34,10 @@ instance (KZG c1 c2 d ~ kzg, NonInteractiveProof kzg, Arbitrary kzg, Arbitrary ( instance forall n . (KnownNat n) => Arbitrary (NonInteractiveProofTestData (PlonkBS n)) where arbitrary = do - ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField BLS12_381_G1) Par1) - let inputLen = length . acInput $ ac + ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField BLS12_381_G1) (Vector n) Par1) + let inputLen = value @n vecPubInp <- genSubset (value @n) inputLen let (omega, k1, k2) = getParams $ value @PlonkSizeBS pl <- Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary secret <- arbitrary - return $ TestData pl (PlonkWitnessInput $ witnessGenerator ac wi, secret) + return $ TestData pl (PlonkWitnessInput wi, secret) diff --git a/tests/Tests/NonInteractiveProof/Plonk.hs b/tests/Tests/NonInteractiveProof/Plonk.hs index b0435a90e..7da087772 100644 --- a/tests/Tests/NonInteractiveProof/Plonk.hs +++ b/tests/Tests/NonInteractiveProof/Plonk.hs @@ -4,8 +4,8 @@ module Tests.NonInteractiveProof.Plonk (PlonkBS, specPlonk) where import Data.ByteString (ByteString) +import Data.Functor.Rep (Representable (..)) import Data.List (transpose) -import Data.Map ((!)) import Data.Maybe (fromJust) import qualified Data.Vector as V import GHC.IsList (IsList (..)) @@ -43,7 +43,7 @@ propPlonkConstraintSatisfaction (TestData (Plonk _ _ _ iPub ac _) w) = (PlonkWitnessInput wInput, _) = w (w1', w2', w3') = wmap pr wInput - wPub = toPolyVec @_ @PlonkPolyLengthBS $ fmap (negate . (wInput !)) $ fromList @(V.Vector Natural) $ fromVector iPub + wPub = toPolyVec @_ @PlonkPolyLengthBS $ fmap (negate . index wInput . fromIntegral) $ fromList @(V.Vector Natural) $ fromVector iPub qm' = V.toList $ fromPolyVec $ qM pr ql' = V.toList $ fromPolyVec $ qL pr @@ -67,7 +67,7 @@ propPlonkPolyIdentity (TestData plonk w) = PlonkProverSecret b1 b2 b3 b4 b5 b6 _ _ _ _ _ = ps (w1, w2, w3) = wmap wInput - wPub = fmap (negate . (wInput !)) iPub' + wPub = fmap (negate . index wInput . fromIntegral) iPub' pubPoly = polyVecInLagrangeBasis @(ScalarField BLS12_381_G1) @PlonkPolyLengthBS @PlonkPolyExtendedLengthBS omega' $ toPolyVec @(ScalarField BLS12_381_G1) @PlonkPolyLengthBS wPub diff --git a/tests/Tests/SHA2.hs b/tests/Tests/SHA2.hs index 6e12c2d62..af0712429 100644 --- a/tests/Tests/SHA2.hs +++ b/tests/Tests/SHA2.hs @@ -10,6 +10,7 @@ import Data.Functor ((<$>)) import Data.List (isPrefixOf, isSuffixOf, take, (++)) import Data.List.Split (splitOn) import Data.Proxy (Proxy (..)) +import GHC.Generics (U1) import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) import Prelude (String, fmap, otherwise, pure, read, (<>), (==)) import qualified Prelude as Haskell @@ -133,13 +134,13 @@ specSHA2Natural = do toss :: Natural -> Gen Natural toss x = chooseNatural (0, x) -eval :: forall a n . ByteString n (ArithmeticCircuit a) -> Vector n a +eval :: forall a n . ByteString n (ArithmeticCircuit a U1) -> Vector n a eval (ByteString bits) = exec bits specSHA2bs :: forall (n :: Natural) (algorithm :: Symbol) . KnownSymbol algorithm - => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar)) n + => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar) U1) n => SHA2N algorithm (Interpreter (Zp BLS12_381_Scalar)) => Spec specSHA2bs = do @@ -147,7 +148,7 @@ specSHA2bs = do m = 2 ^ n -! 1 it ("calculates " <> symbolVal (Proxy @algorithm) <> " of a " <> Haskell.show n <> "-bit bytestring") $ withMaxSuccess 2 $ do x <- toss m - let hashAC = sha2 @algorithm @(ArithmeticCircuit (Zp BLS12_381_Scalar)) @n $ fromConstant x + let hashAC = sha2 @algorithm @(ArithmeticCircuit (Zp BLS12_381_Scalar) U1) @n $ fromConstant x ByteString (Interpreter hashZP) = sha2Natural @algorithm @(Interpreter (Zp BLS12_381_Scalar)) n x pure $ eval @(Zp BLS12_381_Scalar) @(ResultSize algorithm) hashAC === hashZP @@ -158,10 +159,10 @@ specSHA2' :: forall (algorithm :: Symbol) . KnownSymbol algorithm => SHA2N algorithm (Interpreter (Zp BLS12_381_Scalar)) - => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar)) 1 - => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar)) 63 - => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar)) 64 - => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar)) 1900 + => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar) U1) 1 + => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar) U1) 63 + => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar) U1) 64 + => SHA2 algorithm (ArithmeticCircuit (Zp BLS12_381_Scalar) U1) 1900 => IO () specSHA2' = hspec $ do specSHA2bs @1 @algorithm diff --git a/tests/Tests/UInt.hs b/tests/Tests/UInt.hs index 4e4b3f16d..ff80727af 100644 --- a/tests/Tests/UInt.hs +++ b/tests/Tests/UInt.hs @@ -11,7 +11,7 @@ import Control.Monad (return, when) import Data.Function (($)) import Data.Functor ((<$>)) import Data.List ((++)) -import GHC.Generics (Par1 (Par1)) +import GHC.Generics (Par1 (Par1), U1) import Prelude (show, type (~)) import qualified Prelude as P import System.IO (IO) @@ -37,13 +37,13 @@ import ZkFold.Symbolic.Interpreter (Interpreter (Inter toss :: Natural -> Gen Natural toss x = chooseNatural (0, x) -evalBool :: forall a . Bool (ArithmeticCircuit a) -> a +evalBool :: forall a . Bool (ArithmeticCircuit a U1) -> a evalBool (Bool ac) = exec1 ac evalBoolVec :: forall a . Bool (Interpreter a) -> a evalBoolVec (Bool (Interpreter (Par1 v))) = v -execAcUint :: forall a n r . UInt n r (ArithmeticCircuit a)-> Vector (NumberOfRegisters a n r) a +execAcUint :: forall a n r . UInt n r (ArithmeticCircuit a U1) -> Vector (NumberOfRegisters a n r) a execAcUint (UInt v) = exec v execZpUint :: forall a n r . UInt n r (Interpreter a) -> Vector (NumberOfRegisters a n r) a @@ -53,7 +53,7 @@ type Binary a = a -> a -> a type UBinary n b r = Binary (UInt n b r) -isHom :: (KnownNat n, PrimeField (Zp p), KnownRegisterSize r) => UBinary n r (Interpreter (Zp p)) -> UBinary n r (ArithmeticCircuit (Zp p)) -> Natural -> Natural -> Property +isHom :: (KnownNat n, PrimeField (Zp p), KnownRegisterSize r) => UBinary n r (Interpreter (Zp p)) -> UBinary n r (ArithmeticCircuit (Zp p) U1) -> Natural -> Natural -> Property isHom f g x y = execAcUint (fromConstant x `g` fromConstant y) === execZpUint (fromConstant x `f` fromConstant y) specUInt' @@ -101,7 +101,7 @@ specUInt' = hspec $ do when (n <= 128) $ it "performs divMod correctly" $ withMaxSuccess 10 $ do num <- toss m d <- toss m - let (acQ, acR) = (fromConstant num :: UInt n rs (ArithmeticCircuit (Zp p))) `divMod` (fromConstant d) + let (acQ, acR) = (fromConstant num :: UInt n rs (ArithmeticCircuit (Zp p) U1)) `divMod` (fromConstant d) let (zpQ, zpR) = (fromConstant num :: UInt n rs (Interpreter (Zp p))) `divMod` (fromConstant d) return $ (execAcUint acQ, execAcUint acR) === (execZpUint zpQ, execZpUint zpR) @@ -136,19 +136,19 @@ specUInt' = hspec $ do it "extends correctly" $ do x <- toss m - let acUint = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p)) + let acUint = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p) U1) zpUint = fromConstant x :: UInt (2 * n) rs (Interpreter (Zp p)) - return $ execAcUint @(Zp p) (extend acUint :: UInt (2 * n) rs (ArithmeticCircuit (Zp p))) === execZpUint zpUint + return $ execAcUint @(Zp p) (extend acUint :: UInt (2 * n) rs (ArithmeticCircuit (Zp p) U1)) === execZpUint zpUint it "shrinks correctly" $ do x <- toss (m * m) - let acUint = fromConstant x :: UInt (2 * n) rs (ArithmeticCircuit (Zp p)) + let acUint = fromConstant x :: UInt (2 * n) rs (ArithmeticCircuit (Zp p) U1) zpUint = fromConstant x :: UInt n rs (Interpreter (Zp p)) - return $ execAcUint @(Zp p) (shrink acUint :: UInt n rs (ArithmeticCircuit (Zp p))) === execZpUint zpUint + return $ execAcUint @(Zp p) (shrink acUint :: UInt n rs (ArithmeticCircuit (Zp p) U1)) === execZpUint zpUint it "checks equality" $ do x <- toss m - let acUint = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p)) + let acUint = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p) U1) return $ evalBool @(Zp p) (acUint == acUint) === one it "checks inequality" $ do @@ -156,8 +156,8 @@ specUInt' = hspec $ do y' <- toss m let y = if y' P.== x then x + 1 else y' - let acUint1 = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p)) - acUint2 = fromConstant y :: UInt n rs (ArithmeticCircuit (Zp p)) + let acUint1 = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p) U1) + acUint2 = fromConstant y :: UInt n rs (ArithmeticCircuit (Zp p) U1) return $ evalBool @(Zp p) (acUint1 == acUint2) === zero @@ -166,8 +166,8 @@ specUInt' = hspec $ do y <- toss m let x' = fromConstant x :: UInt n rs (Interpreter (Zp p)) y' = fromConstant y :: UInt n rs (Interpreter (Zp p)) - x'' = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p)) - y'' = fromConstant y :: UInt n rs (ArithmeticCircuit (Zp p)) + x'' = fromConstant x :: UInt n rs (ArithmeticCircuit (Zp p) U1) + y'' = fromConstant y :: UInt n rs (ArithmeticCircuit (Zp p) U1) gt' = evalBoolVec $ x' > y' gt'' = evalBool @(Zp p) (x'' > y'') return $ gt' === gt'' diff --git a/zkfold-base.cabal b/zkfold-base.cabal index 1846eec75..072d31119 100644 --- a/zkfold-base.cabal +++ b/zkfold-base.cabal @@ -261,6 +261,7 @@ test-suite zkfold-base-test Tests.Univariate.PolyVec build-depends: base >= 4.9 && < 5, + adjunctions < 4.5, binary < 0.11, bytestring , blake2 , From 14d003026ca282f5fcdefaa4318d3958d10ce6a7 Mon Sep 17 00:00:00 2001 From: echatav Date: Fri, 9 Aug 2024 17:36:55 +0000 Subject: [PATCH 05/48] stylish-haskell auto-commit --- examples/Examples/ByteString.hs | 2 +- examples/Examples/Eq.hs | 2 +- examples/Examples/FFA.hs | 3 +- examples/Examples/Fibonacci.hs | 2 +- examples/Examples/LEQ.hs | 2 +- examples/Examples/MiMCHash.hs | 2 +- examples/Examples/UInt.hs | 2 +- src/ZkFold/Base/Data/Vector.hs | 2 +- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 2 +- .../Base/Protocol/ARK/Plonk/Internal.hs | 2 +- .../Base/Protocol/ARK/Plonk/Relation.hs | 24 +++++----- src/ZkFold/Base/Protocol/ARK/Protostar.hs | 4 +- src/ZkFold/Symbolic/Cardano/Types/Output.hs | 2 +- .../Symbolic/Compiler/ArithmeticCircuit.hs | 14 +++--- .../Compiler/ArithmeticCircuit/Combinators.hs | 3 +- .../Compiler/ArithmeticCircuit/Instance.hs | 3 +- .../Compiler/ArithmeticCircuit/Internal.hs | 14 +++--- .../Compiler/ArithmeticCircuit/Map.hs | 7 +-- src/ZkFold/Symbolic/Data/ByteString.hs | 4 +- tests/Tests/ArithmeticCircuit.hs | 1 - tests/Tests/Arithmetization.hs | 2 +- tests/Tests/Arithmetization/Test2.hs | 2 +- tests/Tests/Arithmetization/Test4.hs | 44 +++++++++---------- 23 files changed, 72 insertions(+), 73 deletions(-) diff --git a/examples/Examples/ByteString.hs b/examples/Examples/ByteString.hs index 06deb6c20..630d81091 100644 --- a/examples/Examples/ByteString.hs +++ b/examples/Examples/ByteString.hs @@ -18,7 +18,7 @@ import Text.Show (show) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler (ArithmeticCircuit, compileIO) import ZkFold.Symbolic.Data.Bool import ZkFold.Symbolic.Data.ByteString diff --git a/examples/Examples/Eq.hs b/examples/Examples/Eq.hs index 7873186da..848a434da 100644 --- a/examples/Examples/Eq.hs +++ b/examples/Examples/Eq.hs @@ -7,7 +7,7 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) diff --git a/examples/Examples/FFA.hs b/examples/Examples/FFA.hs index 7145e1a7a..e475e9d19 100644 --- a/examples/Examples/FFA.hs +++ b/examples/Examples/FFA.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE AllowAmbiguousTypes, TypeOperators #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE TypeOperators #-} module Examples.FFA (examplesFFA) where diff --git a/examples/Examples/Fibonacci.hs b/examples/Examples/Fibonacci.hs index 662e8ca5e..7cbc96e60 100644 --- a/examples/Examples/Fibonacci.hs +++ b/examples/Examples/Fibonacci.hs @@ -7,7 +7,7 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) diff --git a/examples/Examples/LEQ.hs b/examples/Examples/LEQ.hs index c91852404..3535a532d 100644 --- a/examples/Examples/LEQ.hs +++ b/examples/Examples/LEQ.hs @@ -7,7 +7,7 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool) diff --git a/examples/Examples/MiMCHash.hs b/examples/Examples/MiMCHash.hs index 894ce6f8f..3f8b5ceb1 100644 --- a/examples/Examples/MiMCHash.hs +++ b/examples/Examples/MiMCHash.hs @@ -8,7 +8,7 @@ import Prelude hiding (Eq (..), import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Algorithms.Hash.MiMC (mimcHash2) import ZkFold.Symbolic.Algorithms.Hash.MiMC.Constants (mimcConstants) import ZkFold.Symbolic.Compiler diff --git a/examples/Examples/UInt.hs b/examples/Examples/UInt.hs index 7a67e96c0..679f060ce 100644 --- a/examples/Examples/UInt.hs +++ b/examples/Examples/UInt.hs @@ -22,7 +22,7 @@ import Text.Show (show) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler (ArithmeticCircuit, compileIO) import ZkFold.Symbolic.Data.Combinators import ZkFold.Symbolic.Data.UInt diff --git a/src/ZkFold/Base/Data/Vector.hs b/src/ZkFold/Base/Data/Vector.hs index a7c17fccd..030b9a84f 100644 --- a/src/ZkFold/Base/Data/Vector.hs +++ b/src/ZkFold/Base/Data/Vector.hs @@ -10,7 +10,7 @@ import Control.Parallel.Strategies (parMap, rpar) import Data.Aeson (ToJSON (..)) import Data.Bifunctor (first) import Data.Distributive (Distributive (..)) -import Data.Functor.Rep (Representable (..), distributeRep, collectRep) +import Data.Functor.Rep (Representable (..), collectRep, distributeRep) import qualified Data.List as List import Data.List.Split (chunksOf) import Data.These (These (..)) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 3223dbe4a..3e064bd26 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -11,8 +11,8 @@ module ZkFold.Base.Protocol.ARK.Plonk ( plonkVerifierInput ) where -import Data.Maybe (fromJust) import Data.Functor.Rep (Representable (..)) +import Data.Maybe (fromJust) import qualified Data.Vector as V import GHC.Generics (Par1) import GHC.IsList (IsList (..)) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index b65dc9f4d..5992b0b13 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -16,8 +16,8 @@ import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Point) import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) -import ZkFold.Prelude (take) import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Prelude (take) log2 :: (Integral a, Integral b) => a -> b log2 = ceiling @Double . logBase 2 . fromIntegral diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs index a13649246..966d0cbd6 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs @@ -3,20 +3,20 @@ module ZkFold.Base.Protocol.ARK.Plonk.Relation where -import Data.Map (elems, (!)) -import GHC.Generics (Par1) -import GHC.IsList (IsList (..)) -import Prelude hiding (Num (..), drop, length, replicate, sum, take, - (!!), (/), (^)) +import Data.Map (elems, (!)) +import GHC.Generics (Par1) +import GHC.IsList (IsList (..)) +import Prelude hiding (Num (..), drop, length, replicate, sum, + take, (!!), (/), (^)) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number -import ZkFold.Base.Algebra.Basic.Permutations (Permutation, fromCycles, mkIndexPartition) -import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) -import ZkFold.Base.Algebra.Polynomials.Univariate (PolyVec, toPolyVec) -import ZkFold.Base.Data.Vector (Vector, fromVector) -import ZkFold.Base.Protocol.ARK.Plonk.Constraint (PlonkConstraint (..), toPlonkConstraint) -import ZkFold.Prelude (replicate) +import ZkFold.Base.Algebra.Basic.Permutations (Permutation, fromCycles, mkIndexPartition) +import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) +import ZkFold.Base.Algebra.Polynomials.Univariate (PolyVec, toPolyVec) +import ZkFold.Base.Data.Vector (Vector, fromVector) +import ZkFold.Base.Protocol.ARK.Plonk.Constraint (PlonkConstraint (..), toPlonkConstraint) +import ZkFold.Prelude (replicate) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal @@ -44,7 +44,7 @@ toPlonkRelation xPub ac0 = let ac = desugarRanges ac0 varF (NewVar ix) = if ix == 0 then one else var (ix + value @l) - varF (InVar ix) = var (fromIntegral ix) + varF (InVar ix) = var (fromIntegral ix) evalX0 = evalPolynomial evalMonomial varF pubInputConstraints = map var (fromVector xPub) diff --git a/src/ZkFold/Base/Protocol/ARK/Protostar.hs b/src/ZkFold/Base/Protocol/ARK/Protostar.hs index 458fb81df..f0741ea56 100644 --- a/src/ZkFold/Base/Protocol/ARK/Protostar.hs +++ b/src/ZkFold/Base/Protocol/ARK/Protostar.hs @@ -11,10 +11,10 @@ import Prelude (($), (==)) import qualified Prelude as P import ZkFold.Base.Algebra.Basic.Number +import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) import ZkFold.Base.Data.Vector (Vector) import ZkFold.Base.Protocol.ARK.Protostar.SpecialSound import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal -import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) {-- @@ -66,7 +66,7 @@ instance (Arithmetic a, KnownNat n) => SpecialSoundProtocol a (RecursiveCircuit algebraicMap rc _ _ _ = let varF (NewVar ix) = var (ix P.+ value @n) - varF (InVar ix) = var (P.fromIntegral ix) + varF (InVar ix) = var (P.fromIntegral ix) in [ evalPolynomial evalMonomial varF poly | poly <- M.elems $ acSystem (circuit rc) diff --git a/src/ZkFold/Symbolic/Cardano/Types/Output.hs b/src/ZkFold/Symbolic/Cardano/Types/Output.hs index 1075092ef..4fbf3fee3 100644 --- a/src/ZkFold/Symbolic/Cardano/Types/Output.hs +++ b/src/ZkFold/Symbolic/Cardano/Types/Output.hs @@ -11,7 +11,7 @@ module ZkFold.Symbolic.Cardano.Types.Output ( txoDatumHash ) where -import Data.Functor.Rep (Representable (..)) +import Data.Functor.Rep (Representable (..)) import Prelude hiding (Bool, Eq, length, splitAt, (*), (+)) import qualified Prelude as Haskell diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 86e8b5fb1..b6889ac97 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -1,5 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeOperators #-} module ZkFold.Symbolic.Compiler.ArithmeticCircuit ( ArithmeticCircuit, @@ -40,8 +40,8 @@ import GHC.Generics (U1 (..) import Numeric.Natural (Natural) import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) -import Test.QuickCheck (Arbitrary, Property, conjoin, property, - withMaxSuccess, (===), arbitrary) +import Test.QuickCheck (Arbitrary, Property, arbitrary, conjoin, + property, withMaxSuccess, (===)) import Text.Pretty.Simple (pPrint) import ZkFold.Base.Algebra.Basic.Class @@ -50,8 +50,8 @@ import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (desugarRange) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, - eval, eval1, exec, exec1, - witnessGenerator, Var (..), acInput) + Var (..), acInput, eval, eval1, exec, exec1, + witnessGenerator) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map --------------------------------- High-level functions -------------------------------- @@ -126,7 +126,7 @@ checkClosedCircuit c = withMaxSuccess 1 $ conjoin [ testPoly p | p <- elems (acS w = witnessGenerator c U1 testPoly p = evalPolynomial evalMonomial varF p === zero varF (NewVar v) = w ! v - varF (InVar v) = absurd v + varF (InVar v) = absurd v checkCircuit :: Arbitrary (i a) @@ -142,6 +142,6 @@ checkCircuit c = conjoin [ property (testPoly p) | p <- elems (acSystem c) ] ins <- arbitrary let w = witnessGenerator c ins varF (NewVar v) = w ! v - varF (InVar v) = index ins v + varF (InVar v) = index ins v return $ evalPolynomial evalMonomial varF p === zero diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs index 928b2ecaf..dc0cc8694 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs @@ -32,7 +32,6 @@ import Data.Map (elem import Data.Traversable (for) import qualified Data.Zip as Z import GHC.Generics (Par1) --- import GHC.IsList (IsList (..)) import Prelude hiding (Bool, Eq (..), drop, length, negate, splitAt, take, (!!), (*), (+), (-), (^)) @@ -42,7 +41,7 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (vari import qualified ZkFold.Base.Data.Vector as V import ZkFold.Base.Data.Vector (Vector (..)) import ZkFold.Prelude (drop, length, take, (!!)) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), acInput, Var (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), Var (..), acInput) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint import ZkFold.Symbolic.MonadCircuit diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs index 124897b16..845513a7c 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs @@ -15,8 +15,7 @@ import Prelude (Show, m (<$>)) import qualified Prelude as Haskell import System.Random (mkStdGen) -import Test.QuickCheck (Arbitrary (arbitrary), Gen, - elements) +import Test.QuickCheck (Arbitrary (arbitrary), Gen, elements) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 5387e388e..a0ed643d0 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -1,5 +1,5 @@ {-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -27,7 +27,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( import Control.DeepSeq (NFData, force) import Control.Monad.State (MonadState (..), State, gets, modify, runState) -import Data.Aeson (ToJSON, ToJSONKey, FromJSON, FromJSONKey) +import Data.Aeson (FromJSON, FromJSONKey, ToJSON, ToJSONKey) import Data.Foldable (fold) import Data.Functor.Rep (Representable (..), fmapRep) import Data.Map.Strict hiding (drop, foldl, foldr, map, null, splitAt, take) @@ -140,7 +140,7 @@ instance (Arithmetic a, Ord (Rep i), Representable i, o ~ U1) => MonadCircuit (V newConstrained new witness = do let ws = sources @a witness varF (NewVar v) = NewVar (v + 1) - varF (InVar v) = InVar v + varF (InVar v) = InVar v -- | We need a throwaway variable to feed into `new` which definitely would not be present in a witness x = maximum (S.mapMonotonic varF ws <> S.singleton (NewVar 0)) -- | `s` is meant to be a set of variables used in a witness not present in a constraint. @@ -202,7 +202,7 @@ toVar srcs c = force $ fromZp ex r = toZp 903489679376934896793395274328947923579382759823 :: VarField g = toZp 89175291725091202781479751781509570912743212325 :: VarField varF (NewVar w) = w - varF (InVar _) = 0 + varF (InVar _) = 0 v = (+ r) . fromConstant . varF x = g ^ fromZp (evalPolynomial evalMonomial v $ mapCoeffs toField c) ex = foldr (\p y -> x ^ (varF p) + y) x srcs @@ -240,7 +240,7 @@ assignment i f = zoom #acWitness . modify $ insert i f eval1 :: Representable i => ArithmeticCircuit a i Par1 -> i a -> a eval1 ctx i = case unPar1 (acOutput ctx) of NewVar k -> witnessGenerator ctx i ! k - InVar j -> index i j + InVar j -> index i j -- | Evaluates the arithmetic circuit using the supplied input map. eval :: (Representable i, Functor o) => ArithmeticCircuit a i o -> i a -> o a @@ -265,9 +265,9 @@ apply xs ac = ac , acOutput = U1 } where - varF (InVar (Left v)) = fromConstant (index xs v) + varF (InVar (Left v)) = fromConstant (index xs v) varF (InVar (Right v)) = var (InVar v) - varF (NewVar v) = var (NewVar v) + varF (NewVar v) = var (NewVar v) witF f j = f (xs :*: j) -- let inputs = acInput diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index dabb71d17..a7dec2fc0 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -6,11 +6,11 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map ( ArithmeticCircuitTest(..) ) where -import Data.Traversable (for) import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldr, fromList, map, null, splitAt, take, toList) import qualified Data.Map as Map +import Data.Traversable (for) import GHC.Generics (Par1) import GHC.IsList (IsList (..)) import Prelude hiding (Num (..), drop, length, product, @@ -20,7 +20,8 @@ import Test.QuickCheck (Arbitra import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (getAllVars) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), acInput, Var (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Var (..), + acInput) -- This module contains functions for mapping variables in arithmetic circuits. @@ -48,7 +49,7 @@ mapVarArithmeticCircuit (ArithmeticCircuitTest ac wi) = let vars = [v | NewVar v <- getAllVars ac] forward = Map.fromAscList $ zip vars [0..] backward = Map.fromAscList $ zip [0..] vars - varF (InVar v) = InVar v + varF (InVar v) = InVar v varF (NewVar v) = NewVar (forward ! v) mappedCircuit = ac { diff --git a/src/ZkFold/Symbolic/Data/ByteString.hs b/src/ZkFold/Symbolic/Data/ByteString.hs index 8d0b4d8d1..4c066283b 100644 --- a/src/ZkFold/Symbolic/Data/ByteString.hs +++ b/src/ZkFold/Symbolic/Data/ByteString.hs @@ -22,7 +22,7 @@ import Control.DeepSeq (NFDa import Control.Monad (replicateM) import Data.Bits as B import qualified Data.ByteString as Bytes -import Data.Functor.Rep (Representable (..)) +import Data.Functor.Rep (Representable (..)) import Data.Kind (Type) import Data.List (foldl, reverse, unfoldr) import Data.Maybe (Maybe (..)) @@ -46,13 +46,13 @@ import ZkFold.Base.Data.Vector (Vect import ZkFold.Prelude (replicateA, (!!)) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (embedV) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..)) import ZkFold.Symbolic.Data.Class (SymbolicData) import ZkFold.Symbolic.Data.Combinators import ZkFold.Symbolic.Interpreter (Interpreter (..)) import ZkFold.Symbolic.MonadCircuit (Arithmetic, newAssigned) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var) -- | A ByteString which stores @n@ bits and uses elements of @a@ as registers, one element per register. -- Bit layout is Big-endian. diff --git a/tests/Tests/ArithmeticCircuit.hs b/tests/Tests/ArithmeticCircuit.hs index 35cdb2568..2db4688cc 100644 --- a/tests/Tests/ArithmeticCircuit.hs +++ b/tests/Tests/ArithmeticCircuit.hs @@ -19,7 +19,6 @@ import qualified ZkFold.Base.Data.Vector as V import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (embed) import ZkFold.Symbolic.Data.Bool --- import ZkFold.Symbolic.Data.DiscreteField import ZkFold.Symbolic.Data.Eq import ZkFold.Symbolic.Data.FieldElement import ZkFold.Symbolic.MonadCircuit (Arithmetic) diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index d7cb37773..23912a757 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -16,7 +16,7 @@ import Tests.Arithmetization.Test4 (specArithmetiza import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (ArithmeticCircuitTest (..)) import ZkFold.Symbolic.MonadCircuit (Arithmetic) diff --git a/tests/Tests/Arithmetization/Test2.hs b/tests/Tests/Arithmetization/Test2.hs index 3d0242e00..7bfaa6687 100644 --- a/tests/Tests/Arithmetization/Test2.hs +++ b/tests/Tests/Arithmetization/Test2.hs @@ -12,12 +12,12 @@ import Test.QuickCheck (property) import ZkFold.Base.Algebra.Basic.Class (one) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (Fr) +import ZkFold.Base.Data.Vector (Vector, unsafeToVector) import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..)) import ZkFold.Symbolic.Data.Eq (Eq (..)) import ZkFold.Symbolic.Data.FieldElement (FieldElement) -import ZkFold.Base.Data.Vector (Vector, unsafeToVector) import ZkFold.Symbolic.MonadCircuit (Arithmetic) -- A true statement. diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index 6a21b31b7..e29b089e6 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -2,29 +2,29 @@ module Tests.Arithmetization.Test4 (specArithmetization4) where -import GHC.Generics (Par1 (unPar1)) -import GHC.Num (Natural) -import Prelude hiding (Bool, Eq (..), Num (..), Ord (..), (&&)) -import qualified Prelude as Haskell -import Test.Hspec (Spec, describe, it) -import Test.QuickCheck (Testable (..), withMaxSuccess, (==>)) -import Tests.NonInteractiveProof.Plonk (PlonkBS) - -import ZkFold.Base.Algebra.Basic.Class (FromConstant (..), one, zero, (+)) -import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1) -import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) -import qualified ZkFold.Base.Data.Vector as V -import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkInput (..), PlonkProverSecret, - PlonkWitnessInput (..), plonkVerifierInput) -import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams) -import ZkFold.Base.Protocol.NonInteractiveProof (NonInteractiveProof (..)) +import GHC.Generics (Par1 (unPar1)) +import GHC.Num (Natural) +import Prelude hiding (Bool, Eq (..), Num (..), Ord (..), (&&)) +import qualified Prelude as Haskell +import Test.Hspec (Spec, describe, it) +import Test.QuickCheck (Testable (..), withMaxSuccess, (==>)) +import Tests.NonInteractiveProof.Plonk (PlonkBS) + +import ZkFold.Base.Algebra.Basic.Class (FromConstant (..), one, zero, (+)) +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1) +import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) +import qualified ZkFold.Base.Data.Vector as V +import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkInput (..), PlonkProverSecret, + PlonkWitnessInput (..), plonkVerifierInput) +import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams) +import ZkFold.Base.Protocol.NonInteractiveProof (NonInteractiveProof (..)) import ZkFold.Symbolic.Class -import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), compile, - compileForceOne, eval) +import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), compile, compileForceOne, + eval) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var (..)) -import ZkFold.Symbolic.Data.Bool (Bool (..)) -import ZkFold.Symbolic.Data.Eq (Eq (..)) -import ZkFold.Symbolic.Data.FieldElement (FieldElement) +import ZkFold.Symbolic.Data.Bool (Bool (..)) +import ZkFold.Symbolic.Data.Eq (Eq (..)) +import ZkFold.Symbolic.Data.FieldElement (FieldElement) type N = 1 @@ -53,7 +53,7 @@ testOnlyOutputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue varF (NewVar ix) = ix + 1 - varF (InVar ix) = fromIntegral ix + varF (InVar ix) = fromIntegral ix indexOutputBool = V.singleton $ varF $ unPar1 $ acOutput ac plonk = Plonk @32 omega k1 k2 indexOutputBool ac x setupP = setupProve @(PlonkBS N) plonk From 84958def52f153af22f048e06539aa438712a345 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Wed, 14 Aug 2024 10:32:52 -0700 Subject: [PATCH 06/48] change type Rep (Vector size) to Zp size --- src/ZkFold/Base/Data/Vector.hs | 7 ++++--- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 2 +- src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs | 2 +- src/ZkFold/Base/Protocol/ARK/Protostar.hs | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ZkFold/Base/Data/Vector.hs b/src/ZkFold/Base/Data/Vector.hs index 030b9a84f..8480b48ed 100644 --- a/src/ZkFold/Base/Data/Vector.hs +++ b/src/ZkFold/Base/Data/Vector.hs @@ -23,6 +23,7 @@ import System.Random (Random (..)) import Test.QuickCheck (Arbitrary (..)) import ZkFold.Base.Algebra.Basic.Class +import ZkFold.Base.Algebra.Basic.Field import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Data.ByteString (Binary (..)) import qualified ZkFold.Prelude as ZP @@ -32,9 +33,9 @@ newtype Vector (size :: Natural) a = Vector [a] deriving (Show, Eq, Functor, Foldable, Traversable, Generic, NFData) instance KnownNat size => Representable (Vector size) where - type Rep (Vector size) = Int - index (Vector v) ix = v Prelude.!! ix - tabulate f = Vector [f ix | ix <- [0 .. fromIntegral (value @size) Prelude.- 1]] + type Rep (Vector size) = Zp size + index (Vector v) ix = v Prelude.!! (fromIntegral (fromZp ix)) + tabulate f = Vector [f (toZp ix) | ix <- [0 .. fromIntegral (value @size) Prelude.- 1]] instance KnownNat size => Distributive (Vector size) where distribute = distributeRep diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 3e064bd26..3fe6aae3d 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -47,7 +47,7 @@ data Plonk (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonk { ac :: ArithmeticCircuit (ScalarField curve1) (Vector l) Par1, x :: ScalarField curve1 } -instance (Show (ScalarField c1), Arithmetic (ScalarField c1)) => Show (Plonk n l c1 c2 t) where +instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l) => Show (Plonk n l c1 c2 t) where show (Plonk omega k1 k2 iPub ac x) = "Plonk: " ++ show omega ++ " " ++ show k1 ++ " " ++ show k2 ++ " " ++ show iPub ++ " " ++ show ac ++ " " ++ show x diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs index 966d0cbd6..c6dadbee9 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs @@ -44,7 +44,7 @@ toPlonkRelation xPub ac0 = let ac = desugarRanges ac0 varF (NewVar ix) = if ix == 0 then one else var (ix + value @l) - varF (InVar ix) = var (fromIntegral ix) + varF (InVar ix) = var (toConstant ix) evalX0 = evalPolynomial evalMonomial varF pubInputConstraints = map var (fromVector xPub) diff --git a/src/ZkFold/Base/Protocol/ARK/Protostar.hs b/src/ZkFold/Base/Protocol/ARK/Protostar.hs index f0741ea56..34f51ddc8 100644 --- a/src/ZkFold/Base/Protocol/ARK/Protostar.hs +++ b/src/ZkFold/Base/Protocol/ARK/Protostar.hs @@ -10,6 +10,7 @@ import GHC.Generics (Generic) import Prelude (($), (==)) import qualified Prelude as P +import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) import ZkFold.Base.Data.Vector (Vector) @@ -66,7 +67,7 @@ instance (Arithmetic a, KnownNat n) => SpecialSoundProtocol a (RecursiveCircuit algebraicMap rc _ _ _ = let varF (NewVar ix) = var (ix P.+ value @n) - varF (InVar ix) = var (P.fromIntegral ix) + varF (InVar ix) = var (toConstant ix) in [ evalPolynomial evalMonomial varF poly | poly <- M.elems $ acSystem (circuit rc) From 1d770e234fa86af6bee244bb6a2932c79472c6b5 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Wed, 14 Aug 2024 10:37:26 -0700 Subject: [PATCH 07/48] Update Internal.hs --- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index a0ed643d0..4c35d291c 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -163,11 +163,8 @@ sources = runSources . ($ Sources @a . S.singleton) instance (Eq a, o ~ U1) => Semigroup (ArithmeticCircuit a i o) where c1 <> c2 = ArithmeticCircuit - { - acSystem = acSystem c1 `union` acSystem c2 - , acRange = acRange c1 `union` acRange c2 - -- NOTE: is it possible that we get a wrong argument order when doing `apply` because of this concatenation? - -- We need a way to ensure the correct order no matter how `(<>)` is used. + { acSystem = acSystem c1 `union` acSystem c2 + , acRange = acRange c1 `union` acRange c2 , acWitness = acWitness c1 `union` acWitness c2 , acVarOrder = acVarOrder c1 `union` acVarOrder c2 , acRNG = mkStdGen $ fst (uniform (acRNG c1)) Haskell.* fst (uniform (acRNG c2)) From 8fb1c27bf2419af1ae1b22f21ce3ecb790689f52 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Wed, 14 Aug 2024 10:38:13 -0700 Subject: [PATCH 08/48] Update ArithmeticCircuit.hs --- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index b6889ac97..82cbe01cd 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -56,10 +56,6 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map --------------------------------- High-level functions -------------------------------- --- TODO: make this work for different input types. --- applyArgs :: ArithmeticCircuit a (i :*: j) o -> i a -> ArithmeticCircuit a j o --- applyArgs r args = (apply args r{acOutput = U1}) {acOutput = fmap _ (acOutput r)} - -- | Optimizes the constraint system. -- -- TODO: Implement nontrivial optimizations. From b6a919f557b58ae852f511e302df0ab649713720 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 20 Aug 2024 14:27:40 -0700 Subject: [PATCH 09/48] test bug? --- tests/Tests/Arithmetization/Test4.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index e29b089e6..60c6737c2 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -52,9 +52,10 @@ testOnlyOutputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue - varF (NewVar ix) = ix + 1 - varF (InVar ix) = fromIntegral ix - indexOutputBool = V.singleton $ varF $ unPar1 $ acOutput ac + -- varF (NewVar ix) = ix + 1 + -- varF (InVar ix) = fromIntegral ix + -- indexOutputBool = V.singleton $ varF $ unPar1 $ acOutput ac + indexOutputBool = V.singleton (1 :: Natural) plonk = Plonk @32 omega k1 k2 indexOutputBool ac x setupP = setupProve @(PlonkBS N) plonk setupV = setupVerify @(PlonkBS N) plonk From 553a14ce5b1f7452b73a7b6e4a232335e03d698f Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 20 Aug 2024 14:31:52 -0700 Subject: [PATCH 10/48] Update Test4.hs --- tests/Tests/Arithmetization/Test4.hs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index 60c6737c2..df34fb3df 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -52,10 +52,9 @@ testOnlyOutputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue - -- varF (NewVar ix) = ix + 1 - -- varF (InVar ix) = fromIntegral ix - -- indexOutputBool = V.singleton $ varF $ unPar1 $ acOutput ac - indexOutputBool = V.singleton (1 :: Natural) + indexOutputBool = V.singleton $ case unPar1 $ acOutput ac of + NewVar ix -> ix + 1 + InVar _ -> 1 plonk = Plonk @32 omega k1 k2 indexOutputBool ac x setupP = setupProve @(PlonkBS N) plonk setupV = setupVerify @(PlonkBS N) plonk From 610056c4ab6d900830d115abe3cc3673560d6cf9 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 20 Aug 2024 14:47:15 -0700 Subject: [PATCH 11/48] fixes --- src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs | 2 +- src/ZkFold/Symbolic/Data/Combinators.hs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index 990dc0559..153120698 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -17,7 +17,7 @@ import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Point) import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) import ZkFold.Base.Data.Vector (Vector) -import ZkFold.Prelude (take) +import ZkFold.Prelude (log2ceiling, take) log2 :: (Integral a, Integral b) => a -> b log2 = ceiling @Double . logBase 2 . fromIntegral diff --git a/src/ZkFold/Symbolic/Data/Combinators.hs b/src/ZkFold/Symbolic/Data/Combinators.hs index f19643e81..fc8c2ac1b 100644 --- a/src/ZkFold/Symbolic/Data/Combinators.hs +++ b/src/ZkFold/Symbolic/Data/Combinators.hs @@ -46,8 +46,8 @@ class Shrink a b where -- | Convert an @ArithmeticCircuit@ to bits and return their corresponding variables. -- toBits - :: forall i v a m - . MonadBlueprint i v a m + :: forall v a m + . MonadCircuit v a m => [v] -> Natural -> Natural @@ -68,7 +68,7 @@ fromBits :: forall a . Natural -> Natural - -> (forall i v m. MonadBlueprint i v a m => [v] -> m [v]) + -> (forall v m. MonadCircuit v a m => [v] -> m [v]) fromBits hiBits loBits bits = do let (bitsHighNew, bitsLowNew) = splitAt (Haskell.fromIntegral hiBits) bits let lowVarsNew = chunksOf (Haskell.fromIntegral loBits) bitsLowNew From a3e07cce139255a7c74c74f2159b186e4105a59f Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 20 Aug 2024 14:47:17 -0700 Subject: [PATCH 12/48] Update UInt.hs --- src/ZkFold/Symbolic/Data/UInt.hs | 50 ++++++++++++++++---------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 5984442de..93097290f 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -203,7 +203,7 @@ instance ) => Extend (UInt n r (ArithmeticCircuit a i)) (UInt k r (ArithmeticCircuit a i)) where extend (UInt ac) = UInt (circuitF solve) where - solve :: forall v m. MonadBlueprint i v a m => m (Vector to v) + solve :: forall v m. MonadCircuit v a m => m (Vector to v) solve = do regs <- V.fromVector <$> runCircuit ac bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) @@ -225,7 +225,7 @@ instance ) => Shrink (UInt n r (ArithmeticCircuit a i)) (UInt k r (ArithmeticCircuit a i)) where shrink (UInt ac) = UInt (circuitF solve) where - solve :: forall v m. MonadBlueprint i v a m => m (Vector to v) + solve :: forall v m. MonadCircuit v a m => m (Vector to v) solve = do regs <- V.fromVector <$> runCircuit ac bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) @@ -286,7 +286,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegis instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => AdditiveSemigroup (UInt n r (ArithmeticCircuit a i)) where UInt x + UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do j <- newAssigned (Haskell.const zero) let xs = V.fromVector xv @@ -328,7 +328,7 @@ instance t :: a t = (one + one) ^ registerSize @a @n @r - one - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -339,13 +339,13 @@ instance (ris, rjs) = Haskell.unzip $ Haskell.init rest in solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i v a m => v -> v -> m [v] + solve1 :: MonadCircuit v a m => v -> v -> m [v] solve1 i j = do z0 <- newAssigned (\v -> v i - v j + fromConstant (2 ^ registerSize @a @n @r :: Natural)) (z, _) <- splitExpansion (highRegisterSize @a @n @r) 1 z0 return [z] - solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] + solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do s <- newAssigned (\v -> v i - v j + fromConstant (t + one)) (k, b0) <- splitExpansion (registerSize @a @n @r) 1 s @@ -355,7 +355,7 @@ instance (s', _) <- splitExpansion (highRegisterSize @a @n @r) 1 s'0 return (k : zs <> [s']) - fullSub :: MonadBlueprint i v a m => v -> v -> v -> m (v, v) + fullSub :: MonadCircuit v a m => v -> v -> v -> m (v, v) fullSub xk yk b = do d <- newAssigned (\v -> v xk - v yk) s <- newAssigned (\v -> v d + v b + fromConstant t) @@ -363,7 +363,7 @@ instance negate (UInt x) = UInt $ circuitF (V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do j <- newAssigned (Haskell.const zero) @@ -377,7 +377,7 @@ instance (zs, _) <- flip runStateT j $ traverse StateT (Haskell.zipWith negateN ns xs) return zs - negateN :: MonadBlueprint i v a m => Natural -> v -> v -> m (v, v) + negateN :: MonadCircuit v a m => Natural -> v -> v -> m (v, v) negateN n i b = do r <- newAssigned (\v -> fromConstant n - v i + v b) splitExpansion (registerSize @a @n @r) 1 r @@ -386,7 +386,7 @@ instance instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs, Haskell.Ord (Rep i), Representable i) => MultiplicativeSemigroup (UInt n rs (ArithmeticCircuit a i)) where UInt x * UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -397,13 +397,13 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters (ris, rjs) = Haskell.unzip $ Haskell.init rest in solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i v a m => v -> v -> m [v] + solve1 :: MonadCircuit v a m => v -> v -> m [v] solve1 i j = do z0 <- newAssigned (\v -> v i - v j + fromConstant (2 ^ registerSize @(BaseField c) @n @r :: Natural)) (z, _) <- splitExpansion (highRegisterSize @(BaseField c) @n @r) 1 z0 return [z] - solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] + solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do let cs = fromList $ zip [0..] (i : is ++ [i']) ds = fromList $ zip [0..] (j : js ++ [j']) @@ -476,7 +476,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Iso (ByteString n (ArithmeticCircuit a i)) (UInt n r (ArithmeticCircuit a i)) where from (ByteString bits) = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: forall v m. MonadBlueprint i v a m => m [v] + solve :: forall v m. MonadCircuit v a m => m [v] solve = do bsBits <- V.fromVector <$> runCircuit bits Haskell.reverse <$> fromBits (highRegisterSize @a @n @r) (registerSize @a @n @r) bsBits @@ -484,7 +484,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Iso (UInt n r (ArithmeticCircuit a i)) (ByteString n (ArithmeticCircuit a i)) where from (UInt ac) = ByteString $ circuitF $ Vector <$> solve where - solve :: forall v m. MonadBlueprint i v a m => m [v] + solve :: forall v m. MonadCircuit v a m => m [v] solve = do regs <- V.fromVector <$> runCircuit ac toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) @@ -517,7 +517,7 @@ instance (Finite (Zp p), Prime p, KnownNat n, KnownRegisterSize r, Haskell.Ord ( instance (Arithmetic a, KnownNat n, KnownRegisterSize r, NumberOfBits a <= n, Haskell.Ord (Rep i), Representable i) => StrictConv (ArithmeticCircuit a i Par1) (UInt n r (ArithmeticCircuit a i)) where strictConv a = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do i <- unPar1 <$> runCircuit a let len = Haskell.min (getNatural @n) (numberOfBits @a) @@ -538,7 +538,7 @@ instance (Finite (Zp p), KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => StrictNum (UInt n r (ArithmeticCircuit a i)) where strictAdd (UInt x) (UInt y) = UInt (circuitF $ V.unsafeToVector <$> solve) where - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do j <- newAssigned (Haskell.const zero) let xs = V.fromVector xv @@ -559,7 +559,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re t :: BaseField c t = (one + one) ^ registerSize @(BaseField c) @n @r - one - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -570,13 +570,13 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re (ris, rjs) = Haskell.unzip $ Haskell.init rest in V.unsafeToVector <$> solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i v a m => v -> v -> m [v] + solve1 :: MonadCircuit v a m => v -> v -> m [v] solve1 i j = do z <- newAssigned (\v -> v i - v j) _ <- expansion (highRegisterSize @(BaseField c) @n @r) z return [z] - solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] + solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do s <- newAssigned (\v -> v i - v j + fromConstant (t + one)) (k, b0) <- splitExpansion (registerSize @(BaseField c) @n @r) 1 s @@ -587,7 +587,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re return (k : zs <> [s']) - fullSub :: MonadBlueprint i v a m => v -> v -> v -> m (v, v) + fullSub :: MonadCircuit v a m => v -> v -> v -> m (v, v) fullSub xk yk b = do k <- newAssigned (\v -> v xk - v yk) s <- newAssigned (\v -> v k + v b + fromConstant t) @@ -595,7 +595,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re strictMul (UInt x) (UInt y) = UInt $ symbolic2F x y (\u v -> naturalToVector @c @n @r $ vectorToNatural u (registerSize @(BaseField c) @n @r) * vectorToNatural v (registerSize @(BaseField c) @n @r)) solve where - solve :: MonadBlueprint i v a m => m [v] + solve :: MonadCircuit v a m => m [v] solve = do is <- runCircuit x js <- runCircuit y @@ -606,13 +606,13 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re (ris, rjs) = Haskell.unzip $ Haskell.init rest in V.unsafeToVector <$> solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadBlueprint i v a m => v -> v -> m [v] + solve1 :: MonadCircuit v a m => v -> v -> m [v] solve1 i j = do z <- newAssigned $ \v -> v i * v j _ <- expansion (highRegisterSize @(BaseField c) @n @r) z return [z] - solveN :: MonadBlueprint i v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] + solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] solveN (i, j) (is, js) (i', j') = do let cs = fromList $ zip [0..] (i : is ++ [i']) ds = fromList $ zip [0..] (j : js ++ [j']) @@ -643,10 +643,10 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re -------------------------------------------------------------------------------- -fullAdder :: (Arithmetic a, MonadBlueprint i v a m) => Natural -> v -> v -> v -> m (v, v) +fullAdder :: (Arithmetic a, MonadCircuit v a m) => Natural -> v -> v -> v -> m (v, v) fullAdder r xk yk c = fullAdded xk yk c >>= splitExpansion r 1 -fullAdded :: MonadBlueprint i v a m => v -> v -> v -> m v +fullAdded :: MonadCircuit v a m => v -> v -> v -> m v fullAdded i j c = do k <- newAssigned (\v -> v i + v j) newAssigned (\v -> v k + v c) From 8a24cfd8aa959cee86e8a14be2287cfcb17ad5fd Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 20 Aug 2024 14:49:47 -0700 Subject: [PATCH 13/48] Update UInt.hs --- src/ZkFold/Symbolic/Data/UInt.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 93097290f..6c245f4b3 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -282,8 +282,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegis min x y = bool @(Bool (ArithmeticCircuit a i)) x y $ x > y - -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => AdditiveSemigroup (UInt n r (ArithmeticCircuit a i)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => AdditiveSemigroup (UInt n r c) where UInt x + UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) where solve :: MonadCircuit v a m => m [v] @@ -383,7 +382,7 @@ instance splitExpansion (registerSize @a @n @r) 1 r -instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs, Haskell.Ord (Rep i), Representable i) => MultiplicativeSemigroup (UInt n rs (ArithmeticCircuit a i)) where +instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs) => MultiplicativeSemigroup (UInt n rs c) where UInt x * UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) where solve :: MonadCircuit v a m => m [v] From e164c3d53ee7144018fc6b31664fc462b18141ee Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 20 Aug 2024 14:52:51 -0700 Subject: [PATCH 14/48] Update UInt.hs --- src/ZkFold/Symbolic/Data/UInt.hs | 414 ++++++++++++------------------- 1 file changed, 155 insertions(+), 259 deletions(-) diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 6c245f4b3..f91ffc726 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -16,22 +16,21 @@ module ZkFold.Symbolic.Data.UInt ( ) where import Control.DeepSeq -import Control.Monad.State (StateT (..)) -import Data.Foldable (foldr, foldrM, for_) -import Data.Functor ((<$>)) -import Data.Functor.Rep (Representable (..)) -import Data.Kind (Type) -import Data.List (unfoldr, zip) -import Data.Map (fromList, (!)) -import Data.Traversable (for, traverse) -import Data.Tuple (swap) -import qualified Data.Zip as Z -import GHC.Generics (Generic, Par1 (..)) -import GHC.Natural (naturalFromInteger) -import Prelude (Integer, error, flip, otherwise, return, - type (~), ($), (++), (.), (<>), (>>=)) -import qualified Prelude as Haskell -import Test.QuickCheck (Arbitrary (..), chooseInteger) +import Control.Monad.State (StateT (..)) +import Data.Foldable (foldr, foldrM, for_) +import Data.Functor ((<$>)) +import Data.Kind (Type) +import Data.List (unfoldr, zip) +import Data.Map (fromList, (!)) +import Data.Traversable (for, traverse) +import Data.Tuple (swap) +import qualified Data.Zip as Z +import GHC.Generics (Generic, Par1 (..)) +import GHC.Natural (naturalFromInteger) +import Prelude (Integer, error, flip, otherwise, return, + type (~), ($), (++), (.), (<>), (>>=)) +import qualified Prelude as Haskell +import Test.QuickCheck (Arbitrary (..), chooseInteger) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp, toZp) @@ -68,33 +67,9 @@ instance (Symbolic c, KnownNat n, KnownRegisterSize r) => FromConstant Natural ( instance (Symbolic c, KnownNat n, KnownRegisterSize r) => FromConstant Integer (UInt n r c) where fromConstant = fromConstant . naturalFromInteger . (`Haskell.mod` (2 ^ getNatural @n)) -instance - ( FromConstant Natural a - , Arithmetic a - , KnownNat n - , KnownRegisterSize r - , Haskell.Ord (Rep i) - , Representable i - ) => FromConstant Natural (UInt n r (ArithmeticCircuit a i)) where - fromConstant c = - let (lo, hi, _) = cast @a @n @r . (`Haskell.mod` (2 ^ getNatural @n)) $ c - in UInt $ embedV $ Vector $ fromConstant <$> (lo <> [hi]) - -instance - ( FromConstant Natural a - , Arithmetic a - , KnownNat n - , KnownRegisterSize r - , Haskell.Ord (Rep i) - , Representable i - ) => FromConstant Integer (UInt n r (ArithmeticCircuit a i)) where - fromConstant = fromConstant . naturalFromInteger . (`Haskell.mod` (2 ^ getNatural @n)) - -instance (FromConstant Natural (UInt n r b), KnownNat n, MultiplicativeSemigroup (UInt n r b)) => Scale Natural (UInt n r b) +instance (Symbolic c, KnownNat n, KnownRegisterSize r, FromConstant a (UInt n r c), MultiplicativeMonoid a) => Scale a (UInt n r c) -instance (FromConstant Integer (UInt n r b), KnownNat n, MultiplicativeSemigroup (UInt n r b)) => Scale Integer (UInt n r b) - -instance MultiplicativeMonoid (UInt n r b) => Exponent (UInt n r b) Natural where +instance MultiplicativeMonoid (UInt n r c) => Exponent (UInt n r c) Natural where (^) = natPow cast :: forall a n r . (Arithmetic a, KnownNat n, KnownRegisterSize r) => Natural -> ([Natural], Natural, [Natural]) @@ -196,17 +171,12 @@ instance , KnownNat k , KnownRegisterSize r , n <= k - , from ~ NumberOfRegisters a n r - , to ~ NumberOfRegisters a k r - , Haskell.Ord (Rep i) - , Representable i - ) => Extend (UInt n r (ArithmeticCircuit a i)) (UInt k r (ArithmeticCircuit a i)) where - extend (UInt ac) = UInt (circuitF solve) + ) => Extend (UInt n r c) (UInt k r c) where + extend (UInt x) = UInt $ symbolicF x (\l -> naturalToVector @c @k @r (vectorToNatural l (registerSize @(BaseField c) @n @r))) solve where - solve :: forall v m. MonadCircuit v a m => m (Vector to v) - solve = do - regs <- V.fromVector <$> runCircuit ac - bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) + solve :: MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector (NumberOfRegisters (BaseField c) k r) i) + solve xv = do + let regs = V.fromVector xv zeros <- replicateA (value @k -! (value @n)) (newAssigned (Haskell.const zero)) bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @(BaseField c) @n @r)(registerSize @(BaseField c) @n @r) extended <- fromBits (highRegisterSize @(BaseField c) @k @r) (registerSize @(BaseField c) @k @r) (zeros <> bsBits) @@ -218,18 +188,16 @@ instance , KnownNat k , KnownRegisterSize r , k <= n - , from ~ NumberOfRegisters a n r - , to ~ NumberOfRegisters a k r - , Haskell.Ord (Rep i) - , Representable i - ) => Shrink (UInt n r (ArithmeticCircuit a i)) (UInt k r (ArithmeticCircuit a i)) where - shrink (UInt ac) = UInt (circuitF solve) + , from ~ NumberOfRegisters (BaseField c) n r + , to ~ NumberOfRegisters (BaseField c) k r + ) => Shrink (UInt n r c) (UInt k r c) where + shrink (UInt ac) = UInt $ symbolicF ac (V.unsafeToVector . V.fromVector ) solve where - solve :: forall v m. MonadCircuit v a m => m (Vector to v) - solve = do - regs <- V.fromVector <$> runCircuit ac - bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) - shrinked <- fromBits (highRegisterSize @a @k @r) (registerSize @a @k @r) (drop (value @n -! (value @k)) bsBits) + solve :: MonadCircuit i (BaseField c) m => Vector from i -> m (Vector to i) + solve xv = do + let regs = V.fromVector xv + bsBits <- toBits (Haskell.reverse regs) (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) + shrinked <- fromBits (highRegisterSize @(BaseField c) @k @r) (registerSize @(BaseField c) @k @r) (drop (value @n -! (value @k)) bsBits) return $ V.unsafeToVector $ Haskell.reverse shrinked instance @@ -263,7 +231,23 @@ instance let rs = force $ addBit (r' + r') (value @n -! i -! 1) in bool @(Bool c) (q', rs) (q' + fromConstant ((2 :: Natural) ^ i), rs - d) (rs >= d) -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegisters a n r), Haskell.Ord (Rep i), Representable i) => Ord (Bool (ArithmeticCircuit a i)) (UInt n r (ArithmeticCircuit a i)) where +instance (Symbolic (ArithmeticCircuit a i), KnownNat n, KnownRegisterSize r) => Iso (ByteString n (ArithmeticCircuit a i)) (UInt n r (ArithmeticCircuit a i)) where + from (ByteString bits) = UInt $ symbolicF bits (\v -> naturalToVector @(ArithmeticCircuit a i) @n @r $ vectorToNatural v (registerSize @a @n @r)) solve + where + solve :: MonadCircuit v a m => Vector n v -> m (Vector (NumberOfRegisters a n r) v) + solve xv = do + let bsBits = V.fromVector xv + V.unsafeToVector . Haskell.reverse <$> fromBits (highRegisterSize @a @n @r) (registerSize @a @n @r) bsBits + +instance (Symbolic (ArithmeticCircuit a i), KnownNat n, KnownRegisterSize r) => Iso (UInt n r (ArithmeticCircuit a i)) (ByteString n (ArithmeticCircuit a i)) where + from (UInt ac) = ByteString $ symbolicF ac (\v -> V.unsafeToVector $ fromConstant <$> toBsBits (vectorToNatural v (registerSize @a @n @r)) (value @n)) solve + where + solve :: MonadCircuit v a m => Vector (NumberOfRegisters a n r) v -> m (Vector n v) + solve xv = do + let regs = V.fromVector xv + V.unsafeToVector <$> toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) + +instance (Symbolic (ArithmeticCircuit a i), KnownNat n, KnownRegisterSize r) => Ord (Bool (ArithmeticCircuit a i)) (UInt n r (ArithmeticCircuit a i)) where x <= y = y >= x x < y = y > x @@ -276,17 +260,18 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegis u1 > u2 = let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a i) ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a i) - in bitwiseGT rs1 rs2 + in bitwiseGT rs1 rs2 max x y = bool @(Bool (ArithmeticCircuit a i)) x y $ x < y min x y = bool @(Bool (ArithmeticCircuit a i)) x y $ x > y -instance (Arithmetic a, KnownNat n, KnownRegisterSize r) => AdditiveSemigroup (UInt n r c) where - UInt x + UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) + +instance (Symbolic c, KnownNat n, KnownRegisterSize r) => AdditiveSemigroup (UInt n r c) where + UInt xc + UInt yc = UInt $ symbolic2F xc yc (\u v -> naturalToVector @c @n @r $ vectorToNatural u (registerSize @(BaseField c) @n @r) + vectorToNatural v (registerSize @(BaseField c) @n @r)) solve where - solve :: MonadCircuit v a m => m [v] - solve = do + solve :: MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve xv yv = do j <- newAssigned (Haskell.const zero) let xs = V.fromVector xv ys = V.fromVector yv @@ -304,105 +289,89 @@ instance (Symbolic c, KnownNat n, KnownRegisterSize r) => AdditiveMonoid (UInt n zero = fromConstant (0:: Natural) instance - ( Arithmetic a + (Symbolic c , KnownNat n - , KnownNat (NumberOfRegisters a n r) , KnownRegisterSize r - , Haskell.Ord (Rep i) - , Representable i - ) => AdditiveMonoid (UInt n r (ArithmeticCircuit a i)) where - zero = UInt $ embedV (pure zero) + ) => AdditiveGroup (UInt n r c) where -instance - ( Arithmetic a - , KnownNat n - , KnownRegisterSize r - , KnownNat (NumberOfRegisters a n r) - , Haskell.Ord (Rep i) - , Representable i - ) => AdditiveGroup (UInt n r (ArithmeticCircuit a i)) where + UInt x - UInt y = UInt $ symbolic2F x y (\u v -> naturalToVector @c @n @r $ vectorToNatural u (registerSize @(BaseField c) @n @r) + 2 ^ (value @n) -! vectorToNatural v (registerSize @(BaseField c) @n @r) ) solve + where + t :: BaseField c + t = (one + one) ^ registerSize @(BaseField c) @n @r - one - UInt x - UInt y = UInt $ circuitF (V.unsafeToVector <$> solve) - where - t :: a - t = (one + one) ^ registerSize @a @n @r - one - - solve :: MonadCircuit v a m => m [v] - solve = do - is <- runCircuit x - js <- runCircuit y - case V.fromVector $ Z.zip is js of - [] -> return [] - [(i, j)] -> solve1 i j - ((i, j) : rest) -> let (z, w) = Haskell.last rest - (ris, rjs) = Haskell.unzip $ Haskell.init rest - in solveN (i, j) (ris, rjs) (z, w) - - solve1 :: MonadCircuit v a m => v -> v -> m [v] - solve1 i j = do - z0 <- newAssigned (\v -> v i - v j + fromConstant (2 ^ registerSize @a @n @r :: Natural)) - (z, _) <- splitExpansion (highRegisterSize @a @n @r) 1 z0 - return [z] - - solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] - solveN (i, j) (is, js) (i', j') = do - s <- newAssigned (\v -> v i - v j + fromConstant (t + one)) - (k, b0) <- splitExpansion (registerSize @a @n @r) 1 s - (zs, b) <- flip runStateT b0 $ traverse StateT (Haskell.zipWith fullSub is js) - d <- newAssigned (\v -> v i' - v j') - s'0 <- newAssigned (\v -> v d + v b + fromConstant t) - (s', _) <- splitExpansion (highRegisterSize @a @n @r) 1 s'0 - return (k : zs <> [s']) - - fullSub :: MonadCircuit v a m => v -> v -> v -> m (v, v) - fullSub xk yk b = do - d <- newAssigned (\v -> v xk - v yk) - s <- newAssigned (\v -> v d + v b + fromConstant t) - splitExpansion (registerSize @a @n @r) 1 s - - negate (UInt x) = UInt $ circuitF (V.unsafeToVector <$> solve) - where - solve :: MonadCircuit v a m => m [v] - solve = do - j <- newAssigned (Haskell.const zero) - - xs <- V.fromVector <$> runCircuit x - let y = 2 ^ registerSize @a @n @r - ys = replicate (numberOfRegisters @a @n @r -! 2) (2 ^ registerSize @a @n @r -! 1) - y' = 2 ^ highRegisterSize @a @n @r -! 1 - ns - | numberOfRegisters @a @n @r Haskell.== 1 = [y' + 1] - | otherwise = (y : ys) <> [y'] - (zs, _) <- flip runStateT j $ traverse StateT (Haskell.zipWith negateN ns xs) - return zs - - negateN :: MonadCircuit v a m => Natural -> v -> v -> m (v, v) - negateN n i b = do - r <- newAssigned (\v -> fromConstant n - v i + v b) - splitExpansion (registerSize @a @n @r) 1 r - - -instance (Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs) => MultiplicativeSemigroup (UInt n rs c) where - UInt x * UInt y = UInt (circuitF $ V.unsafeToVector <$> solve) + solve :: forall i m. (MonadCircuit i (BaseField c) m) => Vector (NumberOfRegisters (BaseField c) n r) i -> Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve xv yv = do + let is = V.fromVector xv + js = V.fromVector yv + case Z.zip is js of + [] -> return $ V.unsafeToVector [] + [(i, j)] -> V.unsafeToVector <$> solve1 i j + ((i, j) : rest) -> let (z, w) = Haskell.last rest + (ris, rjs) = Haskell.unzip $ Haskell.init rest + in V.unsafeToVector <$> solveN (i, j) (ris, rjs) (z, w) + + solve1 :: MonadCircuit i (BaseField c) m => i -> i -> m [i] + solve1 i j = do + z0 <- newAssigned (\v -> v i - v j + fromConstant (2 ^ registerSize @(BaseField c) @n @r :: Natural)) + (z, _) <- splitExpansion (highRegisterSize @(BaseField c) @n @r) 1 z0 + return [z] + + solveN :: MonadCircuit i (BaseField c) m => (i, i) -> ([i], [i]) -> (i, i) -> m [i] + solveN (i, j) (is, js) (i', j') = do + s <- newAssigned (\v -> v i - v j + fromConstant (t + one)) + (k, b0) <- splitExpansion (registerSize @(BaseField c) @n @r) 1 s + (zs, b) <- flip runStateT b0 $ traverse StateT (Haskell.zipWith fullSub is js) + d <- newAssigned (\v -> v i' - v j') + s'0 <- newAssigned (\v -> v d + v b + fromConstant t) + (s', _) <- splitExpansion (highRegisterSize @(BaseField c) @n @r) 1 s'0 + return (k : zs <> [s']) + + fullSub :: MonadCircuit i (BaseField c) m => i -> i -> i -> m (i, i) + fullSub xk yk b = do + d <- newAssigned (\v -> v xk - v yk) + s <- newAssigned (\v -> v d + v b + fromConstant t) + splitExpansion (registerSize @(BaseField c) @n @r) 1 s + + negate :: UInt n r c -> UInt n r c + negate (UInt x) = UInt $ symbolicF x (\v -> naturalToVector @c @n @r $ (2 ^ (value @n) ) -! vectorToNatural v (registerSize @(BaseField c) @n @r)) solve where - solve :: MonadCircuit v a m => m [v] - solve = do - is <- runCircuit x - js <- runCircuit y - case V.fromVector $ Z.zip is js of - [] -> return [] - [(i, j)] -> solve1 i j + solve :: MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve xv = do + j <- newAssigned (Haskell.const zero) + let xs = V.fromVector xv + y = 2 ^ registerSize @(BaseField c) @n @r + ys = replicate (numberOfRegisters @(BaseField c) @n @r -! 2) (2 ^ registerSize @(BaseField c) @n @r -! 1) + y' = 2 ^ highRegisterSize @(BaseField c) @n @r -! 1 + ns + | numberOfRegisters @(BaseField c) @n @r Haskell.== 1 = [y' + 1] + | otherwise = (y : ys) <> [y'] + (zs, _) <- flip runStateT j $ traverse StateT (Haskell.zipWith negateN ns xs) + return $ V.unsafeToVector zs + + negateN :: MonadCircuit i (BaseField c) m => Natural -> i -> i -> m (i, i) + negateN n i b = do + r <- newAssigned (\v -> fromConstant n - v i + v b) + splitExpansion (registerSize @(BaseField c) @n @r) 1 r + + +instance (Symbolic c, KnownNat n, KnownRegisterSize rs) => MultiplicativeSemigroup (UInt n rs c) where + UInt x * UInt y = UInt $ symbolic2F x y (\u v -> naturalToVector @c @n @rs $ vectorToNatural u (registerSize @(BaseField c) @n @rs) * vectorToNatural v (registerSize @(BaseField c) @n @rs)) solve + where + solve :: forall i m. (MonadCircuit i (BaseField c) m) => Vector (NumberOfRegisters (BaseField c) n rs) i -> Vector (NumberOfRegisters (BaseField c) n rs) i -> m (Vector (NumberOfRegisters (BaseField c) n rs) i) + solve xv yv = do + case V.fromVector $ Z.zip xv yv of + [] -> return $ V.unsafeToVector [] + [(i, j)] -> V.unsafeToVector <$> solve1 i j ((i, j) : rest) -> let (z, w) = Haskell.last rest (ris, rjs) = Haskell.unzip $ Haskell.init rest - in solveN (i, j) (ris, rjs) (z, w) + in V.unsafeToVector <$> solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadCircuit v a m => v -> v -> m [v] + solve1 :: forall i m. (MonadCircuit i (BaseField c) m) => i -> i -> m [i] solve1 i j = do - z0 <- newAssigned (\v -> v i - v j + fromConstant (2 ^ registerSize @(BaseField c) @n @r :: Natural)) - (z, _) <- splitExpansion (highRegisterSize @(BaseField c) @n @r) 1 z0 + (z, _) <- newAssigned (\v -> v i * v j) >>= splitExpansion (highRegisterSize @(BaseField c) @n @rs) (maxOverflow @(BaseField c) @n @rs) return [z] - solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] + solveN :: forall i m. (MonadCircuit i (BaseField c) m) => (i, i) -> ([i], [i]) -> (i, i) -> m [i] solveN (i, j) (is, js) (i', j') = do let cs = fromList $ zip [0..] (i : is ++ [i']) ds = fromList $ zip [0..] (j : js ++ [j']) @@ -430,63 +399,7 @@ instance ( Symbolic c , KnownNat n , KnownRegisterSize r - , (NumberOfRegisters a n r - 1) + 1 ~ NumberOfRegisters a n r - , Haskell.Ord (Rep i) - , Representable i - ) => MultiplicativeMonoid (UInt n r (ArithmeticCircuit a i)) where - - one = UInt $ hliftA2 (\(Par1 h) t -> h V..: t) (embed one :: ArithmeticCircuit a i Par1) (embedV (pure zero) :: ArithmeticCircuit a i (Vector (NumberOfRegisters a n r - 1))) - - -instance - ( Arithmetic a - , KnownNat n - , KnownNat (NumberOfRegisters a n r) - , KnownNat (NumberOfRegisters a n r - 1) - , KnownRegisterSize r - , (NumberOfRegisters a n r - 1) + 1 ~ NumberOfRegisters a n r - , Haskell.Ord (Rep i) - , Representable i - ) => Semiring (UInt n r (ArithmeticCircuit a i)) - -instance - ( Arithmetic a - , KnownNat n - , KnownNat (NumberOfRegisters a n r) - , KnownNat (NumberOfRegisters a n r - 1) - , KnownRegisterSize r - , (NumberOfRegisters a n r - 1) + 1 ~ NumberOfRegisters a n r - , Haskell.Ord (Rep i) - , Representable i - ) => Ring (UInt n r (ArithmeticCircuit a i)) - -deriving via (Structural (UInt n rs (ArithmeticCircuit a i))) - instance (Arithmetic a, r ~ NumberOfRegisters a n rs, 1 <= r, Haskell.Ord (Rep i), Representable i) => - Eq (Bool (ArithmeticCircuit a i)) (UInt n rs (ArithmeticCircuit a i)) - -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Arbitrary (UInt n r (ArithmeticCircuit a i)) where - arbitrary = do - lows <- replicateA (numberOfRegisters @a @n @r -! 1) (toss $ registerSize @a @n @r) - hi <- toss (highRegisterSize @a @n @r) - return $ UInt $ embedV (V.unsafeToVector $ lows <> [hi]) - where - toss b = fromConstant <$> chooseInteger (0, 2 ^ b - 1) - -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Iso (ByteString n (ArithmeticCircuit a i)) (UInt n r (ArithmeticCircuit a i)) where - from (ByteString bits) = UInt (circuitF $ V.unsafeToVector <$> solve) - where - solve :: forall v m. MonadCircuit v a m => m [v] - solve = do - bsBits <- V.fromVector <$> runCircuit bits - Haskell.reverse <$> fromBits (highRegisterSize @a @n @r) (registerSize @a @n @r) bsBits - -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => Iso (UInt n r (ArithmeticCircuit a i)) (ByteString n (ArithmeticCircuit a i)) where - from (UInt ac) = ByteString $ circuitF $ Vector <$> solve - where - solve :: forall v m. MonadCircuit v a m => m [v] - solve = do - regs <- V.fromVector <$> runCircuit ac - toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) + ) => Ring (UInt n r c) deriving via (Structural (UInt n rs c)) instance (Symbolic c) => @@ -502,24 +415,16 @@ instance (Symbolic c, KnownNat n, KnownRegisterSize rs) => StrictConv Natural (U (lo, hi, []) -> UInt $ embed $ V.unsafeToVector $ fromConstant <$> (lo <> [hi]) _ -> error "strictConv: overflow" -instance (FromConstant Natural a, Arithmetic a, KnownNat n, KnownRegisterSize rs, r ~ NumberOfRegisters a n rs, Haskell.Ord (Rep i), Representable i) => StrictConv Natural (UInt n rs (ArithmeticCircuit a i)) where - strictConv n = case cast @a @n @rs n of - (lo, hi, []) -> UInt $ embedV $ V.unsafeToVector $ fromConstant <$> (lo <> [hi]) - _ -> error "strictConv: overflow" - -instance (Finite (Zp p), KnownNat n, KnownRegisterSize r) => StrictConv (Zp p) (UInt n r (Interpreter (Zp p))) where - strictConv = strictConv . toConstant @_ @Natural - -instance (Finite (Zp p), Prime p, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => StrictConv (Zp p) (UInt n r (ArithmeticCircuit (Zp p) i)) where +instance (Symbolic c, KnownNat n, KnownRegisterSize r) => StrictConv (Zp p) (UInt n r c) where strictConv = strictConv . toConstant @_ @Natural -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, NumberOfBits a <= n, Haskell.Ord (Rep i), Representable i) => StrictConv (ArithmeticCircuit a i Par1) (UInt n r (ArithmeticCircuit a i)) where - strictConv a = UInt (circuitF $ V.unsafeToVector <$> solve) +instance (Symbolic c, KnownNat n, KnownRegisterSize r, NumberOfBits (BaseField c) <= n) => StrictConv (c Par1) (UInt n r c) where + strictConv a = UInt $ symbolicF a (\p -> V.unsafeToVector [unPar1 p]) solve where - solve :: MonadCircuit v a m => m [v] - solve = do - i <- unPar1 <$> runCircuit a - let len = Haskell.min (getNatural @n) (numberOfBits @a) + solve :: MonadCircuit i (BaseField c) m => Par1 i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve xv = do + let i = unPar1 xv + len = Haskell.min (getNatural @n) (numberOfBits @(BaseField c)) bits <- Haskell.reverse <$> expansion len i V.unsafeToVector <$> fromBits (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) bits @@ -529,16 +434,11 @@ class StrictNum a where strictSub :: a -> a -> a strictMul :: a -> a -> a -instance (Finite (Zp p), KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r (Interpreter (Zp p))) where - strictAdd x y = strictConv $ toConstant x + toConstant @_ @Natural y - strictSub x y = strictConv $ toConstant x -! toConstant y - strictMul x y = strictConv $ toConstant x * toConstant @_ @Natural y - -instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Representable i) => StrictNum (UInt n r (ArithmeticCircuit a i)) where - strictAdd (UInt x) (UInt y) = UInt (circuitF $ V.unsafeToVector <$> solve) +instance (Symbolic c, KnownNat n, KnownRegisterSize r) => StrictNum (UInt n r c) where + strictAdd (UInt x) (UInt y) = UInt $ symbolic2F x y (\u v -> naturalToVector @c @n @r $ vectorToNatural u (registerSize @(BaseField c) @n @r) + vectorToNatural v (registerSize @(BaseField c) @n @r))solve where - solve :: MonadCircuit v a m => m [v] - solve = do + solve :: MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve xv yv = do j <- newAssigned (Haskell.const zero) let xs = V.fromVector xv ys = V.fromVector yv @@ -558,24 +458,22 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re t :: BaseField c t = (one + one) ^ registerSize @(BaseField c) @n @r - one - solve :: MonadCircuit v a m => m [v] - solve = do - is <- runCircuit x - js <- runCircuit y - case V.fromVector $ Z.zip is js of - [] -> return [] - [(i, j)] -> solve1 i j + solve :: MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve xv yv = do + case V.fromVector $ Z.zip xv yv of + [] -> return $ V.unsafeToVector [] + [(i, j)] -> V.unsafeToVector <$> solve1 i j ((i, j) : rest) -> let (z, w) = Haskell.last rest (ris, rjs) = Haskell.unzip $ Haskell.init rest in V.unsafeToVector <$> solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadCircuit v a m => v -> v -> m [v] + solve1 :: MonadCircuit i (BaseField c) m => i -> i -> m [i] solve1 i j = do z <- newAssigned (\v -> v i - v j) _ <- expansion (highRegisterSize @(BaseField c) @n @r) z return [z] - solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] + solveN :: MonadCircuit i (BaseField c) m => (i, i) -> ([i], [i]) -> (i, i) -> m [i] solveN (i, j) (is, js) (i', j') = do s <- newAssigned (\v -> v i - v j + fromConstant (t + one)) (k, b0) <- splitExpansion (registerSize @(BaseField c) @n @r) 1 s @@ -586,7 +484,7 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re return (k : zs <> [s']) - fullSub :: MonadCircuit v a m => v -> v -> v -> m (v, v) + fullSub :: MonadCircuit i (BaseField c) m => i -> i -> i -> m (i, i) fullSub xk yk b = do k <- newAssigned (\v -> v xk - v yk) s <- newAssigned (\v -> v k + v b + fromConstant t) @@ -594,24 +492,22 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re strictMul (UInt x) (UInt y) = UInt $ symbolic2F x y (\u v -> naturalToVector @c @n @r $ vectorToNatural u (registerSize @(BaseField c) @n @r) * vectorToNatural v (registerSize @(BaseField c) @n @r)) solve where - solve :: MonadCircuit v a m => m [v] - solve = do - is <- runCircuit x - js <- runCircuit y - case V.fromVector $ Z.zip is js of - [] -> return [] - [(i, j)] -> solve1 i j + solve :: MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve xv yv = do + case V.fromVector $ Z.zip xv yv of + [] -> return $ V.unsafeToVector [] + [(i, j)] -> V.unsafeToVector <$> solve1 i j ((i, j) : rest) -> let (z, w) = Haskell.last rest (ris, rjs) = Haskell.unzip $ Haskell.init rest in V.unsafeToVector <$> solveN (i, j) (ris, rjs) (z, w) - solve1 :: MonadCircuit v a m => v -> v -> m [v] + solve1 :: MonadCircuit i (BaseField c) m => i -> i -> m [i] solve1 i j = do z <- newAssigned $ \v -> v i * v j _ <- expansion (highRegisterSize @(BaseField c) @n @r) z return [z] - solveN :: MonadCircuit v a m => (v, v) -> ([v], [v]) -> (v, v) -> m [v] + solveN :: MonadCircuit i (BaseField c) m => (i, i) -> ([i], [i]) -> (i, i) -> m [i] solveN (i, j) (is, js) (i', j') = do let cs = fromList $ zip [0..] (i : is ++ [i']) ds = fromList $ zip [0..] (j : js ++ [j']) @@ -642,10 +538,10 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, Haskell.Ord (Rep i), Re -------------------------------------------------------------------------------- -fullAdder :: (Arithmetic a, MonadCircuit v a m) => Natural -> v -> v -> v -> m (v, v) +fullAdder :: (Arithmetic a, MonadCircuit i a m) => Natural -> i -> i -> i -> m (i, i) fullAdder r xk yk c = fullAdded xk yk c >>= splitExpansion r 1 -fullAdded :: MonadCircuit v a m => v -> v -> v -> m v +fullAdded :: MonadCircuit i a m => i -> i -> i -> m i fullAdded i j c = do k <- newAssigned (\v -> v i + v j) newAssigned (\v -> v k + v c) From 9f43853d8f57bb540f75d78f1f138492f0a4bf5d Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 20 Aug 2024 14:56:52 -0700 Subject: [PATCH 15/48] fix merge --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 13 ++++++------- tests/Tests/ByteString.hs | 2 ++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 01922d4f8..38690668c 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -32,9 +32,8 @@ import ZkFold.Base.Protocol.ARK.Plonk.Internal import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) import ZkFold.Base.Protocol.Commitment.KZG (com) import ZkFold.Base.Protocol.NonInteractiveProof -import ZkFold.Prelude (length, log2ceiling, (!)) -import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), ArithmeticCircuitTest (..), - witnessGenerator) +import ZkFold.Prelude (log2ceiling) +import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), ArithmeticCircuitTest (..)) import ZkFold.Symbolic.MonadCircuit (Arithmetic) {- @@ -64,15 +63,15 @@ instance (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (Scalar Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary instance forall n l c1 c2 t . (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1), - Witness (Plonk n l c1 c2 t) ~ (PlonkWitnessInput c1, PlonkProverSecret c1)) => Arbitrary (NonInteractiveProofTestData (Plonk n l c1 c2 t)) where + Witness (Plonk n l c1 c2 t) ~ (PlonkWitnessInput l c1, PlonkProverSecret c1)) => Arbitrary (NonInteractiveProofTestData (Plonk n l c1 c2 t)) where arbitrary = do - ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) Par1) - let inputLen = length . acInput $ ac + ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) (Vector l) Par1) + let inputLen = value @l vecPubInp <- genSubset (value @l) inputLen let (omega, k1, k2) = getParams $ value @n pl <- Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary secret <- arbitrary - return $ TestData pl (PlonkWitnessInput (witnessGenerator ac wi), secret) + return $ TestData pl (PlonkWitnessInput wi, secret) plonkPermutation :: forall n l c1 c2 t . (KnownNat n, FiniteField (ScalarField c1)) => Plonk n l c1 c2 t -> PlonkRelation n l (ScalarField c1) -> PlonkPermutation n c1 diff --git a/tests/Tests/ByteString.hs b/tests/Tests/ByteString.hs index 22fd7a18e..3b1912a20 100644 --- a/tests/Tests/ByteString.hs +++ b/tests/Tests/ByteString.hs @@ -66,6 +66,8 @@ testWords :: forall n wordSize p . KnownNat n => KnownNat wordSize + => Prime p + => KnownNat (Log2 (p - 1) + 1) => ToWords (ByteString n (ArithmeticCircuit (Zp p) U1)) (ByteString wordSize (ArithmeticCircuit (Zp p) U1)) => ToWords (ByteString n (Interpreter (Zp p))) (ByteString wordSize (Interpreter (Zp p))) => Spec From 416c40960f8d3f2b53a964329318820af1286b97 Mon Sep 17 00:00:00 2001 From: echatav Date: Tue, 20 Aug 2024 22:00:09 +0000 Subject: [PATCH 16/48] stylish-haskell auto-commit --- tests/Tests/Arithmetization.hs | 10 +++++----- tests/Tests/Arithmetization/Test4.hs | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index 223887499..db2dedf7a 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -3,8 +3,8 @@ module Tests.Arithmetization (specArithmetization) where -import Data.Functor.Rep (Representable (..)) -import GHC.Generics (Par1) +import Data.Functor.Rep (Representable (..)) +import GHC.Generics (Par1) import Prelude import Test.Hspec import Test.QuickCheck @@ -13,10 +13,10 @@ import Tests.Arithmetization.Test2 (specArithmetizatio import Tests.Arithmetization.Test3 (specArithmetization3) import Tests.Arithmetization.Test4 (specArithmetization4) -import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale, MultiplicativeMonoid) -import ZkFold.Base.Algebra.Basic.Field (Zp) +import ZkFold.Base.Algebra.Basic.Class (FromConstant, MultiplicativeMonoid, Scale) +import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.MonadCircuit (Arithmetic) diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index df34fb3df..b6072caff 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -54,7 +54,7 @@ testOnlyOutputZKP x ps targetValue = witnessInputs = V.singleton targetValue indexOutputBool = V.singleton $ case unPar1 $ acOutput ac of NewVar ix -> ix + 1 - InVar _ -> 1 + InVar _ -> 1 plonk = Plonk @32 omega k1 k2 indexOutputBool ac x setupP = setupProve @(PlonkBS N) plonk setupV = setupVerify @(PlonkBS N) plonk From 39b0ceffa50a62e99ad3499a94abdff13a917bf7 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Thu, 22 Aug 2024 09:22:56 -0700 Subject: [PATCH 17/48] merge things --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 2 +- src/ZkFold/Symbolic/Data/UInt.hs | 27 +++++------------------- tests/Tests/Arithmetization.hs | 2 +- tests/Tests/NonInteractiveProof/Plonk.hs | 2 +- 4 files changed, 8 insertions(+), 25 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 0fcfca6f4..d4856871f 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -62,7 +62,7 @@ instance (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (Scalar Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary instance forall n l c1 c2 t core . (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1), - Witness (Plonk n l c1 c2 t) ~ (PlonkWitnessInput c1, PlonkProverSecret c1), NonInteractiveProof (Plonk n l c1 c2 t) core) => Arbitrary (NonInteractiveProofTestData (Plonk n l c1 c2 t) core) where + Witness (Plonk n l c1 c2 t) ~ (PlonkWitnessInput l c1, PlonkProverSecret c1), NonInteractiveProof (Plonk n l c1 c2 t) core) => Arbitrary (NonInteractiveProofTestData (Plonk n l c1 c2 t) core) where arbitrary = do ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) (Vector l) Par1) let inputLen = value @l diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 4cbf43c9c..ad7f525ff 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -142,23 +142,6 @@ instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Arbitrary (UInt n r c) return $ UInt $ embed $ V.unsafeToVector (lo <> [hi]) where toss b = fromConstant <$> chooseInteger (0, 2 ^ b - 1) - -instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (ByteString n c) (UInt n r c) where - from (ByteString b) = UInt $ fromCircuitF b solve - where - solve :: forall i m. MonadCircuit i (BaseField c) m => Vector n i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) - solve bits = do - let bsBits = V.fromVector bits - V.unsafeToVector . Haskell.reverse <$> fromBits (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) bsBits - -instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (UInt n r c) (ByteString n c) where - from (UInt v) = ByteString $ fromCircuitF v solve - where - solve :: forall i m. MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector n i) - solve ui = do - let regs = V.fromVector ui - V.unsafeToVector <$> toBits (Haskell.reverse regs) (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) - -- -------------------------------------------------------------------------------- instance @@ -228,20 +211,20 @@ instance in bool @(Bool c) (q', rs) (q' + fromConstant ((2 :: Natural) ^ i), rs - d) (rs >= d) instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (ByteString n c) (UInt n r c) where - from (ByteString bits) = UInt $ symbolicF bits (\v -> naturalToVector @c @n @r $ vectorToNatural v (registerSize @a @n @r)) solve + from (ByteString bits) = UInt $ symbolicF bits (\v -> naturalToVector @c @n @r $ vectorToNatural v (registerSize @(BaseField c) @n @r)) solve where solve :: MonadCircuit v a m => Vector n v -> m (Vector (NumberOfRegisters a n r) v) solve xv = do let bsBits = V.fromVector xv - V.unsafeToVector . Haskell.reverse <$> fromBits (highRegisterSize @a @n @r) (registerSize @a @n @r) bsBits + V.unsafeToVector . Haskell.reverse <$> fromBits (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) bsBits instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (UInt n r c) (ByteString n c) where - from (UInt ac) = ByteString $ symbolicF ac (\v -> V.unsafeToVector $ fromConstant <$> toBsBits (vectorToNatural v (registerSize @a @n @r)) (value @n)) solve + from (UInt ac) = ByteString $ symbolicF ac (\v -> V.unsafeToVector $ fromConstant <$> toBsBits (vectorToNatural v (registerSize @(BaseField c) @n @r)) (value @n)) solve where - solve :: MonadCircuit v a m => Vector (NumberOfRegisters a n r) v -> m (Vector n v) + solve :: MonadCircuit v (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) v -> m (Vector n v) solve xv = do let regs = V.fromVector xv - V.unsafeToVector <$> toBits (Haskell.reverse regs) (highRegisterSize @a @n @r) (registerSize @a @n @r) + V.unsafeToVector <$> toBits (Haskell.reverse regs) (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Ord (Bool c) (UInt n r c) where x <= y = y >= x diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index db2dedf7a..b08c62b85 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -13,7 +13,7 @@ import Tests.Arithmetization.Test2 (specArithmetizatio import Tests.Arithmetization.Test3 (specArithmetization3) import Tests.Arithmetization.Test4 (specArithmetization4) -import ZkFold.Base.Algebra.Basic.Class (FromConstant, MultiplicativeMonoid, Scale) +import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 import ZkFold.Base.Data.Vector (Vector) diff --git a/tests/Tests/NonInteractiveProof/Plonk.hs b/tests/Tests/NonInteractiveProof/Plonk.hs index 154ecef2a..12a112abd 100644 --- a/tests/Tests/NonInteractiveProof/Plonk.hs +++ b/tests/Tests/NonInteractiveProof/Plonk.hs @@ -57,7 +57,7 @@ propPlonkConstraintSatisfaction (TestData (Plonk _ _ _ iPub ac _) w) = in all ((== zero) . f) $ transpose [ql', qr', qo', qm', qc', toList $ fromPolyVec w1', toList $ fromPolyVec w2', toList $ fromPolyVec w3', toList $ fromPolyVec wPub] -propPlonkPolyIdentity :: forall n core . NonInteractiveProofTestData (PlonkBS n) core -> Bool +propPlonkPolyIdentity :: forall n core . KnownNat n => NonInteractiveProofTestData (PlonkBS n) core -> Bool propPlonkPolyIdentity (TestData plonk w) = let zH = polyVecZero @(ScalarField BLS12_381_G1) @PlonkPolyLengthBS @PlonkPolyExtendedLengthBS From 4a66d631c5b90091bd17b3da57d4137492998709 Mon Sep 17 00:00:00 2001 From: echatav Date: Thu, 22 Aug 2024 16:26:09 +0000 Subject: [PATCH 18/48] stylish-haskell auto-commit --- tests/Tests/Arithmetization/Test4.hs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index 835a97ab9..4c7f82ba4 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -11,14 +11,15 @@ import Test.Hspec (Spec, desc import Test.QuickCheck (Testable (..), withMaxSuccess, (==>)) import Tests.NonInteractiveProof.Plonk (PlonkBS) -import ZkFold.Base.Algebra.Basic.Class (FromConstant (..), one, zero, (+)) -import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1) -import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) -import qualified ZkFold.Base.Data.Vector as V -import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkInput (..), PlonkProverSecret, - PlonkWitnessInput (..), plonkVerifierInput) -import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams) -import ZkFold.Base.Protocol.NonInteractiveProof (CoreFunction, HaskellCore, NonInteractiveProof (..)) +import ZkFold.Base.Algebra.Basic.Class (FromConstant (..), one, zero, (+)) +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1) +import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) +import qualified ZkFold.Base.Data.Vector as V +import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkInput (..), PlonkProverSecret, + PlonkWitnessInput (..), plonkVerifierInput) +import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams) +import ZkFold.Base.Protocol.NonInteractiveProof (CoreFunction, HaskellCore, + NonInteractiveProof (..)) import ZkFold.Symbolic.Class import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), compile, compileForceOne, eval) From f9057769195fd9366e9bbc5ab71faf1fdb92ab2a Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Thu, 22 Aug 2024 09:45:38 -0700 Subject: [PATCH 19/48] input variable generation --- src/ZkFold/Base/Algebra/Basic/Class.hs | 4 ++++ src/ZkFold/Symbolic/Cardano/Types/Address.hs | 4 +++- src/ZkFold/Symbolic/Cardano/Types/Output.hs | 3 +++ .../Symbolic/Compiler/ArithmeticCircuit.hs | 2 +- .../Compiler/ArithmeticCircuit/Combinators.hs | 16 ++++++++-------- .../Compiler/ArithmeticCircuit/Instance.hs | 4 ++-- .../Compiler/ArithmeticCircuit/Internal.hs | 15 ++++++++------- .../Compiler/ArithmeticCircuit/MonadBlueprint.hs | 11 +++++++---- tests/Tests/Arithmetization.hs | 5 +++-- 9 files changed, 39 insertions(+), 25 deletions(-) diff --git a/src/ZkFold/Base/Algebra/Basic/Class.hs b/src/ZkFold/Base/Algebra/Basic/Class.hs index 87a105106..2851d449e 100644 --- a/src/ZkFold/Base/Algebra/Basic/Class.hs +++ b/src/ZkFold/Base/Algebra/Basic/Class.hs @@ -9,6 +9,7 @@ module ZkFold.Base.Algebra.Basic.Class where import Data.Bool (bool) import Data.Foldable (foldl') import Data.Kind (Type) +import Data.Void (Void, absurd) import GHC.Natural (naturalFromInteger) import Prelude hiding (Num (..), div, divMod, length, mod, negate, product, replicate, sum, (/), (^)) @@ -43,6 +44,9 @@ class ToConstant a b where instance ToConstant a a where toConstant = id +instance ToConstant Void b where + toConstant = absurd + -------------------------------------------------------------------------------- {- | A class of types with a binary associative operation with a multiplicative diff --git a/src/ZkFold/Symbolic/Cardano/Types/Address.hs b/src/ZkFold/Symbolic/Cardano/Types/Address.hs index 82bf1d420..9e4cb2654 100644 --- a/src/ZkFold/Symbolic/Cardano/Types/Address.hs +++ b/src/ZkFold/Symbolic/Cardano/Types/Address.hs @@ -7,6 +7,8 @@ import Data.Functor.Rep (Representable (..)) import Prelude hiding (Bool, Eq, length, splitAt, (*), (+)) import qualified Prelude as Haskell +import ZkFold.Base.Algebra.Basic.Class (ToConstant) +import ZkFold.Base.Algebra.Basic.Number (Natural) import ZkFold.Base.Control.HApplicative (HApplicative) import ZkFold.Symbolic.Cardano.Types.Basic import ZkFold.Symbolic.Data.Class @@ -25,7 +27,7 @@ deriving instance (Haskell.Eq (ByteString 4 context), Haskell.Eq (ByteString 224 deriving instance HApplicative context => SymbolicData context (Address context) deriving via (Structural (Address (CtxCompilation i))) - instance (Ord (Rep i), Representable i) + instance (Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => Eq (Bool (CtxCompilation i)) (Address (CtxCompilation i)) addressType :: Address context -> AddressType context diff --git a/src/ZkFold/Symbolic/Cardano/Types/Output.hs b/src/ZkFold/Symbolic/Cardano/Types/Output.hs index 4fbf3fee3..d27380d4c 100644 --- a/src/ZkFold/Symbolic/Cardano/Types/Output.hs +++ b/src/ZkFold/Symbolic/Cardano/Types/Output.hs @@ -15,6 +15,7 @@ import Data.Functor.Rep (Representable (..)) import Prelude hiding (Bool, Eq, length, splitAt, (*), (+)) import qualified Prelude as Haskell +import ZkFold.Base.Algebra.Basic.Class (ToConstant) import ZkFold.Base.Algebra.Basic.Number import ZkFold.Symbolic.Cardano.Types.Address (Address) import ZkFold.Symbolic.Cardano.Types.Basic @@ -47,7 +48,9 @@ deriving via (Structural (Output tokens datum (CtxCompilation i))) , KnownNat tokens , KnownNat (TypeSize (CtxCompilation i) (Value tokens (CtxCompilation i))) , Ord (Rep i) + , Foldable i , Representable i + , ToConstant (Rep i) Natural ) => Eq (Bool (CtxCompilation i)) (Output tokens datum (CtxCompilation i)) txoAddress :: Output tokens datum context -> Address context diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index df2c0cd30..87cb3f412 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -64,7 +64,7 @@ optimize :: ArithmeticCircuit a i o -> ArithmeticCircuit a i o optimize = id -- | Desugars range constraints into polynomial constraints -desugarRanges :: (Arithmetic a, Ord (Rep i), Representable i) => ArithmeticCircuit a i o -> ArithmeticCircuit a i o +desugarRanges :: (Arithmetic a, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => ArithmeticCircuit a i o -> ArithmeticCircuit a i o desugarRanges c = let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(NewVar k, v) | (k,v) <- toList (acRange c)] in r' { acRange = mempty, acOutput = acOutput c } diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs index dc0cc8694..02b1c6e96 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs @@ -45,13 +45,13 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arit import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint import ZkFold.Symbolic.MonadCircuit -boolCheckC :: (Arithmetic a, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f +boolCheckC :: (Arithmetic a, Traversable f, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f -- ^ @boolCheckC r@ computes @r (r - 1)@ in one PLONK constraint. boolCheckC r = circuitF $ do is <- runCircuit r for is $ \i -> newAssigned (\x -> let xi = x i in xi * (xi - one)) -foldCircuit :: forall n i a. (Arithmetic a, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => v -> v -> m v) -> ArithmeticCircuit a i (Vector n) -> ArithmeticCircuit a i Par1 +foldCircuit :: forall n i a. (Arithmetic a, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => (forall v m . MonadBlueprint i v a m => v -> v -> m v) -> ArithmeticCircuit a i (Vector n) -> ArithmeticCircuit a i Par1 foldCircuit f c = circuit $ do outputs <- runCircuit c let (element, rest) = V.uncons outputs @@ -59,16 +59,16 @@ foldCircuit f c = circuit $ do -- | TODO: Think about circuits with multiple outputs -- -embed :: (Arithmetic a, Ord (Rep i), Representable i) => a -> ArithmeticCircuit a i Par1 +embed :: (Arithmetic a, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => a -> ArithmeticCircuit a i Par1 embed x = circuit $ newAssigned $ const (fromConstant x) -embedV :: (Arithmetic a, Traversable f, Ord (Rep i), Representable i) => f a -> ArithmeticCircuit a i f +embedV :: (Arithmetic a, Traversable f, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => f a -> ArithmeticCircuit a i f embedV v = circuitF $ for v $ \x -> newAssigned $ const (fromConstant x) embedVar :: forall a . a -> (forall i v m . MonadBlueprint i v a m => m v) embedVar x = newAssigned $ const (fromConstant x) -embedAll :: forall a i n . (Arithmetic a, KnownNat n, Ord (Rep i), Representable i) => a -> ArithmeticCircuit a i (Vector n) +embedAll :: forall a i n . (Arithmetic a, KnownNat n, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => a -> ArithmeticCircuit a i (Vector n) embedAll x = circuitF $ Vector <$> replicateM (fromIntegral $ value @n) (newAssigned $ const (fromConstant x)) expansion :: MonadCircuit v a m => Natural -> v -> m [v] @@ -125,15 +125,15 @@ desugarRange i b | c == zero = ($ j) * (one - ($ k)) | otherwise = one + ($ k) * (($ j) - one) -forceOne :: (Arithmetic a, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f +forceOne :: (Arithmetic a, Traversable f, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f forceOne r = circuitF $ do is' <- runCircuit r for is' $ \i -> constraint (\x -> x i - one) $> i -isZeroC :: (Arithmetic a, Z.Zip f, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f +isZeroC :: (Arithmetic a, Z.Zip f, Traversable f, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f isZeroC r = circuitF $ runCircuit r >>= fmap fst . runInvert -invertC :: (Arithmetic a, Z.Zip f, Traversable f, Ord (Rep i), Representable i) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f +invertC :: (Arithmetic a, Z.Zip f, Traversable f, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => ArithmeticCircuit a i f -> ArithmeticCircuit a i f invertC r = circuitF $ runCircuit r >>= fmap snd . runInvert runInvert :: (MonadCircuit v a m, Z.Zip f, Traversable f) => f v -> m (f v, f v) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs index 845513a7c..e765ec8fa 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs @@ -26,13 +26,13 @@ import ZkFold.Symbolic.Data.FieldElement (FieldEl ------------------------------------- Instances ------------------------------------- -instance (Arithmetic a, Arbitrary a, Arbitrary (Rep i), Haskell.Ord (Rep i), Representable i, Haskell.Foldable i) => Arbitrary (ArithmeticCircuit a i Par1) where +instance (Arithmetic a, Arbitrary a, Arbitrary (Rep i), Haskell.Ord (Rep i), Representable i, Haskell.Foldable i, ToConstant (Rep i) Natural) => Arbitrary (ArithmeticCircuit a i Par1) where arbitrary = do outVar <- InVar <$> arbitrary let ac = mempty {acOutput = Par1 outVar} fromFieldElement <$> arbitrary' (FieldElement ac) 10 -arbitrary' :: forall a i . (Arithmetic a, Arbitrary a, FromConstant a a, Haskell.Ord (Rep i), Representable i, Haskell.Foldable i) => FieldElement (ArithmeticCircuit a i) -> Natural -> Gen (FieldElement (ArithmeticCircuit a i)) +arbitrary' :: forall a i . (Arithmetic a, Arbitrary a, FromConstant a a, Haskell.Ord (Rep i), Representable i, Haskell.Foldable i, ToConstant (Rep i) Natural) => FieldElement (ArithmeticCircuit a i) -> Natural -> Gen (FieldElement (ArithmeticCircuit a i)) arbitrary' ac 0 = return ac arbitrary' ac iter = do let vars = getAllVars (fromFieldElement ac) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 4c35d291c..8adadf360 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -111,13 +111,13 @@ instance (Eq a, MultiplicativeMonoid a) => Package (ArithmeticCircuit a i) where unpackWith f (behead -> (c, o)) = crown c <$> f o packWith f (unzipDefault . fmap behead -> (cs, os)) = crown (fold cs) (f os) -instance (Arithmetic a, Ord (Rep i), Representable i) => Symbolic (ArithmeticCircuit a i) where +instance (Arithmetic a, Ord (Rep i), Representable i, Foldable i, ToConstant (Rep i) Natural) => Symbolic (ArithmeticCircuit a i) where type BaseField (ArithmeticCircuit a i) = a symbolicF (behead -> (c, o)) _ f = uncurry (set #acOutput) (runState (f o) c) -------------------------------- MonadCircuit instance ------------------------------ -instance (Arithmetic a, Ord (Rep i), Representable i, o ~ U1) => MonadCircuit (Var i) a (State (ArithmeticCircuit a i o)) where +instance (Arithmetic a, Ord (Rep i), Representable i, Foldable i, o ~ U1, ToConstant (Rep i) Natural) => MonadCircuit (Var i) a (State (ArithmeticCircuit a i o)) where newRanged upperBound witness = do let s = sources @a witness b = fromConstant upperBound @@ -193,18 +193,19 @@ toField :: Arithmetic a => a -> VarField toField = toZp . fromConstant . fromBinary @Natural . castBits . binaryExpansion -- TODO: Remove the hardcoded constant. -toVar :: Arithmetic a => [Var i] -> Constraint a i -> Natural +toVar :: forall a i. (Arithmetic a, ToConstant (Rep i) Natural, Representable i, Foldable i) => [Var i] -> Constraint a i -> Natural toVar srcs c = force $ fromZp ex where + l = Haskell.fromIntegral (Haskell.length (tabulate @i (\_ -> error "can't reach"))) r = toZp 903489679376934896793395274328947923579382759823 :: VarField g = toZp 89175291725091202781479751781509570912743212325 :: VarField - varF (NewVar w) = w - varF (InVar _) = 0 + varF (NewVar w) = w + l + varF (InVar inV) = toConstant inV v = (+ r) . fromConstant . varF x = g ^ fromZp (evalPolynomial evalMonomial v $ mapCoeffs toField c) ex = foldr (\p y -> x ^ (varF p) + y) x srcs -newVariableWithSource :: Arithmetic a => [Var i] -> (Var i -> Constraint a i) -> State (ArithmeticCircuit a i U1) Natural +newVariableWithSource :: (Arithmetic a, ToConstant (Rep i) Natural, Representable i, Foldable i) => [Var i] -> (Var i -> Constraint a i) -> State (ArithmeticCircuit a i U1) Natural newVariableWithSource srcs con = toVar srcs . con . NewVar . fst <$> do zoom #acRNG $ get >>= traverse put . uniformR (0, order @VarField -! 1) @@ -222,7 +223,7 @@ type ConstraintMonomial = Mono Natural Natural type Constraint c i = Poly c (Var i) Natural -- | Adds a constraint to the arithmetic circuit. -addConstraint :: Arithmetic a => Constraint a i -> State (ArithmeticCircuit a i U1) () +addConstraint :: (Arithmetic a, Foldable i, Representable i, ToConstant (Rep i) Natural) => Constraint a i -> State (ArithmeticCircuit a i U1) () addConstraint c = zoom #acSystem . modify $ insert (toVar [] c) c rangeConstraint :: Natural -> a -> State (ArithmeticCircuit a i U1) () diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs index 1da89284e..4337d93f8 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MonadBlueprint.hs @@ -13,12 +13,15 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint ( import Control.Applicative (pure) import Control.Monad.State (State, modify, runState) +import Data.Foldable (Foldable) import Data.Functor (Functor, fmap, ($>), (<$>)) import Data.Functor.Rep (Representable (..)) import Data.Monoid (mempty, (<>)) import Data.Ord (Ord) import GHC.Generics (Par1, U1 (..)) +import ZkFold.Base.Algebra.Basic.Class (ToConstant) +import ZkFold.Base.Algebra.Basic.Number (Natural) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal hiding (constraint) import ZkFold.Symbolic.MonadCircuit @@ -28,22 +31,22 @@ class MonadCircuit v a m => MonadBlueprint i v a m | m -> i where -- | Adds the supplied circuit to the blueprint and returns its output variable. runCircuit :: ArithmeticCircuit a i o -> m (o v) -instance (Arithmetic a, Ord (Rep i), Representable i) => MonadBlueprint i (Var i) a (State (ArithmeticCircuit a i U1)) where +instance (Arithmetic a, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => MonadBlueprint i (Var i) a (State (ArithmeticCircuit a i U1)) where runCircuit r = modify (<> r {acOutput = U1}) $> acOutput r -circuit :: (Arithmetic a, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => m v) -> ArithmeticCircuit a i Par1 +circuit :: (Arithmetic a, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => (forall v m . MonadBlueprint i v a m => m v) -> ArithmeticCircuit a i Par1 -- ^ Builds a circuit from blueprint. A blueprint is a function which, given an -- arbitrary type of variables @i@ and a monad @m@ supporting the 'MonadBlueprint' -- API, computes the output variable of a future circuit. circuit b = circuitF (pure <$> b) -circuitF :: forall a i o . (Arithmetic a, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => m (o v)) -> ArithmeticCircuit a i o +circuitF :: forall a i o . (Arithmetic a, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => (forall v m . MonadBlueprint i v a m => m (o v)) -> ArithmeticCircuit a i o -- TODO: I should really rethink this... circuitF b = let (os, r) = runState b mempty in r { acOutput = os } -- TODO: kept for compatibility with @binaryExpansion@ only. Perhaps remove it in the future? -circuits :: forall a i o . (Arithmetic a, Functor o, Ord (Rep i), Representable i) => (forall v m . MonadBlueprint i v a m => m (o v)) -> o (ArithmeticCircuit a i Par1) +circuits :: forall a i o . (Arithmetic a, Functor o, Ord (Rep i), Foldable i, Representable i, ToConstant (Rep i) Natural) => (forall v m . MonadBlueprint i v a m => m (o v)) -> o (ArithmeticCircuit a i Par1) -- ^ Builds a collection of circuits from one blueprint. A blueprint is a function -- which, given an arbitrary type of variables @i@ and a monad @m@ supporting the -- 'MonadBlueprint' API, computes the collection of output variables of future circuits. diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index b08c62b85..c1a2e461f 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -13,7 +13,8 @@ import Tests.Arithmetization.Test2 (specArithmetizatio import Tests.Arithmetization.Test3 (specArithmetization3) import Tests.Arithmetization.Test4 (specArithmetization4) -import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale) +import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale, ToConstant) +import ZkFold.Base.Algebra.Basic.Number (Natural) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 import ZkFold.Base.Data.Vector (Vector) @@ -27,7 +28,7 @@ propCircuitInvariance act@(ArithmeticCircuitTest ac wi) = v' = ac' `eval` wi' in v == v' -specArithmetization' :: forall a i . (FromConstant a a, Scale a a, Arithmetic a, Arbitrary a, Show a, Show (ArithmeticCircuitTest a i Par1), Arbitrary (Rep i), Ord (Rep i), Representable i, Traversable i) => IO () +specArithmetization' :: forall a i . (FromConstant a a, Scale a a, Arithmetic a, Arbitrary a, Show a, Show (ArithmeticCircuitTest a i Par1), Arbitrary (Rep i), Ord (Rep i), Representable i, Traversable i, ToConstant (Rep i) Natural) => IO () specArithmetization' = hspec $ do describe "Arithmetization specification" $ do describe "Variable mapping" $ do From b8954f57dc847366917227124289a35b63ae542e Mon Sep 17 00:00:00 2001 From: echatav Date: Thu, 22 Aug 2024 16:49:46 +0000 Subject: [PATCH 20/48] stylish-haskell auto-commit --- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs | 4 ++-- tests/Tests/Arithmetization.hs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 8adadf360..d08032c72 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -199,8 +199,8 @@ toVar srcs c = force $ fromZp ex l = Haskell.fromIntegral (Haskell.length (tabulate @i (\_ -> error "can't reach"))) r = toZp 903489679376934896793395274328947923579382759823 :: VarField g = toZp 89175291725091202781479751781509570912743212325 :: VarField - varF (NewVar w) = w + l - varF (InVar inV) = toConstant inV + varF (NewVar w) = w + l + varF (InVar inV) = toConstant inV v = (+ r) . fromConstant . varF x = g ^ fromZp (evalPolynomial evalMonomial v $ mapCoeffs toField c) ex = foldr (\p y -> x ^ (varF p) + y) x srcs diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index c1a2e461f..cf4bb552e 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -14,8 +14,8 @@ import Tests.Arithmetization.Test3 (specArithmetizatio import Tests.Arithmetization.Test4 (specArithmetization4) import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale, ToConstant) -import ZkFold.Base.Algebra.Basic.Number (Natural) import ZkFold.Base.Algebra.Basic.Field (Zp) +import ZkFold.Base.Algebra.Basic.Number (Natural) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler From dcd8244586ec15fa319d5fc85b7216f16942a087 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Thu, 22 Aug 2024 09:55:09 -0700 Subject: [PATCH 21/48] loeb trick --- .../Compiler/ArithmeticCircuit/Internal.hs | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 8adadf360..c8b6ddffb 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -26,7 +26,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( ) where import Control.DeepSeq (NFData, force) -import Control.Monad.State (MonadState (..), State, gets, modify, runState) +import Control.Monad.State (MonadState (..), State, modify, runState) import Data.Aeson (FromJSON, FromJSONKey, ToJSON, ToJSONKey) import Data.Foldable (fold) import Data.Functor.Rep (Representable (..), fmapRep) @@ -60,7 +60,7 @@ data ArithmeticCircuit a i o = ArithmeticCircuit -- ^ The system of polynomial constraints acRange :: Map Natural a, -- ^ The range constraints [0, a] for the selected variables - acWitness :: Map Natural (i a -> a), + acWitness :: Map Natural (i a -> Map Natural a -> a), -- ^ The witness generation functions acVarOrder :: Map (Natural, Natural) Natural, -- ^ The order of variable assignments @@ -90,7 +90,10 @@ deriving instance NFData (Rep i) => NFData (Var i) witnessGenerator :: ArithmeticCircuit a i o -> i a -> Map Natural a witnessGenerator circuit inputs = - fmap ($ inputs) (acWitness circuit) + let + result = fmap (\k -> k inputs result) (acWitness circuit) + in + result ------------------------------ Symbolic compiler context ---------------------------- @@ -127,10 +130,10 @@ instance (Arithmetic a, Ord (Rep i), Representable i, Foldable i, o ~ U1, ToCons p i = b * var i * (var i - b) i <- addVariable =<< newVariableWithSource (S.toList s) p rangeConstraint i upperBound - currentWitness <- gets acWitness - assignment i $ \m -> witness $ \case + -- currentWitness <- gets acWitness + assignment i $ \m currentWitness -> witness $ \case InVar inV -> index m inV - NewVar newV -> (currentWitness ! newV) m + NewVar newV -> currentWitness ! newV return (NewVar i) newConstrained @@ -147,10 +150,9 @@ instance (Arithmetic a, Ord (Rep i), Representable i, Foldable i, o ~ U1, ToCons s = ws `S.difference` sources @a (`new` x) i <- addVariable =<< newVariableWithSource (S.toList s) (new var) constraint (`new` (NewVar i)) - currentWitness <- gets acWitness - assignment i $ \m -> witness $ \case + assignment i $ \m currentWitness -> witness $ \case InVar inV -> index m inV - NewVar newV -> (currentWitness ! newV) m + NewVar newV -> currentWitness ! newV return (NewVar i) constraint p = addConstraint (p var) @@ -231,7 +233,7 @@ rangeConstraint i b = zoom #acRange . modify $ insert i b -- | Adds a new variable assignment to the arithmetic circuit. -- TODO: forbid reassignment of variables -assignment :: Natural -> (i a -> a) -> State (ArithmeticCircuit a i U1) () +assignment :: Natural -> (i a -> Map Natural a -> a) -> State (ArithmeticCircuit a i U1) () assignment i f = zoom #acWitness . modify $ insert i f -- | Evaluates the arithmetic circuit with one output using the supplied input map. From 34900de75ed035e4240f4a2f3a2fd2682263e4a4 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Thu, 22 Aug 2024 09:56:06 -0700 Subject: [PATCH 22/48] Update Internal.hs --- src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index 153120698..f2f8a938d 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -19,9 +19,6 @@ import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) import ZkFold.Base.Data.Vector (Vector) import ZkFold.Prelude (log2ceiling, take) -log2 :: (Integral a, Integral b) => a -> b -log2 = ceiling @Double . logBase 2 . fromIntegral - getParams :: forall a . (Eq a, FiniteField a) => Natural -> (a, a, a) getParams n = findK' $ mkStdGen 0 where From 086cdfbce77ab048a5ae88d4ce9e3dce5794f99b Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Thu, 22 Aug 2024 10:01:00 -0700 Subject: [PATCH 23/48] tests arbitrary --- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs | 8 +++----- tests/Tests/Arithmetization.hs | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index a7dec2fc0..f26d54ed5 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -10,7 +10,6 @@ import Data.Functor.Rep (Represe import Data.Map hiding (drop, foldl, foldr, fromList, map, null, splitAt, take, toList) import qualified Data.Map as Map -import Data.Traversable (for) import GHC.Generics (Par1) import GHC.IsList (IsList (..)) import Prelude hiding (Num (..), drop, length, product, @@ -20,8 +19,7 @@ import Test.QuickCheck (Arbitra import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (getAllVars) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Var (..), - acInput) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Var (..)) -- This module contains functions for mapping variables in arithmetic circuits. @@ -34,11 +32,11 @@ data ArithmeticCircuitTest a i o = ArithmeticCircuitTest instance (Show (ArithmeticCircuit a i o), Show a, Show (i a)) => Show (ArithmeticCircuitTest a i o) where show (ArithmeticCircuitTest ac wi) = show ac ++ ",\nwitnessInput: " ++ show wi -instance (Arithmetic a, Arbitrary a, Arbitrary (ArithmeticCircuit a i Par1), Traversable i, Representable i) => Arbitrary (ArithmeticCircuitTest a i Par1) where +instance (Arithmetic a, Arbitrary (i a), Arbitrary (ArithmeticCircuit a i Par1), Representable i) => Arbitrary (ArithmeticCircuitTest a i Par1) where arbitrary :: Gen (ArithmeticCircuitTest a i Par1) arbitrary = do ac <- arbitrary - wi <- for acInput $ \_ -> arbitrary + wi <- arbitrary return ArithmeticCircuitTest { arithmeticCircuit = ac , witnessInput = wi diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index cf4bb552e..b973a682c 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -28,7 +28,7 @@ propCircuitInvariance act@(ArithmeticCircuitTest ac wi) = v' = ac' `eval` wi' in v == v' -specArithmetization' :: forall a i . (FromConstant a a, Scale a a, Arithmetic a, Arbitrary a, Show a, Show (ArithmeticCircuitTest a i Par1), Arbitrary (Rep i), Ord (Rep i), Representable i, Traversable i, ToConstant (Rep i) Natural) => IO () +specArithmetization' :: forall a i . (FromConstant a a, Scale a a, Arithmetic a, Arbitrary a, Arbitrary (i a), Show a, Show (ArithmeticCircuitTest a i Par1), Arbitrary (Rep i), Ord (Rep i), Representable i, Traversable i, ToConstant (Rep i) Natural) => IO () specArithmetization' = hspec $ do describe "Arithmetization specification" $ do describe "Variable mapping" $ do From 1cc83e725d2501c16834df850ab1580e87b25fea Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 23 Aug 2024 10:44:11 -0700 Subject: [PATCH 24/48] Update Compiler.hs --- src/ZkFold/Symbolic/Compiler.hs | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler.hs b/src/ZkFold/Symbolic/Compiler.hs index 010a29e1f..f4e18bc2e 100644 --- a/src/ZkFold/Symbolic/Compiler.hs +++ b/src/ZkFold/Symbolic/Compiler.hs @@ -38,10 +38,12 @@ import ZkFold.Symbolic.Data.Class -- | Arithmetizes an argument by feeding an appropriate amount of inputs. solder :: - forall a c f . + forall a c f ni . ( Eq a , MultiplicativeMonoid a - , c ~ ArithmeticCircuit a + , KnownNat ni + , ni ~ TypeSize (Support f) + , c ~ ArithmeticCircuit a (Vector ni) , SymbolicData f , Context f ~ c , SymbolicData (Support f) @@ -56,16 +58,16 @@ solder f = pieces f (restore @(Support f) $ const inputC) -- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1. compileForceOne :: - forall n a c f y . - ( n ~ TypeSize c (Support c f) - , c ~ ArithmeticCircuit a (Vector n) + forall a c f y ni . + ( KnownNat ni + , ni ~ TypeSize (Support f) + , c ~ ArithmeticCircuit a (Vector ni) , Arithmetic a , SymbolicData f , Context f ~ c , SymbolicData (Support f) , Context (Support f) ~ c , Support (Support f) ~ Proxy c - , KnownNat (TypeSize (Support f)) , SymbolicData y , Context y ~ c , Support y ~ Proxy c @@ -75,16 +77,17 @@ compileForceOne = restore . const . optimize . forceOne . solder @a -- | Compiles function `f` into an arithmetic circuit. compile :: - forall n a c f y . + forall a c f y ni . ( Eq a , MultiplicativeMonoid a - , c ~ ArithmeticCircuit a + , KnownNat ni + , ni ~ TypeSize (Support f) + , c ~ ArithmeticCircuit a (Vector ni) , SymbolicData f , Context f ~ c , SymbolicData (Support f) , Context (Support f) ~ c , Support (Support f) ~ Proxy c - , KnownNat (TypeSize (Support f)) , SymbolicData y , Context y ~ c , Support y ~ Proxy c @@ -94,18 +97,18 @@ compile = restore . const . optimize . solder @a -- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file. compileIO :: - forall n a c f . + forall a c f ni . ( Eq a , MultiplicativeMonoid a - , n ~ TypeSize c (Support c f) - , c ~ ArithmeticCircuit a (Vector n) + , KnownNat ni + , ni ~ TypeSize (Support f) + , c ~ ArithmeticCircuit a (Vector ni) , ToJSON a , SymbolicData f , Context f ~ c , SymbolicData (Support f) , Context (Support f) ~ c , Support (Support f) ~ Proxy c - , KnownNat (TypeSize (Support f)) ) => FilePath -> f -> IO () compileIO scriptFile f = do let ac = optimize (solder @a f) :: c (Vector (TypeSize f)) From 208a0d837e9a056d9b24d1dc10ead649d366df38 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 23 Aug 2024 10:53:36 -0700 Subject: [PATCH 25/48] remove type aps --- examples/Examples/BatchTransfer.hs | 2 +- examples/Examples/ByteString.hs | 4 ++-- examples/Examples/Conditional.hs | 2 +- examples/Examples/Eq.hs | 2 +- examples/Examples/FFA.hs | 2 +- examples/Examples/Fibonacci.hs | 2 +- examples/Examples/LEQ.hs | 2 +- examples/Examples/MiMCHash.hs | 2 +- examples/Examples/ReverseList.hs | 2 +- examples/Examples/UInt.hs | 2 +- src/ZkFold/Symbolic/Cardano/Types/Output.hs | 10 +++++++--- src/ZkFold/Symbolic/Compiler.hs | 5 +---- tests/Tests/Arithmetization/Test1.hs | 2 +- tests/Tests/Arithmetization/Test2.hs | 2 +- tests/Tests/Arithmetization/Test3.hs | 2 +- tests/Tests/Arithmetization/Test4.hs | 10 +++++----- tests/Tests/Blake2b.hs | 2 +- 17 files changed, 28 insertions(+), 27 deletions(-) diff --git a/examples/Examples/BatchTransfer.hs b/examples/Examples/BatchTransfer.hs index 984468fc7..f7553f2e2 100644 --- a/examples/Examples/BatchTransfer.hs +++ b/examples/Examples/BatchTransfer.hs @@ -17,4 +17,4 @@ exampleBatchTransfer = do putStrLn "\nExample: Batch Transfer smart contract\n" - compileIO @151810 @F file (batchTransfer @(CtxCompilation (Vector 151810))) + compileIO @F file (batchTransfer @(CtxCompilation (Vector 151810))) diff --git a/examples/Examples/ByteString.hs b/examples/Examples/ByteString.hs index 630d81091..18438b5e2 100644 --- a/examples/Examples/ByteString.hs +++ b/examples/Examples/ByteString.hs @@ -36,7 +36,7 @@ exampleByteStringExtend = do let k = show $ natVal (Proxy @k) putStrLn $ "\nExample: Extending a bytestring of length " ++ n ++ " to length " ++ k let file = "compiled_scripts/bytestring" ++ n ++ "_to_" ++ k ++ ".json" - compileIO @n @(Zp BLS12_381_Scalar) file $ extend @(ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector n))) @(ByteString k (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector n))) + compileIO @(Zp BLS12_381_Scalar) file $ extend @(ByteString n (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector n))) @(ByteString k (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector n))) type Binary a = a -> a -> a @@ -47,4 +47,4 @@ makeExample shortName name op = do let n = show $ natVal (Proxy @n) putStrLn $ "\nExample: (" ++ shortName ++ ") operation on ByteString" ++ n let file = "compiled_scripts/bytestring" ++ n ++ "_" ++ name ++ ".json" - compileIO @(n+n) @(Zp BLS12_381_Scalar) file op + compileIO @(Zp BLS12_381_Scalar) file op diff --git a/examples/Examples/Conditional.hs b/examples/Examples/Conditional.hs index 0c746cf1a..2909666df 100644 --- a/examples/Examples/Conditional.hs +++ b/examples/Examples/Conditional.hs @@ -22,4 +22,4 @@ exampleConditional = do putStrLn "\nExample: conditional\n" - compileIO @3 @F file (bool @B @(A Par1)) + compileIO @F file (bool @B @(A Par1)) diff --git a/examples/Examples/Eq.hs b/examples/Examples/Eq.hs index 848a434da..2cfd5dfe7 100644 --- a/examples/Examples/Eq.hs +++ b/examples/Examples/Eq.hs @@ -24,4 +24,4 @@ exampleEq = do putStrLn "\nExample: (==) operation\n" - compileIO @2 @(Zp BLS12_381_Scalar) file (eq @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 2))) + compileIO @(Zp BLS12_381_Scalar) file (eq @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 2))) diff --git a/examples/Examples/FFA.hs b/examples/Examples/FFA.hs index e475e9d19..02ced1734 100644 --- a/examples/Examples/FFA.hs +++ b/examples/Examples/FFA.hs @@ -40,4 +40,4 @@ makeExample shortName name op = do let p = show $ value @p putStrLn $ "\nExample: (" ++ shortName ++ ") operation on FFA " ++ p let file = "compiled_scripts/ffa_" ++ p ++ "_" ++ name ++ ".json" - compileIO @14 @(Zp BLS12_381_Scalar) file op + compileIO @(Zp BLS12_381_Scalar) file op diff --git a/examples/Examples/Fibonacci.hs b/examples/Examples/Fibonacci.hs index 7cbc96e60..cdbcd8a3a 100644 --- a/examples/Examples/Fibonacci.hs +++ b/examples/Examples/Fibonacci.hs @@ -30,4 +30,4 @@ exampleFibonacci = do putStrLn "\nExample: Fibonacci index function\n" - compileIO @1 @(Zp BLS12_381_Scalar) file (fibonacciIndex @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 1)) nMax) + compileIO @(Zp BLS12_381_Scalar) file (fibonacciIndex @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 1)) nMax) diff --git a/examples/Examples/LEQ.hs b/examples/Examples/LEQ.hs index 3535a532d..d2bdde104 100644 --- a/examples/Examples/LEQ.hs +++ b/examples/Examples/LEQ.hs @@ -24,4 +24,4 @@ exampleLEQ = do putStrLn "\nExample: (<=) operation\n" - compileIO @2 @(Zp BLS12_381_Scalar) file (leq @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 2))) + compileIO @(Zp BLS12_381_Scalar) file (leq @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 2))) diff --git a/examples/Examples/MiMCHash.hs b/examples/Examples/MiMCHash.hs index 3f8b5ceb1..f616ca02d 100644 --- a/examples/Examples/MiMCHash.hs +++ b/examples/Examples/MiMCHash.hs @@ -22,4 +22,4 @@ exampleMiMC = do putStrLn "\nExample: MiMC hash function\n" - compileIO @2 @F file (mimcHash2 @F @(FieldElement (ArithmeticCircuit F (Vector 2))) mimcConstants zero) + compileIO @F file (mimcHash2 @F @(FieldElement (ArithmeticCircuit F (Vector 2))) mimcConstants zero) diff --git a/examples/Examples/ReverseList.hs b/examples/Examples/ReverseList.hs index 4b58de22b..0dc10ce6d 100644 --- a/examples/Examples/ReverseList.hs +++ b/examples/Examples/ReverseList.hs @@ -22,4 +22,4 @@ exampleReverseList = do putStrLn "\nExample: Reverse List function\n" - compileIO @32 @(Zp BLS12_381_Scalar) file (reverseList @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 32) Par1) @32) + compileIO @(Zp BLS12_381_Scalar) file (reverseList @(ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 32) Par1) @32) diff --git a/examples/Examples/UInt.hs b/examples/Examples/UInt.hs index 679f060ce..39ea76aa4 100644 --- a/examples/Examples/UInt.hs +++ b/examples/Examples/UInt.hs @@ -80,4 +80,4 @@ makeExample shortName name op = do let n = show $ natVal (Proxy @n) putStrLn $ "\nExample: (" ++ shortName ++ ") operation on UInt" ++ n let file = "compiled_scripts/uint" ++ n ++ "_" ++ name ++ ".json" - compileIO @(Num n + Num n) @(Zp BLS12_381_Scalar) file op + compileIO @(Zp BLS12_381_Scalar) file op diff --git a/src/ZkFold/Symbolic/Cardano/Types/Output.hs b/src/ZkFold/Symbolic/Cardano/Types/Output.hs index fc1e5d1a1..3d5435b9b 100644 --- a/src/ZkFold/Symbolic/Cardano/Types/Output.hs +++ b/src/ZkFold/Symbolic/Cardano/Types/Output.hs @@ -43,11 +43,15 @@ deriving instance deriving via (Structural (Output tokens datum (CtxCompilation i))) instance - ( ts ~ TypeSize (Output tokens datum CtxCompilation) + ( ts ~ TypeSize (Output tokens datum (CtxCompilation i)) , 1 <= ts , KnownNat tokens - , KnownNat (TypeSize (Value tokens CtxCompilation)) - ) => Eq (Bool CtxCompilation) (Output tokens datum CtxCompilation) + , KnownNat (TypeSize (Value tokens (CtxCompilation i))) + , Ord (Rep i) + , Representable i + , Foldable i + , ToConstant (Rep i) Natural + ) => Eq (Bool (CtxCompilation i)) (Output tokens datum (CtxCompilation i)) txoAddress :: Output tokens datum context -> Address context txoAddress (Output (addr, _)) = addr diff --git a/src/ZkFold/Symbolic/Compiler.hs b/src/ZkFold/Symbolic/Compiler.hs index f4e18bc2e..30db65378 100644 --- a/src/ZkFold/Symbolic/Compiler.hs +++ b/src/ZkFold/Symbolic/Compiler.hs @@ -49,14 +49,11 @@ solder :: , SymbolicData (Support f) , Context (Support f) ~ c , Support (Support f) ~ Proxy c - , KnownNat (TypeSize (Support f)) ) => f -> c (Vector (TypeSize f)) solder f = pieces f (restore @(Support f) $ const inputC) where - inputList = [1..(typeSize @(Support f))] - inputC = mempty { acInput = inputList, acOutput = unsafeToVector inputList } + inputC = mempty { acOutput = acInput } --- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1. compileForceOne :: forall a c f y ni . ( KnownNat ni diff --git a/tests/Tests/Arithmetization/Test1.hs b/tests/Tests/Arithmetization/Test1.hs index 11ad65cc1..f99663cd9 100644 --- a/tests/Tests/Arithmetization/Test1.hs +++ b/tests/Tests/Arithmetization/Test1.hs @@ -38,5 +38,5 @@ specArithmetization1 :: forall a . (FromConstant a a, Arithmetic a, Arbitrary a, specArithmetization1 = do describe "Arithmetization test 1" $ do it "should pass" $ do - let ac = compile @2 @a (testFunc @(ArithmeticCircuit a (Vector 2))) :: ArithmeticCircuit a (Vector 2) Par1 + let ac = compile @a (testFunc @(ArithmeticCircuit a (Vector 2))) :: ArithmeticCircuit a (Vector 2) Par1 property $ \x y -> testResult ac x y diff --git a/tests/Tests/Arithmetization/Test2.hs b/tests/Tests/Arithmetization/Test2.hs index 7bfaa6687..13782ae8c 100644 --- a/tests/Tests/Arithmetization/Test2.hs +++ b/tests/Tests/Arithmetization/Test2.hs @@ -26,7 +26,7 @@ tautology x y = (x /= y) || (x == y) testTautology :: forall a . Arithmetic a => a -> a -> Haskell.Bool testTautology x y = - let Bool (ac :: ArithmeticCircuit a (Vector 2) Par1) = compile @2 @a (tautology @(ArithmeticCircuit a (Vector 2))) + let Bool (ac :: ArithmeticCircuit a (Vector 2) Par1) = compile @a (tautology @(ArithmeticCircuit a (Vector 2))) b = unPar1 (eval ac (unsafeToVector [x, y])) in b Haskell.== one diff --git a/tests/Tests/Arithmetization/Test3.hs b/tests/Tests/Arithmetization/Test3.hs index ac5da8775..a693e67c4 100644 --- a/tests/Tests/Arithmetization/Test3.hs +++ b/tests/Tests/Arithmetization/Test3.hs @@ -27,5 +27,5 @@ specArithmetization3 :: Spec specArithmetization3 = do describe "Arithmetization test 3" $ do it "should pass" $ do - let Bool r = compile @2 @(Zp 97) (testFunc @R) :: Bool R + let Bool r = compile @(Zp 97) (testFunc @R) :: Bool R Bool (Interpreter (eval r (unsafeToVector [3, 5]))) `shouldBe` testFunc (fromConstant (3 :: Natural)) (fromConstant (5 :: Natural)) diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index 4c7f82ba4..2bc4ac0a6 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -38,19 +38,19 @@ lockedByTxId targetValue inputValue = inputValue == fromConstant targetValue testSameValue :: F -> Haskell.Bool testSameValue targetValue = - let Bool ac = compile @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) + let Bool ac = compile @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) b = unPar1 (eval ac (V.singleton targetValue)) in b Haskell.== one testDifferentValue :: F -> F -> Haskell.Bool testDifferentValue targetValue otherValue = - let Bool ac = compile @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) + let Bool ac = compile @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) b = unPar1 (eval ac (V.singleton otherValue)) in b Haskell.== zero testOnlyOutputZKP :: forall core . (CoreFunction C core) => F -> PlonkProverSecret C -> F -> Haskell.Bool testOnlyOutputZKP x ps targetValue = - let Bool ac = compile @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) + let Bool ac = compile @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue @@ -70,7 +70,7 @@ testOnlyOutputZKP x ps targetValue = testSafeOneInputZKP :: forall core . (CoreFunction C core) => F -> PlonkProverSecret C -> F -> Haskell.Bool testSafeOneInputZKP x ps targetValue = - let Bool ac = compileForceOne @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) + let Bool ac = compileForceOne @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue @@ -87,7 +87,7 @@ testSafeOneInputZKP x ps targetValue = testAttackSafeOneInputZKP :: forall core . (CoreFunction C core) => F -> PlonkProverSecret C -> F -> Haskell.Bool testAttackSafeOneInputZKP x ps targetValue = - let Bool ac = compileForceOne @1 @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) + let Bool ac = compileForceOne @F (lockedByTxId @F @(ArithmeticCircuit F (V.Vector 1)) targetValue) :: Bool (ArithmeticCircuit F (V.Vector 1)) (omega, k1, k2) = getParams 32 witnessInputs = V.singleton (targetValue + 1) diff --git a/tests/Tests/Blake2b.hs b/tests/Tests/Blake2b.hs index a842dc1e0..d02a4ab4d 100644 --- a/tests/Tests/Blake2b.hs +++ b/tests/Tests/Blake2b.hs @@ -42,7 +42,7 @@ blake2bSimple = blake2bAC :: Spec blake2bAC = - let bs = compile (blake2b_512 @8) :: ByteString 512 (ArithmeticCircuit (Zp BLS12_381_Scalar)) + let bs = compile blake2b_512 :: ByteString 512 (ArithmeticCircuit (Zp BLS12_381_Scalar) (Vector 8)) ac = pieces bs Proxy in it "simple test with cardano-crypto " $ acSizeN ac == 564239 From a82eafc8ef480a8ba92f88444ec337bc4e5cccff Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 23 Aug 2024 11:15:16 -0700 Subject: [PATCH 26/48] Update BatchTransfer.hs --- examples/Examples/BatchTransfer.hs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/Examples/BatchTransfer.hs b/examples/Examples/BatchTransfer.hs index f7553f2e2..4d3d9f05c 100644 --- a/examples/Examples/BatchTransfer.hs +++ b/examples/Examples/BatchTransfer.hs @@ -6,7 +6,6 @@ module Examples.BatchTransfer (exampleBatchTransfer) where import Prelude hiding (Eq (..), Num (..), any, not, (!!), (/), (^), (||)) -import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Cardano.Contracts.BatchTransfer (batchTransfer) import ZkFold.Symbolic.Cardano.Types import ZkFold.Symbolic.Compiler (compileIO) @@ -17,4 +16,4 @@ exampleBatchTransfer = do putStrLn "\nExample: Batch Transfer smart contract\n" - compileIO @F file (batchTransfer @(CtxCompilation (Vector 151810))) + compileIO @F file (batchTransfer) From 802e08c30b64d7a2401fbdc8bc705802dffc9bd1 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 23 Aug 2024 11:18:39 -0700 Subject: [PATCH 27/48] Update BatchTransfer.hs --- examples/Examples/BatchTransfer.hs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/Examples/BatchTransfer.hs b/examples/Examples/BatchTransfer.hs index 4d3d9f05c..89072a160 100644 --- a/examples/Examples/BatchTransfer.hs +++ b/examples/Examples/BatchTransfer.hs @@ -7,7 +7,6 @@ import Prelude hiding (Eq (..) (||)) import ZkFold.Symbolic.Cardano.Contracts.BatchTransfer (batchTransfer) -import ZkFold.Symbolic.Cardano.Types import ZkFold.Symbolic.Compiler (compileIO) exampleBatchTransfer :: IO () @@ -16,4 +15,4 @@ exampleBatchTransfer = do putStrLn "\nExample: Batch Transfer smart contract\n" - compileIO @F file (batchTransfer) + compileIO file batchTransfer From a073cf8da6d0908a5deb75ec5f2388a47b4ce4a3 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 23 Aug 2024 11:21:40 -0700 Subject: [PATCH 28/48] Update BatchTransfer.hs --- examples/Examples/BatchTransfer.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/Examples/BatchTransfer.hs b/examples/Examples/BatchTransfer.hs index 89072a160..a96af8bee 100644 --- a/examples/Examples/BatchTransfer.hs +++ b/examples/Examples/BatchTransfer.hs @@ -7,6 +7,7 @@ import Prelude hiding (Eq (..) (||)) import ZkFold.Symbolic.Cardano.Contracts.BatchTransfer (batchTransfer) +import ZkFold.Symbolic.Cardano.Types import ZkFold.Symbolic.Compiler (compileIO) exampleBatchTransfer :: IO () @@ -15,4 +16,4 @@ exampleBatchTransfer = do putStrLn "\nExample: Batch Transfer smart contract\n" - compileIO file batchTransfer + compileIO @F file batchTransfer From 4a8b8c5398708ef92a4e4ae47d97c79ed26d7a44 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Fri, 23 Aug 2024 12:21:58 -0700 Subject: [PATCH 29/48] Update Compiler.hs --- src/ZkFold/Symbolic/Compiler.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ZkFold/Symbolic/Compiler.hs b/src/ZkFold/Symbolic/Compiler.hs index 30db65378..71eddbac5 100644 --- a/src/ZkFold/Symbolic/Compiler.hs +++ b/src/ZkFold/Symbolic/Compiler.hs @@ -54,6 +54,7 @@ solder f = pieces f (restore @(Support f) $ const inputC) where inputC = mempty { acOutput = acInput } +-- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1. compileForceOne :: forall a c f y ni . ( KnownNat ni From d058683f002f8ce2caceba2b6424499c103b965e Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Mon, 26 Aug 2024 09:39:00 -0700 Subject: [PATCH 30/48] Update Examples.hs --- examples/ZkFold/Symbolic/Examples.hs | 46 ++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/examples/ZkFold/Symbolic/Examples.hs b/examples/ZkFold/Symbolic/Examples.hs index f77e3abfe..9159d95d1 100644 --- a/examples/ZkFold/Symbolic/Examples.hs +++ b/examples/ZkFold/Symbolic/Examples.hs @@ -2,7 +2,6 @@ module ZkFold.Symbolic.Examples (ExampleOutput (..), examples) where -import Control.DeepSeq (NFData) import Data.Function (const, ($), (.)) import Data.Proxy (Proxy) import Data.String (String) @@ -19,7 +18,7 @@ import Examples.ReverseList (exampleReverseList import Examples.UInt import ZkFold.Base.Algebra.Basic.Field (Zp) -import ZkFold.Base.Algebra.Basic.Number (KnownNat, Natural) +import ZkFold.Base.Algebra.Basic.Number (KnownNat) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler (ArithmeticCircuit, compile) @@ -30,20 +29,41 @@ import ZkFold.Symbolic.Data.Combinators (RegisterSize (Auto type C = ArithmeticCircuit (Zp BLS12_381_Scalar) data ExampleOutput where - ExampleOutput :: forall o. NFData (o Natural) => (() -> C o) -> ExampleOutput + ExampleOutput + :: forall i_n o_n. (() -> C (Vector i_n) (Vector o_n)) -> ExampleOutput exampleOutput :: - forall n f. - ( SymbolicData f - , Context f ~ C - , TypeSize f ~ n + forall i_n o_n c f. + ( KnownNat i_n + , i_n ~ TypeSize (Support f) + , SymbolicData f + , c ~ C (Vector i_n) + , Context f ~ c + , TypeSize f ~ o_n , SymbolicData (Support f) - , Context (Support f) ~ C - , Support (Support f) ~ Proxy C - , KnownNat (TypeSize (Support f)) + , Context (Support f) ~ c + , Support (Support f) ~ Proxy c ) => f -> ExampleOutput -exampleOutput = ExampleOutput @(Vector n) . const . compile - +exampleOutput = ExampleOutput @i_n @o_n . const . compile +{- +compile :: + forall a c f y ni . + ( Eq a + , MultiplicativeMonoid a + , KnownNat ni + , ni ~ TypeSize (Support f) + , c ~ ArithmeticCircuit a (Vector ni) + , SymbolicData f + , Context f ~ c + , SymbolicData (Support f) + , Context (Support f) ~ c + , Support (Support f) ~ Proxy c + , SymbolicData y + , Context y ~ c + , Support y ~ Proxy c + , TypeSize f ~ TypeSize y + ) => f -> y +-} examples :: [(String, ExampleOutput)] examples = [ ("Eq", exampleOutput exampleEq) @@ -57,7 +77,7 @@ examples = , ("UInt.StrictAdd.256.Auto", exampleOutput $ exampleUIntStrictAdd @256 @Auto) , ("UInt.StrictMul.512.Auto", exampleOutput $ exampleUIntStrictMul @512 @Auto) , ("UInt.DivMod.32.Auto", exampleOutput $ exampleUIntDivMod @32 @Auto) - , ("Reverse.32.3000", exampleOutput $ exampleReverseList @32 @(ByteString 3000 C)) + , ("Reverse.32.3000", exampleOutput $ exampleReverseList @32 @(ByteString 3000 (C (Vector _)))) , ("Fibonacci.100", exampleOutput $ exampleFibonacci 100) , ("MiMCHash", exampleOutput exampleMiMC) , ("SHA256.32", exampleOutput $ exampleSHA @32) From efa0e4af01dccf594544d70106aa593bfa80a206 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Mon, 26 Aug 2024 09:41:31 -0700 Subject: [PATCH 31/48] Update Examples.hs --- examples/ZkFold/Symbolic/Examples.hs | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/examples/ZkFold/Symbolic/Examples.hs b/examples/ZkFold/Symbolic/Examples.hs index 9159d95d1..3138ee01c 100644 --- a/examples/ZkFold/Symbolic/Examples.hs +++ b/examples/ZkFold/Symbolic/Examples.hs @@ -45,25 +45,7 @@ exampleOutput :: , Support (Support f) ~ Proxy c ) => f -> ExampleOutput exampleOutput = ExampleOutput @i_n @o_n . const . compile -{- -compile :: - forall a c f y ni . - ( Eq a - , MultiplicativeMonoid a - , KnownNat ni - , ni ~ TypeSize (Support f) - , c ~ ArithmeticCircuit a (Vector ni) - , SymbolicData f - , Context f ~ c - , SymbolicData (Support f) - , Context (Support f) ~ c - , Support (Support f) ~ Proxy c - , SymbolicData y - , Context y ~ c - , Support y ~ Proxy c - , TypeSize f ~ TypeSize y - ) => f -> y --} + examples :: [(String, ExampleOutput)] examples = [ ("Eq", exampleOutput exampleEq) From d16651916bdc373874d99c741e7c9d744af2f77e Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Mon, 26 Aug 2024 10:15:07 -0700 Subject: [PATCH 32/48] fix benchmarks --- bench/BenchCompiler.hs | 24 +++++++++++++++--------- examples/ZkFold/Symbolic/Examples.hs | 4 +++- zkfold-base.cabal | 1 + 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/bench/BenchCompiler.hs b/bench/BenchCompiler.hs index 93eb448f9..f15516e62 100644 --- a/bench/BenchCompiler.hs +++ b/bench/BenchCompiler.hs @@ -1,39 +1,45 @@ +{-# LANGUAGE TypeOperators #-} module Main where import Control.DeepSeq (NFData, force) import Control.Monad (return) import Data.ByteString.Lazy (ByteString) -import Data.Function (($)) -import Data.Map (Map, fromAscList) +import Data.Function (($), const) +import Data.Functor.Rep (Representable(..)) import Data.Semigroup ((<>)) import Data.String (String, fromString) -import Numeric.Natural (Natural) +import Data.Type.Equality (type (~)) +import GHC.TypeNats (KnownNat) import System.IO (IO) import Test.Tasty.Bench import Test.Tasty.Golden (goldenVsString) import Text.Show (show) import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid, zero) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Examples +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var) -inputMap :: AdditiveMonoid a => ArithmeticCircuit a o -> Map Natural a -inputMap circuit = fromAscList [ (i, zero) | i <- acInput circuit ] +-- inputMap :: AdditiveMonoid a => ArithmeticCircuit a i o -> Map Natural a +-- inputMap circuit = fromAscList [ (i, zero) | i <- acInput circuit ] -metrics :: String -> ArithmeticCircuit a o -> ByteString +metrics :: String -> ArithmeticCircuit a i o -> ByteString metrics name circuit = fromString name <> "\nNumber of constraints: " <> fromString (show $ acSizeN circuit) <> "\nNumber of variables: " <> fromString (show $ acSizeM circuit) <> "\nNumber of range lookups: " <> fromString (show $ acSizeR circuit) + benchmark :: - (NFData a, AdditiveMonoid a, NFData (o Natural)) => - String -> (() -> ArithmeticCircuit a o) -> Benchmark + (NFData a, AdditiveMonoid a, NFData (o (Var i)), NFData (Rep i), i ~ Vector n_i, KnownNat n_i) => + String -> (() -> ArithmeticCircuit a i o) -> Benchmark benchmark name circuit = bgroup name [ bench "compilation" $ nf circuit () , env (return $ force $ circuit ()) $ \c -> - let input = inputMap c + let + input = tabulate (const zero) path = "stats/" <> name in bgroup "on compilation" [ bench "evaluation" $ nf (witnessGenerator c) input diff --git a/examples/ZkFold/Symbolic/Examples.hs b/examples/ZkFold/Symbolic/Examples.hs index 3138ee01c..218bbb7ad 100644 --- a/examples/ZkFold/Symbolic/Examples.hs +++ b/examples/ZkFold/Symbolic/Examples.hs @@ -30,7 +30,9 @@ type C = ArithmeticCircuit (Zp BLS12_381_Scalar) data ExampleOutput where ExampleOutput - :: forall i_n o_n. (() -> C (Vector i_n) (Vector o_n)) -> ExampleOutput + :: forall i_n o_n. KnownNat i_n + => (() -> C (Vector i_n) (Vector o_n)) + -> ExampleOutput exampleOutput :: forall i_n o_n c f. diff --git a/zkfold-base.cabal b/zkfold-base.cabal index 587ae1580..5692c2370 100644 --- a/zkfold-base.cabal +++ b/zkfold-base.cabal @@ -334,6 +334,7 @@ benchmark compiler-benchmark -fprof-cafs -O3 build-depends: + adjunctions, base, bytestring, containers, From 82366a09a0b6992cabdda7e979b3cf3ae975bb2a Mon Sep 17 00:00:00 2001 From: echatav Date: Mon, 26 Aug 2024 17:18:24 +0000 Subject: [PATCH 33/48] stylish-haskell auto-commit --- bench/BenchCompiler.hs | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/bench/BenchCompiler.hs b/bench/BenchCompiler.hs index f15516e62..22260dd5a 100644 --- a/bench/BenchCompiler.hs +++ b/bench/BenchCompiler.hs @@ -1,25 +1,25 @@ {-# LANGUAGE TypeOperators #-} module Main where -import Control.DeepSeq (NFData, force) -import Control.Monad (return) -import Data.ByteString.Lazy (ByteString) -import Data.Function (($), const) -import Data.Functor.Rep (Representable(..)) -import Data.Semigroup ((<>)) -import Data.String (String, fromString) -import Data.Type.Equality (type (~)) -import GHC.TypeNats (KnownNat) -import System.IO (IO) +import Control.DeepSeq (NFData, force) +import Control.Monad (return) +import Data.ByteString.Lazy (ByteString) +import Data.Function (const, ($)) +import Data.Functor.Rep (Representable (..)) +import Data.Semigroup ((<>)) +import Data.String (String, fromString) +import Data.Type.Equality (type (~)) +import GHC.TypeNats (KnownNat) +import System.IO (IO) import Test.Tasty.Bench -import Test.Tasty.Golden (goldenVsString) -import Text.Show (show) +import Test.Tasty.Golden (goldenVsString) +import Text.Show (show) -import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid, zero) -import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid, zero) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler -import ZkFold.Symbolic.Examples import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var) +import ZkFold.Symbolic.Examples -- inputMap :: AdditiveMonoid a => ArithmeticCircuit a i o -> Map Natural a -- inputMap circuit = fromAscList [ (i, zero) | i <- acInput circuit ] From 9dd46a2bea43a7a2b8ec91e6e3c39c864bf790b5 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Mon, 26 Aug 2024 10:22:01 -0700 Subject: [PATCH 34/48] Update BenchCompiler.hs --- bench/BenchCompiler.hs | 3 --- 1 file changed, 3 deletions(-) diff --git a/bench/BenchCompiler.hs b/bench/BenchCompiler.hs index f15516e62..d2dfae269 100644 --- a/bench/BenchCompiler.hs +++ b/bench/BenchCompiler.hs @@ -21,9 +21,6 @@ import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Examples import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var) --- inputMap :: AdditiveMonoid a => ArithmeticCircuit a i o -> Map Natural a --- inputMap circuit = fromAscList [ (i, zero) | i <- acInput circuit ] - metrics :: String -> ArithmeticCircuit a i o -> ByteString metrics name circuit = fromString name From b0626aaf5eb5fad1137f19a3d1b2eedfb458bbdd Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 10:33:30 -0700 Subject: [PATCH 35/48] fixes --- .../Symbolic/Compiler/ArithmeticCircuit.hs | 8 +++--- .../Compiler/ArithmeticCircuit/Instance.hs | 5 ++-- .../Compiler/ArithmeticCircuit/Internal.hs | 9 +++---- .../Compiler/ArithmeticCircuit/Map.hs | 11 +++----- tests/Tests/ArithmeticCircuit.hs | 25 +++++++------------ 5 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index aedf8fc03..1cf0a6eea 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -35,13 +35,15 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit ( import Control.Monad (foldM) import Control.Monad.State (execState) +import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldr, map, null, splitAt, take) +import Data.Void (absurd) import GHC.Generics (U1 (..)) import Numeric.Natural (Natural) import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) -import Test.QuickCheck (Arbitrary, Property, conjoin, property, vector, +import Test.QuickCheck (Arbitrary, Property, arbitrary, conjoin, property, withMaxSuccess, (===)) import Text.Pretty.Simple (pPrint) @@ -49,8 +51,8 @@ import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial) import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, - apply, eval, eval1, exec, exec1, witnessGenerator) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, acInput, + apply, eval, eval1, exec, exec1, witnessGenerator, Var (..)) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map import ZkFold.Symbolic.Data.Combinators (expansion) import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..)) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs index 0eaf82e8b..71b0bb8ab 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs @@ -7,16 +7,15 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance where import Data.Aeson hiding (Bool) +import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldl', foldr, map, null, splitAt, take, toList) import GHC.Generics (Par1 (..)) -import GHC.Num (integerToNatural) import Prelude (Show, mempty, pure, return, show, ($), (++), (<$>)) import qualified Prelude as Haskell import System.Random (mkStdGen) -import Test.QuickCheck (Arbitrary (arbitrary), Gen, chooseInteger, - elements) +import Test.QuickCheck (Arbitrary (arbitrary), Gen, elements) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 1adfd684e..9b23898ef 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -27,10 +27,10 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( ) where import Control.DeepSeq (NFData, force) -import Control.Monad.State (MonadState (..), State, gets, modify, runState) +import Control.Monad.State (MonadState (..), State, modify, runState) import Data.Aeson import Data.Containers.ListUtils (nubOrd) -import Data.Foldable (fold) +import Data.Foldable (fold, toList) import Data.Functor.Rep import Data.List (sort) import Data.Map.Strict hiding (drop, foldl, foldr, map, null, splitAt, take, @@ -39,7 +39,6 @@ import qualified Data.Map.Strict as M hiding (toLis import Data.Semialign (unzipDefault) import qualified Data.Set as S import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..)) -import GHC.IsList (IsList (toList)) import Optics import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) @@ -281,8 +280,8 @@ apply xs ac = ac -- let inputs = acInput -- zoom #acWitness . modify . union . fromList $ zip inputs (map const xs) -getAllVars :: MultiplicativeMonoid a => ArithmeticCircuit a i o -> [Natural] -getAllVars ac = nubOrd $ sort $ 0 : acInput ac ++ concatMap (toList . variables) (elems $ acSystem ac) +getAllVars :: (MultiplicativeMonoid a, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuit a i o -> [Var i] +getAllVars ac = nubOrd $ sort $ NewVar 0 : toList acInput ++ concatMap (toList . variables) (elems $ acSystem ac) -- TODO: Add proper symbolic application functions diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index 8963c6add..345d580d5 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -6,22 +6,19 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map ( ArithmeticCircuitTest(..) ) where +import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldr, fromList, map, null, splitAt, take, toList) import qualified Data.Map as Map import GHC.Generics (Par1) import GHC.IsList (IsList (..)) -import GHC.Natural (naturalToInteger) -import GHC.Num (integerToInt) -import Numeric.Natural (Natural) import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) -import Test.QuickCheck (Arbitrary (arbitrary), Gen, vector) +import Test.QuickCheck (Arbitrary (arbitrary), Gen) -import ZkFold.Base.Algebra.Basic.Class (MultiplicativeMonoid (..)) +import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate -import ZkFold.Prelude (length) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), getAllVars) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), getAllVars, Var (..)) -- This module contains functions for mapping variables in arithmetic circuits. diff --git a/tests/Tests/ArithmeticCircuit.hs b/tests/Tests/ArithmeticCircuit.hs index 83401d791..7510589dc 100644 --- a/tests/Tests/ArithmeticCircuit.hs +++ b/tests/Tests/ArithmeticCircuit.hs @@ -1,11 +1,10 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE IncoherentInstances #-} {-# LANGUAGE TypeApplications #-} module Tests.ArithmeticCircuit (exec1, it, specArithmeticCircuit) where import Data.Bool (bool) -import GHC.Generics (Par1 (..)) +import GHC.Generics (U1 (..)) import Prelude (IO, Show, String, id, ($)) import qualified Prelude as Haskell import qualified Test.Hspec @@ -26,17 +25,17 @@ import ZkFold.Symbolic.Data.FieldElement correctHom0 :: forall a . (Arithmetic a, Scale a a, Show a) => (forall b . Field b => b) -> Property correctHom0 f = let r = fromFieldElement f in withMaxSuccess 1 $ checkClosedCircuit r .&&. exec1 r === f @a -correctHom1 :: forall a . (Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a))) => (forall b . Field b => b -> b) -> a -> Property +correctHom1 :: forall a . (Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a U1))) => (forall b . Field b => b -> b) -> a -> Property correctHom1 f x = let r = fromFieldElement $ f (fromConstant x) in checkClosedCircuit r .&&. exec1 r === f x -correctHom2 :: forall a . (Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a))) => (forall b . Field b => b -> b -> b) -> a -> a -> Property +correctHom2 :: forall a . (Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a U1))) => (forall b . Field b => b -> b -> b) -> a -> a -> Property correctHom2 f x y = let r = fromFieldElement $ f (fromConstant x) (fromConstant y) in checkClosedCircuit r .&&. exec1 r === f x y it :: Testable prop => String -> prop -> Spec it desc prop = Test.Hspec.it desc (property prop) -specArithmeticCircuit' :: forall a . (Arbitrary a, Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a))) => IO () +specArithmeticCircuit' :: forall a . (Arbitrary a, Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a U1))) => IO () specArithmeticCircuit' = hspec $ do describe "ArithmeticCircuit specification" $ do it "embeds constants" $ correctHom1 @a id @@ -54,20 +53,14 @@ specArithmeticCircuit' = hspec $ do -- let Bool (r :: ArithmeticCircuit a U1 Par1) = isZero (zero :: FieldElement (ArithmeticCircuit a U1)) -- in withMaxSuccess 1 $ checkClosedCircuit r .&&. exec1 r === one it "computes binary expansion" $ \(x :: a) -> - let rs = binaryExpansion (fromConstant x :: FieldElement (ArithmeticCircuit a)) + let rs = binaryExpansion (fromConstant x :: FieldElement (ArithmeticCircuit a U1)) in checkClosedCircuit rs .&&. V.fromVector (exec rs) === padBits (numberOfBits @a) (binaryExpansion x) it "internalizes equality" $ \(x :: a) (y :: a) -> - let Bool (r :: ArithmeticCircuit a U1 Par1) = (embed x :: ArithmeticCircuit a U1 Par1) == embed y - in checkClosedCircuit r .&&. exec1 r === bool zero one (x Haskell.== y) + let Bool r = (fromConstant x :: FieldElement (ArithmeticCircuit a U1)) == fromConstant y + in checkClosedCircuit @a r .&&. exec1 r === bool zero one (x Haskell.== y) it "internal equality is reflexive" $ \(x :: a) -> - let Bool (r :: ArithmeticCircuit a U1 Par1) = (embed x :: ArithmeticCircuit a U1 Par1) == embed x --- ======= --- let Bool (r :: ArithmeticCircuit a Par1) = fromFieldElement (fromConstant x :: FieldElement (ArithmeticCircuit a)) == fromFieldElement (fromConstant y) --- in checkClosedCircuit r .&&. exec1 r === bool zero one (x Haskell.== y) --- it "internal equality is reflexive" $ \(x :: a) -> --- let Bool (r :: ArithmeticCircuit a Par1) = fromFieldElement (fromConstant x :: FieldElement (ArithmeticCircuit a)) == fromFieldElement (fromConstant x) --- >>>>>>> main - in checkClosedCircuit r .&&. exec1 r === one + let Bool r = (fromConstant x :: FieldElement (ArithmeticCircuit a U1)) == fromConstant x + in checkClosedCircuit @a r .&&. exec1 r === one specArithmeticCircuit :: IO () specArithmeticCircuit = do From 3b8aeb1462d996d053f6dad765960cac3b3e5e9d Mon Sep 17 00:00:00 2001 From: echatav Date: Tue, 27 Aug 2024 17:37:13 +0000 Subject: [PATCH 36/48] stylish-haskell auto-commit --- .../Symbolic/Compiler/ArithmeticCircuit.hs | 5 ++-- .../Compiler/ArithmeticCircuit/Map.hs | 3 +- src/ZkFold/Symbolic/Data/FieldElement.hs | 28 +++++++++---------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 1cf0a6eea..5d422dff0 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -51,8 +51,9 @@ import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial) import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, acInput, - apply, eval, eval1, exec, exec1, witnessGenerator, Var (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, + Var (..), acInput, apply, eval, eval1, exec, + exec1, witnessGenerator) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map import ZkFold.Symbolic.Data.Combinators (expansion) import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..)) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index 345d580d5..1b9684d9d 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -18,7 +18,8 @@ import Test.QuickCheck (Arbitrary import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), getAllVars, Var (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Var (..), + getAllVars) -- This module contains functions for mapping variables in arithmetic circuits. diff --git a/src/ZkFold/Symbolic/Data/FieldElement.hs b/src/ZkFold/Symbolic/Data/FieldElement.hs index 7fd282e43..bdb111393 100644 --- a/src/ZkFold/Symbolic/Data/FieldElement.hs +++ b/src/ZkFold/Symbolic/Data/FieldElement.hs @@ -4,26 +4,26 @@ module ZkFold.Symbolic.Data.FieldElement where -import Data.Foldable (foldr) -import Data.Function (($), (.)) -import Data.Functor (fmap, (<$>)) -import Data.Tuple (snd) -import GHC.Generics (Par1 (..)) -import Prelude (Integer) -import qualified Prelude as Haskell +import Data.Foldable (foldr) +import Data.Function (($), (.)) +import Data.Functor (fmap, (<$>)) +import Data.Tuple (snd) +import GHC.Generics (Par1 (..)) +import Prelude (Integer) +import qualified Prelude as Haskell import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number -import ZkFold.Base.Data.HFunctor (HFunctor, hmap) -import ZkFold.Base.Data.Par1 () -import ZkFold.Base.Data.Vector (Vector, fromVector, unsafeToVector) +import ZkFold.Base.Data.HFunctor (HFunctor, hmap) +import ZkFold.Base.Data.Par1 () +import ZkFold.Base.Data.Vector (Vector, fromVector, unsafeToVector) import ZkFold.Symbolic.Class -import ZkFold.Symbolic.Data.Bool (Bool) +import ZkFold.Symbolic.Data.Bool (Bool) import ZkFold.Symbolic.Data.Class -import ZkFold.Symbolic.Data.Combinators (expansion, horner, runInvert) -import ZkFold.Symbolic.Data.Eq (Eq) +import ZkFold.Symbolic.Data.Combinators (expansion, horner, runInvert) +import ZkFold.Symbolic.Data.Eq (Eq) import ZkFold.Symbolic.Data.Ord -import ZkFold.Symbolic.MonadCircuit (newAssigned) +import ZkFold.Symbolic.MonadCircuit (newAssigned) newtype FieldElement c = FieldElement { fromFieldElement :: c Par1 } From 2a899d18b117ff3e13e5d81355b3fafdaa936e28 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 11:00:14 -0700 Subject: [PATCH 37/48] plonk changes --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 47 ++++++++++--------- .../Base/Protocol/ARK/Plonk/Internal.hs | 4 +- .../Base/Protocol/ARK/Plonk/Relation.hs | 13 ++--- tests/Tests/Arithmetization/Test4.hs | 6 +-- tests/Tests/NonInteractiveProof/Plonk.hs | 6 +-- 5 files changed, 39 insertions(+), 37 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index d4856871f..416902e47 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -41,19 +41,19 @@ import ZkFold.Symbolic.MonadCircuit (Arithmetic) Additionally, we don't want this library to depend on Cardano libraries. -} -data Plonk (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonk { +data Plonk (i :: Natural) (n :: Natural) (l :: Natural) curve1 curve2 transcript = Plonk { omega :: ScalarField curve1, k1 :: ScalarField curve1, k2 :: ScalarField curve1, iPub :: Vector l Natural, - ac :: ArithmeticCircuit (ScalarField curve1) (Vector l) Par1, + ac :: ArithmeticCircuit (ScalarField curve1) (Vector i) Par1, x :: ScalarField curve1 } -instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l) => Show (Plonk n l c1 c2 t) where +instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l, KnownNat i) => Show (Plonk i n l c1 c2 t) where show (Plonk omega k1 k2 iPub ac x) = "Plonk: " ++ show omega ++ " " ++ show k1 ++ " " ++ show k2 ++ " " ++ show iPub ++ " " ++ show ac ++ " " ++ show x -instance (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1)) => Arbitrary (Plonk n l c1 c2 t) where +instance (KnownNat i, KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1)) => Arbitrary (Plonk i n l c1 c2 t) where arbitrary = do ac <- arbitrary let fullInp = value @l @@ -61,10 +61,10 @@ instance (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (Scalar let (omega, k1, k2) = getParams (value @n) Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary -instance forall n l c1 c2 t core . (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1), - Witness (Plonk n l c1 c2 t) ~ (PlonkWitnessInput l c1, PlonkProverSecret c1), NonInteractiveProof (Plonk n l c1 c2 t) core) => Arbitrary (NonInteractiveProofTestData (Plonk n l c1 c2 t) core) where +instance forall i n l c1 c2 t core . (KnownNat i, KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1), + Witness (Plonk i n l c1 c2 t) ~ (PlonkWitnessInput i c1, PlonkProverSecret c1), NonInteractiveProof (Plonk i n l c1 c2 t) core) => Arbitrary (NonInteractiveProofTestData (Plonk i n l c1 c2 t) core) where arbitrary = do - ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) (Vector l) Par1) + ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) (Vector i) Par1) let inputLen = value @l vecPubInp <- genSubset (value @l) inputLen let (omega, k1, k2) = getParams $ value @n @@ -72,8 +72,8 @@ instance forall n l c1 c2 t core . (KnownNat n, KnownNat l, Arithmetic (ScalarFi secret <- arbitrary return $ TestData pl (PlonkWitnessInput wi, secret) -plonkPermutation :: forall n l c1 c2 t . - (KnownNat n, FiniteField (ScalarField c1)) => Plonk n l c1 c2 t -> PlonkRelation n l (ScalarField c1) -> PlonkPermutation n c1 +plonkPermutation :: forall i n l c1 c2 t . + (KnownNat n, FiniteField (ScalarField c1)) => Plonk i n l c1 c2 t -> PlonkRelation n i (ScalarField c1) -> PlonkPermutation n c1 plonkPermutation (Plonk omega k1 k2 _ _ _) PlonkRelation {..} = PlonkPermutation {..} where f i = case (i-!1) `div` value @n of @@ -87,11 +87,11 @@ plonkPermutation (Plonk omega k1 k2 _ _ _) PlonkRelation {..} = PlonkPermutation s2 = toPolyVec $ V.take (fromIntegral $ value @n) $ V.drop (fromIntegral $ value @n) s s3 = toPolyVec $ V.take (fromIntegral $ value @n) $ V.drop (fromIntegral $ 2 * value @n) s -plonkCircuitPolynomials :: forall n l c1 c2 t . +plonkCircuitPolynomials :: forall i n l c1 c2 t . (KnownNat n, KnownNat (PlonkPolyExtendedLength n), Eq (ScalarField c1), Field (ScalarField c1)) - => Plonk n l c1 c2 t + => Plonk i n l c1 c2 t -> PlonkPermutation n c1 - -> PlonkRelation n l (ScalarField c1) + -> PlonkRelation n i (ScalarField c1) -> PlonkCircuitPolynomials n c1 plonkCircuitPolynomials (Plonk omega _ _ _ _ _) @@ -110,12 +110,13 @@ plonkCircuitPolynomials plonkVerifierInput :: Field (ScalarField c) => Vector n (ScalarField c) -> PlonkInput c plonkVerifierInput input = PlonkInput $ fromList $ map negate $ fromVector input -instance forall n l c1 c2 t plonk f g1 core. - ( Plonk n l c1 c2 t ~ plonk +instance forall i n l c1 c2 t plonk f g1 core. + ( Plonk i n l c1 c2 t ~ plonk , ScalarField c1 ~ f , Point c1 ~ g1 , KnownNat n , KnownNat l + , KnownNat i , KnownNat (PlonkPermutationSize n) , KnownNat (PlonkPolyExtendedLength n) , Arithmetic f @@ -126,13 +127,13 @@ instance forall n l c1 c2 t plonk f g1 core. , ToTranscript t (PointCompressed c1) , FromTranscript t (ScalarField c1) , CoreFunction c1 core - ) => NonInteractiveProof (Plonk n l c1 c2 t) core where - type Transcript (Plonk n l c1 c2 t) = t - type SetupProve (Plonk n l c1 c2 t) = (PlonkSetupParamsProve c1 c2, PlonkPermutation n c1, PlonkCircuitPolynomials n c1 , PlonkWitnessMap n l c1) - type SetupVerify (Plonk n l c1 c2 t) = (PlonkSetupParamsVerify c1 c2, PlonkCircuitCommitments c1) - type Witness (Plonk n l c1 c2 t) = (PlonkWitnessInput l c1, PlonkProverSecret c1) - type Input (Plonk n l c1 c2 t) = PlonkInput c1 - type Proof (Plonk n l c1 c2 t) = PlonkProof c1 + ) => NonInteractiveProof (Plonk i n l c1 c2 t) core where + type Transcript (Plonk i n l c1 c2 t) = t + type SetupProve (Plonk i n l c1 c2 t) = (PlonkSetupParamsProve c1 c2, PlonkPermutation n c1, PlonkCircuitPolynomials n c1 , PlonkWitnessMap n i c1) + type SetupVerify (Plonk i n l c1 c2 t) = (PlonkSetupParamsVerify c1 c2, PlonkCircuitCommitments c1) + type Witness (Plonk i n l c1 c2 t) = (PlonkWitnessInput i c1, PlonkProverSecret c1) + type Input (Plonk i n l c1 c2 t) = PlonkInput c1 + type Proof (Plonk i n l c1 c2 t) = PlonkProof c1 setupProve :: plonk -> SetupProve plonk setupProve plonk@(Plonk omega' k1' k2' iPub ac x) = @@ -145,7 +146,7 @@ instance forall n l c1 c2 t plonk f g1 core. h1' = x `mul` gen iPub' = fromList . fromVector $ iPub - pr = fromJust $ toPlonkRelation @n @l @f iPub ac + pr = fromJust $ toPlonkRelation @i @n @l @f iPub ac perm@PlonkPermutation {..} = plonkPermutation plonk pr PlonkCircuitPolynomials {..} = plonkCircuitPolynomials plonk perm pr @@ -163,7 +164,7 @@ instance forall n l c1 c2 t plonk f g1 core. pow'' = log2ceiling $ value @n n'' = fromIntegral $ value @n - pr = fromJust $ toPlonkRelation @n @l @f iPub ac + pr = fromJust $ toPlonkRelation @i @n @l @f iPub ac perm = plonkPermutation plonk pr PlonkCircuitPolynomials {..} = plonkCircuitPolynomials plonk perm pr diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index f2f8a938d..754345754 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -141,8 +141,8 @@ instance (Show (BaseField c), EllipticCurve c) => Show (PlonkCircuitCommitments newtype PlonkWitnessMap n l c = PlonkWitnessMap (Vector l (ScalarField c) -> (PolyVec (ScalarField c) n, PolyVec (ScalarField c) n, PolyVec (ScalarField c) n)) -newtype PlonkWitnessInput l c = PlonkWitnessInput (Vector l (ScalarField c)) -instance Show (ScalarField c) => Show (PlonkWitnessInput l c) where +newtype PlonkWitnessInput i c = PlonkWitnessInput (Vector i (ScalarField c)) +instance Show (ScalarField c) => Show (PlonkWitnessInput i c) where show (PlonkWitnessInput m) = "Witness Input: " ++ show m data PlonkProverSecret c = PlonkProverSecret { diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs index c6dadbee9..af9d33a7f 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs @@ -21,25 +21,26 @@ import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal -- Here `n` is the total number of constraints, `l` is the number of public inputs, and `a` is the field type. -data PlonkRelation n l a = PlonkRelation +data PlonkRelation n i a = PlonkRelation { qM :: PolyVec a n , qL :: PolyVec a n , qR :: PolyVec a n , qO :: PolyVec a n , qC :: PolyVec a n , sigma :: Permutation (3 * n) - , wmap :: Vector l a -> (PolyVec a n, PolyVec a n, PolyVec a n) + , wmap :: Vector i a -> (PolyVec a n, PolyVec a n, PolyVec a n) } -toPlonkRelation :: forall n l a . - KnownNat n +toPlonkRelation :: forall i n l a . + KnownNat i + => KnownNat n => KnownNat (3 * n) => KnownNat l => Arithmetic a => Scale a a => Vector l Natural - -> ArithmeticCircuit a (Vector l) Par1 - -> Maybe (PlonkRelation n l a) + -> ArithmeticCircuit a (Vector i) Par1 + -> Maybe (PlonkRelation n i a) toPlonkRelation xPub ac0 = let ac = desugarRanges ac0 diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index 2bc4ac0a6..a4db1f503 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -57,7 +57,7 @@ testOnlyOutputZKP x ps targetValue = indexOutputBool = V.singleton $ case unPar1 $ acOutput ac of NewVar ix -> ix + 1 InVar _ -> 1 - plonk = Plonk @32 omega k1 k2 indexOutputBool ac x + plonk = Plonk @1 @32 omega k1 k2 indexOutputBool ac x setupP = setupProve @(PlonkBS N) @core plonk setupV = setupVerify @(PlonkBS N) @core plonk witness = (PlonkWitnessInput witnessInputs, ps) @@ -75,7 +75,7 @@ testSafeOneInputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue indexTargetValue = V.singleton (1 :: Natural) - plonk = Plonk @32 omega k1 k2 indexTargetValue ac x + plonk = Plonk @1 @32 omega k1 k2 indexTargetValue ac x setupP = setupProve @(PlonkBS N) @core plonk setupV = setupVerify @(PlonkBS N) @core plonk witness = (PlonkWitnessInput witnessInputs, ps) @@ -92,7 +92,7 @@ testAttackSafeOneInputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton (targetValue + 1) indexTargetValue = V.singleton (1 :: Natural) - plonk = Plonk @32 omega k1 k2 indexTargetValue ac x + plonk = Plonk @1 @32 omega k1 k2 indexTargetValue ac x setupP = setupProve @(PlonkBS N) @core plonk setupV = setupVerify @(PlonkBS N) @core plonk witness = (PlonkWitnessInput witnessInputs, ps) diff --git a/tests/Tests/NonInteractiveProof/Plonk.hs b/tests/Tests/NonInteractiveProof/Plonk.hs index 12a112abd..6ec46fc5e 100644 --- a/tests/Tests/NonInteractiveProof/Plonk.hs +++ b/tests/Tests/NonInteractiveProof/Plonk.hs @@ -30,7 +30,7 @@ import ZkFold.Base.Protocol.NonInteractiveProof (HaskellCore, NonIn NonInteractiveProofTestData (..)) type PlonkPolyLengthBS = 32 -type PlonkBS n = Plonk PlonkPolyLengthBS n BLS12_381_G1 BLS12_381_G2 ByteString +type PlonkBS n = Plonk 1 PlonkPolyLengthBS n BLS12_381_G1 BLS12_381_G2 ByteString type PlonkPolyExtendedLengthBS = PlonkPolyExtendedLength PlonkPolyLengthBS propPlonkConstraintConversion :: (Eq a, FiniteField a) => PlonkConstraint a -> Bool @@ -39,7 +39,7 @@ propPlonkConstraintConversion p = propPlonkConstraintSatisfaction :: forall n core . KnownNat n => NonInteractiveProofTestData (PlonkBS n) core -> Bool propPlonkConstraintSatisfaction (TestData (Plonk _ _ _ iPub ac _) w) = - let pr = fromJust $ toPlonkRelation @PlonkPolyLengthBS iPub ac + let pr = fromJust $ toPlonkRelation @1 @PlonkPolyLengthBS iPub ac (PlonkWitnessInput wInput, _) = w (w1', w2', w3') = wmap pr wInput @@ -57,7 +57,7 @@ propPlonkConstraintSatisfaction (TestData (Plonk _ _ _ iPub ac _) w) = in all ((== zero) . f) $ transpose [ql', qr', qo', qm', qc', toList $ fromPolyVec w1', toList $ fromPolyVec w2', toList $ fromPolyVec w3', toList $ fromPolyVec wPub] -propPlonkPolyIdentity :: forall n core . KnownNat n => NonInteractiveProofTestData (PlonkBS n) core -> Bool +propPlonkPolyIdentity :: forall n core . NonInteractiveProofTestData (PlonkBS n) core -> Bool propPlonkPolyIdentity (TestData plonk w) = let zH = polyVecZero @(ScalarField BLS12_381_G1) @PlonkPolyLengthBS @PlonkPolyExtendedLengthBS From 6c4df39a07adf024282fceaa03a4c419ab77053c Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 11:04:42 -0700 Subject: [PATCH 38/48] fixes? --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 4 ++-- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 416902e47..7ca56f5c6 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -56,7 +56,7 @@ instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l, KnownN instance (KnownNat i, KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1)) => Arbitrary (Plonk i n l c1 c2 t) where arbitrary = do ac <- arbitrary - let fullInp = value @l + let fullInp = value @i vecPubInp <- genSubset (value @l) fullInp let (omega, k1, k2) = getParams (value @n) Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary @@ -65,7 +65,7 @@ instance forall i n l c1 c2 t core . (KnownNat i, KnownNat n, KnownNat l, Arithm Witness (Plonk i n l c1 c2 t) ~ (PlonkWitnessInput i c1, PlonkProverSecret c1), NonInteractiveProof (Plonk i n l c1 c2 t) core) => Arbitrary (NonInteractiveProofTestData (Plonk i n l c1 c2 t) core) where arbitrary = do ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) (Vector i) Par1) - let inputLen = value @l + let inputLen = value @i vecPubInp <- genSubset (value @l) inputLen let (omega, k1, k2) = getParams $ value @n pl <- Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 5d422dff0..6e5b661ce 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -52,7 +52,7 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomi import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, - Var (..), acInput, apply, eval, eval1, exec, + Var (..), acInput, eval, eval1, exec, exec1, witnessGenerator) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map import ZkFold.Symbolic.Data.Combinators (expansion) From 2ba627d6b34d8d58c22c269b00854472a090681c Mon Sep 17 00:00:00 2001 From: echatav Date: Tue, 27 Aug 2024 18:07:46 +0000 Subject: [PATCH 39/48] stylish-haskell auto-commit --- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 6e5b661ce..b293e61d5 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -52,8 +52,8 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomi import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, - Var (..), acInput, eval, eval1, exec, - exec1, witnessGenerator) + Var (..), acInput, eval, eval1, exec, exec1, + witnessGenerator) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map import ZkFold.Symbolic.Data.Combinators (expansion) import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..)) From 7523f0c97b6fa20156a21414f79b8c6cb2e43a98 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 11:07:54 -0700 Subject: [PATCH 40/48] Update Internal.hs --- src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index 754345754..75d88d787 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -138,8 +138,8 @@ instance (Show (BaseField c), EllipticCurve c) => Show (PlonkCircuitCommitments ++ show cmS2 ++ " " ++ show cmS3 -newtype PlonkWitnessMap n l c = PlonkWitnessMap - (Vector l (ScalarField c) -> (PolyVec (ScalarField c) n, PolyVec (ScalarField c) n, PolyVec (ScalarField c) n)) +newtype PlonkWitnessMap n i c = PlonkWitnessMap + (Vector i (ScalarField c) -> (PolyVec (ScalarField c) n, PolyVec (ScalarField c) n, PolyVec (ScalarField c) n)) newtype PlonkWitnessInput i c = PlonkWitnessInput (Vector i (ScalarField c)) instance Show (ScalarField c) => Show (PlonkWitnessInput i c) where From ae59e2d152a6ed1c771a95befbb087bc78a29980 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 12:24:56 -0700 Subject: [PATCH 41/48] fixes --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 32 +++++++++++-------- .../Base/Protocol/ARK/Plonk/Constraint.hs | 29 +++++++++-------- .../Base/Protocol/ARK/Plonk/Internal.hs | 14 ++++---- .../Base/Protocol/ARK/Plonk/Relation.hs | 24 ++++++-------- src/ZkFold/Symbolic/Compiler.hs | 2 +- .../Compiler/ArithmeticCircuit/Internal.hs | 9 ++++++ tests/Tests/Arithmetization/Test4.hs | 20 ++++++------ tests/Tests/NonInteractiveProof/Plonk.hs | 25 ++++++++++----- 8 files changed, 89 insertions(+), 66 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 7ca56f5c6..8859279d3 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -12,8 +12,10 @@ module ZkFold.Base.Protocol.ARK.Plonk ( plonkVerifierInput ) where +import Data.Functor ((<&>)) import Data.Functor.Rep (Representable (..)) import Data.Maybe (fromJust) +import qualified Data.Map as Map import qualified Data.Vector as V import GHC.Generics (Par1) import GHC.IsList (IsList (..)) @@ -33,8 +35,8 @@ import ZkFold.Base.Protocol.ARK.Plonk.Internal import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) import ZkFold.Base.Protocol.NonInteractiveProof import ZkFold.Prelude (log2ceiling) -import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), ArithmeticCircuitTest (..)) -import ZkFold.Symbolic.MonadCircuit (Arithmetic) +import ZkFold.Symbolic.Compiler (ArithmeticCircuitTest (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal {- NOTE: we need to parametrize the type of transcripts because we use BuiltinByteString on-chain and ByteString off-chain. @@ -45,7 +47,7 @@ data Plonk (i :: Natural) (n :: Natural) (l :: Natural) curve1 curve2 transcript omega :: ScalarField curve1, k1 :: ScalarField curve1, k2 :: ScalarField curve1, - iPub :: Vector l Natural, + iPub :: Vector l (Var (Vector i)), ac :: ArithmeticCircuit (ScalarField curve1) (Vector i) Par1, x :: ScalarField curve1 } @@ -58,8 +60,9 @@ instance (KnownNat i, KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbit ac <- arbitrary let fullInp = value @i vecPubInp <- genSubset (value @l) fullInp - let (omega, k1, k2) = getParams (value @n) - Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary + let vecPubInVars = [InVar (fromConstant ix) | ix <- vecPubInp] + (omega, k1, k2) = getParams (value @n) + Plonk omega k1 k2 (Vector vecPubInVars) ac <$> arbitrary instance forall i n l c1 c2 t core . (KnownNat i, KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1), Witness (Plonk i n l c1 c2 t) ~ (PlonkWitnessInput i c1, PlonkProverSecret c1), NonInteractiveProof (Plonk i n l c1 c2 t) core) => Arbitrary (NonInteractiveProofTestData (Plonk i n l c1 c2 t) core) where @@ -67,10 +70,11 @@ instance forall i n l c1 c2 t core . (KnownNat i, KnownNat n, KnownNat l, Arithm ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) (Vector i) Par1) let inputLen = value @i vecPubInp <- genSubset (value @l) inputLen - let (omega, k1, k2) = getParams $ value @n - pl <- Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary + let vecPubInVars = [InVar (fromConstant ix) | ix <- vecPubInp] + (omega, k1, k2) = getParams $ value @n + pl <- Plonk omega k1 k2 (Vector vecPubInVars) ac <$> arbitrary secret <- arbitrary - return $ TestData pl (PlonkWitnessInput wi, secret) + return $ TestData pl (PlonkWitnessInput wi (witnessGenerator ac wi), secret) plonkPermutation :: forall i n l c1 c2 t . (KnownNat n, FiniteField (ScalarField c1)) => Plonk i n l c1 c2 t -> PlonkRelation n i (ScalarField c1) -> PlonkPermutation n c1 @@ -129,7 +133,7 @@ instance forall i n l c1 c2 t plonk f g1 core. , CoreFunction c1 core ) => NonInteractiveProof (Plonk i n l c1 c2 t) core where type Transcript (Plonk i n l c1 c2 t) = t - type SetupProve (Plonk i n l c1 c2 t) = (PlonkSetupParamsProve c1 c2, PlonkPermutation n c1, PlonkCircuitPolynomials n c1 , PlonkWitnessMap n i c1) + type SetupProve (Plonk i n l c1 c2 t) = (PlonkSetupParamsProve i c1 c2, PlonkPermutation n c1, PlonkCircuitPolynomials n c1 , PlonkWitnessMap n i c1) type SetupVerify (Plonk i n l c1 c2 t) = (PlonkSetupParamsVerify c1 c2, PlonkCircuitCommitments c1) type Witness (Plonk i n l c1 c2 t) = (PlonkWitnessInput i c1, PlonkProverSecret c1) type Input (Plonk i n l c1 c2 t) = PlonkInput c1 @@ -137,7 +141,7 @@ instance forall i n l c1 c2 t plonk f g1 core. setupProve :: plonk -> SetupProve plonk setupProve plonk@(Plonk omega' k1' k2' iPub ac x) = - (PlonkSetupParamsProve {..}, PlonkPermutation {..}, PlonkCircuitPolynomials {..}, PlonkWitnessMap $ wmap pr) + (PlonkSetupParamsProve {..}, PlonkPermutation {..}, PlonkCircuitPolynomials {..}, PlonkWitnessMap $ \(PlonkWitnessInput win wnv) -> wmap pr win wnv) where d = value @n + 6 xs = fromList $ map (x^) [0..d-!1] @@ -180,15 +184,17 @@ instance forall i n l c1 c2 t plonk f g1 core. prove :: SetupProve plonk -> Witness plonk -> (Input plonk, Proof plonk) prove (PlonkSetupParamsProve {..}, PlonkPermutation {..}, PlonkCircuitPolynomials {..}, PlonkWitnessMap wmap) - (PlonkWitnessInput wInput, PlonkProverSecret {..}) + (PlonkWitnessInput wInput wNewVars, PlonkProverSecret {..}) = (PlonkInput wPub, PlonkProof {..}) where n = value @n zH = polyVecZero @f @n @(PlonkPolyExtendedLength n) - (w1, w2, w3) = wmap wInput + (w1, w2, w3) = wmap (PlonkWitnessInput wInput wNewVars) - wPub = fmap (negate . index wInput . P.fromIntegral) iPub' + wPub = iPub' <&> negate . \case + InVar j -> index wInput j + NewVar j -> wNewVars Map.! j pubPoly = polyVecInLagrangeBasis omega' $ toPolyVec @f @n wPub diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs index 8bf151d4b..c70c836a6 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs @@ -9,6 +9,7 @@ import Data.List (find, permutation import Data.Map (Map, empty, fromListWith) import Data.Maybe (mapMaybe) import GHC.IsList (IsList (..)) +import GHC.TypeNats (KnownNat) import Numeric.Natural (Natural) import Prelude hiding (Num (..), drop, length, sum, take, (!!), (/), (^)) import Test.QuickCheck (Arbitrary (..)) @@ -16,48 +17,50 @@ import Test.QuickCheck (Arbitrary (..)) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, polynomial, variables) import ZkFold.Prelude (length, take, (!!)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal +import ZkFold.Base.Data.Vector (Vector) -data PlonkConstraint a = PlonkConstraint +data PlonkConstraint i a = PlonkConstraint { qm :: a , ql :: a , qr :: a , qo :: a , qc :: a - , x1 :: Natural - , x2 :: Natural - , x3 :: Natural + , x1 :: Var (Vector i) + , x2 :: Var (Vector i) + , x3 :: Var (Vector i) } deriving (Show, Eq) -instance (Arbitrary a, Finite a, ToConstant a Natural) => Arbitrary (PlonkConstraint a) where +instance (Arbitrary a, Finite a, KnownNat i) => Arbitrary (PlonkConstraint i a) where arbitrary = do qm <- arbitrary ql <- arbitrary qr <- arbitrary qo <- arbitrary qc <- arbitrary - x1 <- toConstant <$> arbitrary @a - x2 <- toConstant <$> arbitrary @a - x3 <- toConstant <$> arbitrary @a + x1 <- InVar <$> arbitrary + x2 <- InVar <$> arbitrary + x3 <- InVar <$> arbitrary let xs = sort [x1, x2, x3] return $ PlonkConstraint qm ql qr qo qc (xs !! 0) (xs !! 1) (xs !! 2) -toPlonkConstraint :: forall a . (Eq a, FiniteField a) => Poly a Natural Natural -> PlonkConstraint a +toPlonkConstraint :: forall a i . (Eq a, FiniteField a, KnownNat i) => Poly a (Var (Vector i)) Natural -> PlonkConstraint i a toPlonkConstraint p = let xs = toList $ variables p - i = zero + i = NewVar zero perms = nubOrd $ map (take 3) $ permutations $ case length xs of 0 -> [i, i, i] 1 -> [i, i, head xs, head xs] 2 -> [i] ++ xs ++ xs _ -> xs ++ xs - getCoef :: Map Natural Natural -> a + getCoef :: Map (Var (Vector i)) Natural -> a getCoef m = case find (\(_, as) -> m == as) (toList p) of Just (c, _) -> c _ -> zero - getCoefs :: [Natural] -> Maybe (PlonkConstraint a) + getCoefs :: [Var (Vector i)] -> Maybe (PlonkConstraint i a) getCoefs [a, b, c] = do let xa = [(a, 1)] xb = [(b, 1)] @@ -75,7 +78,7 @@ toPlonkConstraint p = in head $ mapMaybe getCoefs perms -fromPlonkConstraint :: (Eq a, Field a) => PlonkConstraint a -> Poly a Natural Natural +fromPlonkConstraint :: (Eq a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) = let xa = [(a, 1)] xb = [(b, 1)] diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index 75d88d787..01325897d 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -5,6 +5,7 @@ module ZkFold.Base.Protocol.ARK.Plonk.Internal where import Data.Bifunctor (first) import Data.Bool (bool) +import Data.Map.Strict (Map) import qualified Data.Vector as V import GHC.Generics (Generic) import GHC.IsList (IsList (..)) @@ -18,6 +19,7 @@ import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) import ZkFold.Base.Data.Vector (Vector) import ZkFold.Prelude (log2ceiling, take) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal getParams :: forall a . (Eq a, FiniteField a) => Natural -> (a, a, a) getParams n = findK' $ mkStdGen 0 @@ -46,17 +48,17 @@ type PlonkPolyExtendedLength n = 4 * n + 6 type PlonkPolyExtended n c = PolyVec (ScalarField c) (PlonkPolyExtendedLength n) -data PlonkSetupParamsProve c1 c2 = PlonkSetupParamsProve { +data PlonkSetupParamsProve i c1 c2 = PlonkSetupParamsProve { omega' :: ScalarField c1, k1' :: ScalarField c1, k2' :: ScalarField c1, gs' :: V.Vector (Point c1), h0' :: Point c2, h1' :: Point c2, - iPub' :: V.Vector Natural + iPub' :: V.Vector (Var (Vector i)) } instance (Show (ScalarField c1), Show (BaseField c1), Show (BaseField c2), - EllipticCurve c1, EllipticCurve c2) => Show (PlonkSetupParamsProve c1 c2) where + EllipticCurve c1, EllipticCurve c2) => Show (PlonkSetupParamsProve i c1 c2) where show (PlonkSetupParamsProve omega' k1' k2' gs' h0' h1' iPub') = "Setup Parameters (Prove): " ++ show omega' ++ " " @@ -139,11 +141,11 @@ instance (Show (BaseField c), EllipticCurve c) => Show (PlonkCircuitCommitments ++ show cmS3 newtype PlonkWitnessMap n i c = PlonkWitnessMap - (Vector i (ScalarField c) -> (PolyVec (ScalarField c) n, PolyVec (ScalarField c) n, PolyVec (ScalarField c) n)) + (PlonkWitnessInput i c -> (PolyVec (ScalarField c) n, PolyVec (ScalarField c) n, PolyVec (ScalarField c) n)) -newtype PlonkWitnessInput i c = PlonkWitnessInput (Vector i (ScalarField c)) +data PlonkWitnessInput i c = PlonkWitnessInput (Vector i (ScalarField c)) (Map Natural (ScalarField c)) instance Show (ScalarField c) => Show (PlonkWitnessInput i c) where - show (PlonkWitnessInput m) = "Witness Input: " ++ show m + show (PlonkWitnessInput v m) = "Witness Input: " ++ show v <> "Witness New Vars: " ++ show m data PlonkProverSecret c = PlonkProverSecret { b1 :: ScalarField c, diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs index af9d33a7f..a33d8c46a 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs @@ -3,7 +3,7 @@ module ZkFold.Base.Protocol.ARK.Plonk.Relation where -import Data.Map (elems, (!)) +import Data.Map (elems, Map) import GHC.Generics (Par1) import GHC.IsList (IsList (..)) import Prelude hiding (Num (..), drop, length, replicate, sum, @@ -12,7 +12,7 @@ import Prelude hiding (Num import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.Basic.Permutations (Permutation, fromCycles, mkIndexPartition) -import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) +import ZkFold.Base.Algebra.Polynomials.Multivariate (var) import ZkFold.Base.Algebra.Polynomials.Univariate (PolyVec, toPolyVec) import ZkFold.Base.Data.Vector (Vector, fromVector) import ZkFold.Base.Protocol.ARK.Plonk.Constraint (PlonkConstraint (..), toPlonkConstraint) @@ -28,7 +28,7 @@ data PlonkRelation n i a = PlonkRelation , qO :: PolyVec a n , qC :: PolyVec a n , sigma :: Permutation (3 * n) - , wmap :: Vector i a -> (PolyVec a n, PolyVec a n, PolyVec a n) + , wmap :: Vector i a -> Map Natural a -> (PolyVec a n, PolyVec a n, PolyVec a n) } toPlonkRelation :: forall i n l a . @@ -37,19 +37,14 @@ toPlonkRelation :: forall i n l a . => KnownNat (3 * n) => KnownNat l => Arithmetic a - => Scale a a - => Vector l Natural + => Vector l (Var (Vector i)) -> ArithmeticCircuit a (Vector i) Par1 -> Maybe (PlonkRelation n i a) toPlonkRelation xPub ac0 = let ac = desugarRanges ac0 - varF (NewVar ix) = if ix == 0 then one else var (ix + value @l) - varF (InVar ix) = var (toConstant ix) - evalX0 = evalPolynomial evalMonomial varF - pubInputConstraints = map var (fromVector xPub) - acConstraints = map evalX0 $ elems (acSystem ac) + acConstraints = elems (acSystem ac) extraConstraints = replicate (value @n -! acSizeN ac -! value @l) zero system = map toPlonkConstraint $ pubInputConstraints ++ acConstraints ++ extraConstraints @@ -66,11 +61,10 @@ toPlonkRelation xPub ac0 = -- TODO: Permutation code is not particularly safe. We rely on the list being of length 3*n. sigma = fromCycles @(3*n) $ mkIndexPartition $ fromList $ a ++ b ++ c - wmap' = witnessGenerator ac - w1 i = toPolyVec $ fromList $ map (wmap' i !) a - w2 i = toPolyVec $ fromList $ map (wmap' i !) b - w3 i = toPolyVec $ fromList $ map (wmap' i !) c - wmap i = (w1 i, w2 i, w3 i) + w1 i = toPolyVec $ fromList $ map (indexW ac i) a + w2 i = toPolyVec $ fromList $ map (indexW ac i) b + w3 i = toPolyVec $ fromList $ map (indexW ac i) c + wmap i _ = (w1 i, w2 i, w3 i) in if (acSizeN ac + value @l) <= value @n then Just $ PlonkRelation {..} diff --git a/src/ZkFold/Symbolic/Compiler.hs b/src/ZkFold/Symbolic/Compiler.hs index 1601a545f..bca7d5caf 100644 --- a/src/ZkFold/Symbolic/Compiler.hs +++ b/src/ZkFold/Symbolic/Compiler.hs @@ -21,7 +21,7 @@ import Prelude (FilePath, IO, Monoi import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number -import ZkFold.Base.Data.Vector (Vector, unsafeToVector) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Prelude (writeFileJSON) import ZkFold.Symbolic.Class (Arithmetic, Symbolic (..)) import ZkFold.Symbolic.Compiler.ArithmeticCircuit diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 9b23898ef..69143fb2d 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -12,6 +12,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( ConstraintMonomial, Constraint, witnessGenerator, + indexW, -- low-level functions constraint, rangeConstraint, @@ -36,6 +37,7 @@ import Data.List (sort) import Data.Map.Strict hiding (drop, foldl, foldr, map, null, splitAt, take, toList) import qualified Data.Map.Strict as M hiding (toList) +import Data.Maybe (fromMaybe) import Data.Semialign (unzipDefault) import qualified Data.Set as S import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..)) @@ -102,6 +104,13 @@ witnessGenerator circuit inputs = in result +indexW :: Representable i => ArithmeticCircuit a i o -> i a -> Var i -> a +indexW circuit inputs = \case + InVar j -> index inputs j + NewVar j -> fromMaybe + (error ("no such NewVar: " <> show j)) + (witnessGenerator circuit inputs M.!? j) + ------------------------------ Symbolic compiler context ---------------------------- crown :: ArithmeticCircuit a i g -> f (Var i) -> ArithmeticCircuit a i f diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index a4db1f503..b4a5201b1 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -4,7 +4,6 @@ module Tests.Arithmetization.Test4 (specArithmetization4) where import GHC.Generics (Par1 (unPar1)) -import GHC.Num (Natural) import Prelude hiding (Bool, Eq (..), Num (..), Ord (..), (&&)) import qualified Prelude as Haskell import Test.Hspec (Spec, describe, it) @@ -22,7 +21,7 @@ import ZkFold.Base.Protocol.NonInteractiveProof (CoreFuncti NonInteractiveProof (..)) import ZkFold.Symbolic.Class import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), compile, compileForceOne, - eval) + eval, witnessGenerator) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Var (..)) import ZkFold.Symbolic.Data.Bool (Bool (..)) import ZkFold.Symbolic.Data.Eq (Eq (..)) @@ -54,13 +53,12 @@ testOnlyOutputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue - indexOutputBool = V.singleton $ case unPar1 $ acOutput ac of - NewVar ix -> ix + 1 - InVar _ -> 1 + witnessNewVars = witnessGenerator ac witnessInputs + indexOutputBool = V.singleton $ unPar1 $ acOutput ac plonk = Plonk @1 @32 omega k1 k2 indexOutputBool ac x setupP = setupProve @(PlonkBS N) @core plonk setupV = setupVerify @(PlonkBS N) @core plonk - witness = (PlonkWitnessInput witnessInputs, ps) + witness = (PlonkWitnessInput witnessInputs witnessNewVars, ps) (input, proof) = prove @(PlonkBS N) @core setupP witness -- `one` corresponds to `True` @@ -74,11 +72,12 @@ testSafeOneInputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton targetValue - indexTargetValue = V.singleton (1 :: Natural) + witnessNewVars = witnessGenerator ac witnessInputs + indexTargetValue = V.singleton (InVar zero) plonk = Plonk @1 @32 omega k1 k2 indexTargetValue ac x setupP = setupProve @(PlonkBS N) @core plonk setupV = setupVerify @(PlonkBS N) @core plonk - witness = (PlonkWitnessInput witnessInputs, ps) + witness = (PlonkWitnessInput witnessInputs witnessNewVars, ps) (input, proof) = prove @(PlonkBS N) @core setupP witness onePublicInput = plonkVerifierInput $ V.singleton targetValue @@ -91,11 +90,12 @@ testAttackSafeOneInputZKP x ps targetValue = (omega, k1, k2) = getParams 32 witnessInputs = V.singleton (targetValue + 1) - indexTargetValue = V.singleton (1 :: Natural) + witnessNewVars = witnessGenerator ac witnessInputs + indexTargetValue = V.singleton (InVar zero) plonk = Plonk @1 @32 omega k1 k2 indexTargetValue ac x setupP = setupProve @(PlonkBS N) @core plonk setupV = setupVerify @(PlonkBS N) @core plonk - witness = (PlonkWitnessInput witnessInputs, ps) + witness = (PlonkWitnessInput witnessInputs witnessNewVars, ps) (input, proof) = prove @(PlonkBS N) @core setupP witness onePublicInput = plonkVerifierInput $ V.singleton $ targetValue + 1 diff --git a/tests/Tests/NonInteractiveProof/Plonk.hs b/tests/Tests/NonInteractiveProof/Plonk.hs index 6ec46fc5e..3df054e4d 100644 --- a/tests/Tests/NonInteractiveProof/Plonk.hs +++ b/tests/Tests/NonInteractiveProof/Plonk.hs @@ -4,12 +4,13 @@ module Tests.NonInteractiveProof.Plonk (PlonkBS, specPlonk) where import Data.ByteString (ByteString) +import Data.Functor ((<&>)) import Data.Functor.Rep (Representable (..)) import Data.List (transpose) +import qualified Data.Map as Map import Data.Maybe (fromJust) import qualified Data.Vector as V import GHC.IsList (IsList (..)) -import GHC.Natural (Natural) import Prelude hiding (Fractional (..), Num (..), drop, length, replicate, take) import Test.Hspec @@ -28,22 +29,26 @@ import ZkFold.Base.Protocol.ARK.Plonk.Constraint import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) import ZkFold.Base.Protocol.NonInteractiveProof (HaskellCore, NonInteractiveProof (..), NonInteractiveProofTestData (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal type PlonkPolyLengthBS = 32 type PlonkBS n = Plonk 1 PlonkPolyLengthBS n BLS12_381_G1 BLS12_381_G2 ByteString type PlonkPolyExtendedLengthBS = PlonkPolyExtendedLength PlonkPolyLengthBS -propPlonkConstraintConversion :: (Eq a, FiniteField a) => PlonkConstraint a -> Bool +propPlonkConstraintConversion :: (Eq a, FiniteField a) => PlonkConstraint 1 a -> Bool propPlonkConstraintConversion p = toPlonkConstraint (fromPlonkConstraint p) == p propPlonkConstraintSatisfaction :: forall n core . KnownNat n => NonInteractiveProofTestData (PlonkBS n) core -> Bool propPlonkConstraintSatisfaction (TestData (Plonk _ _ _ iPub ac _) w) = let pr = fromJust $ toPlonkRelation @1 @PlonkPolyLengthBS iPub ac - (PlonkWitnessInput wInput, _) = w - (w1', w2', w3') = wmap pr wInput + (PlonkWitnessInput wInput wNewVars, _) = w + (w1', w2', w3') = wmap pr wInput wNewVars - wPub = toPolyVec @_ @PlonkPolyLengthBS $ fmap (negate . index wInput . fromIntegral) $ fromList @(V.Vector Natural) $ fromVector iPub + wPub = toPolyVec @_ @PlonkPolyLengthBS $ + fromList (fromVector iPub) <&> negate . \case + InVar j -> index wInput j + NewVar j -> wNewVars Map.! j qm' = V.toList $ fromPolyVec $ qM pr ql' = V.toList $ fromPolyVec $ qL pr @@ -63,11 +68,15 @@ propPlonkPolyIdentity (TestData plonk w) = s = setupProve @(PlonkBS n) @core plonk (PlonkSetupParamsProve {..}, _, PlonkCircuitPolynomials {..}, PlonkWitnessMap wmap) = s - (PlonkWitnessInput wInput, ps) = w + (pw@(PlonkWitnessInput wInput wNewVars), ps) = w PlonkProverSecret b1 b2 b3 b4 b5 b6 _ _ _ _ _ = ps - (w1, w2, w3) = wmap wInput + (w1, w2, w3) = wmap pw - wPub = fmap (negate . index wInput . fromIntegral) iPub' + wPub = iPub' <&> negate . \case + InVar j -> index wInput j + NewVar j -> wNewVars Map.! j + + -- wPub = fmap (negate . index wInput . fromIntegral) iPub' pubPoly = polyVecInLagrangeBasis @(ScalarField BLS12_381_G1) @PlonkPolyLengthBS @PlonkPolyExtendedLengthBS omega' $ toPolyVec @(ScalarField BLS12_381_G1) @PlonkPolyLengthBS wPub From eb33981e3fa51c8021333a8394ca0d367c59b0d5 Mon Sep 17 00:00:00 2001 From: echatav Date: Tue, 27 Aug 2024 19:28:02 +0000 Subject: [PATCH 42/48] stylish-haskell auto-commit --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 40 ++++++++--------- .../Base/Protocol/ARK/Plonk/Constraint.hs | 27 +++++------ .../Base/Protocol/ARK/Plonk/Internal.hs | 29 ++++++------ .../Base/Protocol/ARK/Plonk/Relation.hs | 2 +- tests/Tests/NonInteractiveProof/Plonk.hs | 45 ++++++++++--------- 5 files changed, 73 insertions(+), 70 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 8859279d3..152d5c1a2 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -12,31 +12,31 @@ module ZkFold.Base.Protocol.ARK.Plonk ( plonkVerifierInput ) where -import Data.Functor ((<&>)) -import Data.Functor.Rep (Representable (..)) -import Data.Maybe (fromJust) -import qualified Data.Map as Map -import qualified Data.Vector as V -import GHC.Generics (Par1) -import GHC.IsList (IsList (..)) -import Prelude hiding (Num (..), div, drop, length, replicate, sum, take, - (!!), (/), (^)) -import qualified Prelude as P hiding (length) -import Test.QuickCheck (Arbitrary (..), Gen) +import Data.Functor ((<&>)) +import Data.Functor.Rep (Representable (..)) +import qualified Data.Map as Map +import Data.Maybe (fromJust) +import qualified Data.Vector as V +import GHC.Generics (Par1) +import GHC.IsList (IsList (..)) +import Prelude hiding (Num (..), div, drop, length, replicate, + sum, take, (!!), (/), (^)) +import qualified Prelude as P hiding (length) +import Test.QuickCheck (Arbitrary (..), Gen) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number -import ZkFold.Base.Algebra.Basic.Permutations (fromPermutation) -import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Pairing (..), Point, PointCompressed, - compress) -import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) -import ZkFold.Base.Data.Vector (Vector (..), fromVector) +import ZkFold.Base.Algebra.Basic.Permutations (fromPermutation) +import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Pairing (..), Point, + PointCompressed, compress) +import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) +import ZkFold.Base.Data.Vector (Vector (..), fromVector) import ZkFold.Base.Protocol.ARK.Plonk.Internal -import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) +import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) import ZkFold.Base.Protocol.NonInteractiveProof -import ZkFold.Prelude (log2ceiling) -import ZkFold.Symbolic.Compiler (ArithmeticCircuitTest (..)) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal +import ZkFold.Prelude (log2ceiling) +import ZkFold.Symbolic.Compiler (ArithmeticCircuitTest (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal {- NOTE: we need to parametrize the type of transcripts because we use BuiltinByteString on-chain and ByteString off-chain. diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs index c70c836a6..8ad477aa2 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs @@ -3,22 +3,23 @@ module ZkFold.Base.Protocol.ARK.Plonk.Constraint where -import Control.Monad (guard) -import Data.Containers.ListUtils (nubOrd) -import Data.List (find, permutations, sort) -import Data.Map (Map, empty, fromListWith) -import Data.Maybe (mapMaybe) -import GHC.IsList (IsList (..)) -import GHC.TypeNats (KnownNat) -import Numeric.Natural (Natural) -import Prelude hiding (Num (..), drop, length, sum, take, (!!), (/), (^)) -import Test.QuickCheck (Arbitrary (..)) +import Control.Monad (guard) +import Data.Containers.ListUtils (nubOrd) +import Data.List (find, permutations, sort) +import Data.Map (Map, empty, fromListWith) +import Data.Maybe (mapMaybe) +import GHC.IsList (IsList (..)) +import GHC.TypeNats (KnownNat) +import Numeric.Natural (Natural) +import Prelude hiding (Num (..), drop, length, sum, take, (!!), + (/), (^)) +import Test.QuickCheck (Arbitrary (..)) import ZkFold.Base.Algebra.Basic.Class -import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, polynomial, variables) -import ZkFold.Prelude (length, take, (!!)) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal +import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, polynomial, variables) import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Prelude (length, take, (!!)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal data PlonkConstraint i a = PlonkConstraint { qm :: a diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index 01325897d..9fd6acf31 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -3,23 +3,24 @@ module ZkFold.Base.Protocol.ARK.Plonk.Internal where -import Data.Bifunctor (first) -import Data.Bool (bool) -import Data.Map.Strict (Map) -import qualified Data.Vector as V -import GHC.Generics (Generic) -import GHC.IsList (IsList (..)) -import Prelude hiding (Num (..), drop, length, sum, take, (!!), (/), (^)) -import System.Random (RandomGen, mkStdGen, uniformR) -import Test.QuickCheck (Arbitrary (..), Gen, shuffle) +import Data.Bifunctor (first) +import Data.Bool (bool) +import Data.Map.Strict (Map) +import qualified Data.Vector as V +import GHC.Generics (Generic) +import GHC.IsList (IsList (..)) +import Prelude hiding (Num (..), drop, length, sum, take, (!!), + (/), (^)) +import System.Random (RandomGen, mkStdGen, uniformR) +import Test.QuickCheck (Arbitrary (..), Gen, shuffle) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number -import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Point) -import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) -import ZkFold.Base.Data.Vector (Vector) -import ZkFold.Prelude (log2ceiling, take) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal +import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..), Point) +import ZkFold.Base.Algebra.Polynomials.Univariate hiding (qr) +import ZkFold.Base.Data.Vector (Vector) +import ZkFold.Prelude (log2ceiling, take) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal getParams :: forall a . (Eq a, FiniteField a) => Natural -> (a, a, a) getParams n = findK' $ mkStdGen 0 diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs index a33d8c46a..a842f64ed 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs @@ -3,7 +3,7 @@ module ZkFold.Base.Protocol.ARK.Plonk.Relation where -import Data.Map (elems, Map) +import Data.Map (Map, elems) import GHC.Generics (Par1) import GHC.IsList (IsList (..)) import Prelude hiding (Num (..), drop, length, replicate, sum, diff --git a/tests/Tests/NonInteractiveProof/Plonk.hs b/tests/Tests/NonInteractiveProof/Plonk.hs index 3df054e4d..42ec5d2dd 100644 --- a/tests/Tests/NonInteractiveProof/Plonk.hs +++ b/tests/Tests/NonInteractiveProof/Plonk.hs @@ -3,33 +3,34 @@ module Tests.NonInteractiveProof.Plonk (PlonkBS, specPlonk) where -import Data.ByteString (ByteString) -import Data.Functor ((<&>)) -import Data.Functor.Rep (Representable (..)) -import Data.List (transpose) -import qualified Data.Map as Map -import Data.Maybe (fromJust) -import qualified Data.Vector as V -import GHC.IsList (IsList (..)) -import Prelude hiding (Fractional (..), Num (..), drop, length, replicate, - take) +import Data.ByteString (ByteString) +import Data.Functor ((<&>)) +import Data.Functor.Rep (Representable (..)) +import Data.List (transpose) +import qualified Data.Map as Map +import Data.Maybe (fromJust) +import qualified Data.Vector as V +import GHC.IsList (IsList (..)) +import Prelude hiding (Fractional (..), Num (..), drop, length, + replicate, take) import Test.Hspec import Test.QuickCheck -import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup (..), AdditiveSemigroup (..), FiniteField, - MultiplicativeSemigroup (..), negate, zero, (-!)) -import ZkFold.Base.Algebra.Basic.Number (KnownNat, value) -import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_G2) -import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) -import ZkFold.Base.Algebra.Polynomials.Univariate (evalPolyVec, fromPolyVec, polyVecInLagrangeBasis, - polyVecLinear, polyVecZero, toPolyVec) -import ZkFold.Base.Data.Vector (fromVector) +import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup (..), AdditiveSemigroup (..), + FiniteField, MultiplicativeSemigroup (..), negate, + zero, (-!)) +import ZkFold.Base.Algebra.Basic.Number (KnownNat, value) +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_G2) +import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) +import ZkFold.Base.Algebra.Polynomials.Univariate (evalPolyVec, fromPolyVec, polyVecInLagrangeBasis, + polyVecLinear, polyVecZero, toPolyVec) +import ZkFold.Base.Data.Vector (fromVector) import ZkFold.Base.Protocol.ARK.Plonk import ZkFold.Base.Protocol.ARK.Plonk.Constraint -import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) -import ZkFold.Base.Protocol.NonInteractiveProof (HaskellCore, NonInteractiveProof (..), - NonInteractiveProofTestData (..)) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal +import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..), toPlonkRelation) +import ZkFold.Base.Protocol.NonInteractiveProof (HaskellCore, NonInteractiveProof (..), + NonInteractiveProofTestData (..)) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal type PlonkPolyLengthBS = 32 type PlonkBS n = Plonk 1 PlonkPolyLengthBS n BLS12_381_G1 BLS12_381_G2 ByteString From d6e3d18d7971970b75f7331ddf9f6a86501239bb Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 12:44:44 -0700 Subject: [PATCH 43/48] Update Internal.hs --- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 69143fb2d..f55143f9a 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -290,7 +290,7 @@ apply xs ac = ac -- zoom #acWitness . modify . union . fromList $ zip inputs (map const xs) getAllVars :: (MultiplicativeMonoid a, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuit a i o -> [Var i] -getAllVars ac = nubOrd $ sort $ NewVar 0 : toList acInput ++ concatMap (toList . variables) (elems $ acSystem ac) +getAllVars ac = nubOrd $ sort $ toList acInput ++ concatMap (toList . variables) (elems $ acSystem ac) -- TODO: Add proper symbolic application functions From 4be6eb75e88e0fd68aca42a8a86e1b7742f7b2fc Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 12:51:39 -0700 Subject: [PATCH 44/48] Update Map.hs --- src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index 1b9684d9d..2fa95ed7a 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -53,7 +53,7 @@ mapVarArithmeticCircuit (ArithmeticCircuitTest ac wi) = { acSystem = fromList $ zip [0..] $ evalPolynomial evalMonomial (var . varF) <$> elems (acSystem ac), -- TODO: the new arithmetic circuit expects the old input variables! We should make this safer. - acWitness = (`Map.compose` backward) $ acWitness ac + acWitness = (`Map.compose` backward) $ (\f i m -> f i (Map.compose m forward)) <$> acWitness ac } mappedOutputs = varF <$> acOutput ac in ArithmeticCircuitTest (mappedCircuit {acOutput = mappedOutputs}) wi From 9d627b13353d33845e8bd80a38b322c74ca9c4f0 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 13:48:21 -0700 Subject: [PATCH 45/48] fix a test --- .../Base/Protocol/ARK/Plonk/Constraint.hs | 29 +++++++++++-------- tests/Tests/NonInteractiveProof/Plonk.hs | 4 +-- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs index 8ad477aa2..2874fe6d4 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs @@ -16,7 +16,7 @@ import Prelude hiding (Num import Test.QuickCheck (Arbitrary (..)) import ZkFold.Base.Algebra.Basic.Class -import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, polynomial, variables) +import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, polynomial, var, variables) import ZkFold.Base.Data.Vector (Vector) import ZkFold.Prelude (length, take, (!!)) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal @@ -33,16 +33,16 @@ data PlonkConstraint i a = PlonkConstraint } deriving (Show, Eq) -instance (Arbitrary a, Finite a, KnownNat i) => Arbitrary (PlonkConstraint i a) where +instance (Arbitrary a, Finite a, ToConstant a Natural, KnownNat i) => Arbitrary (PlonkConstraint i a) where arbitrary = do qm <- arbitrary ql <- arbitrary qr <- arbitrary qo <- arbitrary qc <- arbitrary - x1 <- InVar <$> arbitrary - x2 <- InVar <$> arbitrary - x3 <- InVar <$> arbitrary + x1 <- NewVar . toConstant @a <$> arbitrary + x2 <- NewVar . toConstant @a <$> arbitrary + x3 <- NewVar . toConstant @a <$> arbitrary let xs = sort [x1, x2, x3] return $ PlonkConstraint qm ql qr qo qc (xs !! 0) (xs !! 1) (xs !! 2) @@ -79,11 +79,16 @@ toPlonkConstraint p = in head $ mapMaybe getCoefs perms -fromPlonkConstraint :: (Eq a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural +fromPlonkConstraint :: (Eq a, Scale a a, FromConstant a a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) = - let xa = [(a, 1)] - xb = [(b, 1)] - xc = [(c, 1)] - xaxb = [(a, 1), (b, 1)] - - in polynomial [(qm, xaxb), (ql, xa), (qr, xb), (qo, xc), (qc, one)] + let xvar v = if v == NewVar zero then zero else var v + xa = xvar a + xb = xvar b + xc = xvar c + xaxb = xa * xb + in + scale qm xaxb + + scale ql xa + + scale qr xb + + scale qo xc + + fromConstant qc diff --git a/tests/Tests/NonInteractiveProof/Plonk.hs b/tests/Tests/NonInteractiveProof/Plonk.hs index 42ec5d2dd..1da4b9a8a 100644 --- a/tests/Tests/NonInteractiveProof/Plonk.hs +++ b/tests/Tests/NonInteractiveProof/Plonk.hs @@ -18,7 +18,7 @@ import Test.QuickCheck import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup (..), AdditiveSemigroup (..), FiniteField, MultiplicativeSemigroup (..), negate, - zero, (-!)) + zero, (-!), Scale (..), FromConstant (..)) import ZkFold.Base.Algebra.Basic.Number (KnownNat, value) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_G2) import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) @@ -36,7 +36,7 @@ type PlonkPolyLengthBS = 32 type PlonkBS n = Plonk 1 PlonkPolyLengthBS n BLS12_381_G1 BLS12_381_G2 ByteString type PlonkPolyExtendedLengthBS = PlonkPolyExtendedLength PlonkPolyLengthBS -propPlonkConstraintConversion :: (Eq a, FiniteField a) => PlonkConstraint 1 a -> Bool +propPlonkConstraintConversion :: (Eq a, Scale a a, FromConstant a a, FiniteField a) => PlonkConstraint 1 a -> Bool propPlonkConstraintConversion p = toPlonkConstraint (fromPlonkConstraint p) == p From 50d58ab1da16b7fe2b508a3526179b3cdb5740f8 Mon Sep 17 00:00:00 2001 From: echatav Date: Tue, 27 Aug 2024 20:51:21 +0000 Subject: [PATCH 46/48] stylish-haskell auto-commit --- src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs | 2 +- tests/Tests/NonInteractiveProof/Plonk.hs | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs index 2874fe6d4..e07b1d766 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Constraint.hs @@ -86,7 +86,7 @@ fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) = xb = xvar b xc = xvar c xaxb = xa * xb - in + in scale qm xaxb + scale ql xa + scale qr xb diff --git a/tests/Tests/NonInteractiveProof/Plonk.hs b/tests/Tests/NonInteractiveProof/Plonk.hs index 1da4b9a8a..58c618f2e 100644 --- a/tests/Tests/NonInteractiveProof/Plonk.hs +++ b/tests/Tests/NonInteractiveProof/Plonk.hs @@ -17,8 +17,9 @@ import Test.Hspec import Test.QuickCheck import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup (..), AdditiveSemigroup (..), - FiniteField, MultiplicativeSemigroup (..), negate, - zero, (-!), Scale (..), FromConstant (..)) + FiniteField, FromConstant (..), + MultiplicativeSemigroup (..), Scale (..), negate, + zero, (-!)) import ZkFold.Base.Algebra.Basic.Number (KnownNat, value) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_G2) import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) From c8f62b08d3c15ddecaa68b6642ce9208ae961e68 Mon Sep 17 00:00:00 2001 From: Eitan Chatav Date: Tue, 27 Aug 2024 13:58:54 -0700 Subject: [PATCH 47/48] fix 2 more tests --- src/ZkFold/Base/Protocol/ARK/Plonk.hs | 16 ++++++---------- src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs | 4 ++-- src/ZkFold/Prelude.hs | 3 ++- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index 152d5c1a2..9853c298a 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -58,21 +58,17 @@ instance (Show (ScalarField c1), Arithmetic (ScalarField c1), KnownNat l, KnownN instance (KnownNat i, KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1)) => Arbitrary (Plonk i n l c1 c2 t) where arbitrary = do ac <- arbitrary - let fullInp = value @i - vecPubInp <- genSubset (value @l) fullInp - let vecPubInVars = [InVar (fromConstant ix) | ix <- vecPubInp] - (omega, k1, k2) = getParams (value @n) - Plonk omega k1 k2 (Vector vecPubInVars) ac <$> arbitrary + vecPubInp <- genSubset (getAllVars ac) (value @l) + let (omega, k1, k2) = getParams (value @n) + Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary instance forall i n l c1 c2 t core . (KnownNat i, KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (ScalarField c1), Witness (Plonk i n l c1 c2 t) ~ (PlonkWitnessInput i c1, PlonkProverSecret c1), NonInteractiveProof (Plonk i n l c1 c2 t) core) => Arbitrary (NonInteractiveProofTestData (Plonk i n l c1 c2 t) core) where arbitrary = do ArithmeticCircuitTest ac wi <- arbitrary :: Gen (ArithmeticCircuitTest (ScalarField c1) (Vector i) Par1) - let inputLen = value @i - vecPubInp <- genSubset (value @l) inputLen - let vecPubInVars = [InVar (fromConstant ix) | ix <- vecPubInp] - (omega, k1, k2) = getParams $ value @n - pl <- Plonk omega k1 k2 (Vector vecPubInVars) ac <$> arbitrary + vecPubInp <- genSubset (getAllVars ac) (value @l) + let (omega, k1, k2) = getParams $ value @n + pl <- Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary secret <- arbitrary return $ TestData pl (PlonkWitnessInput wi (witnessGenerator ac wi), secret) diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs index 9fd6acf31..492d950af 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk/Internal.hs @@ -39,8 +39,8 @@ getParams n = findK' $ mkStdGen 0 all (`notElem` hGroup) (hGroup' k1) && all (`notElem` hGroup' k1) (hGroup' k2) -genSubset :: Natural -> Natural -> Gen [Natural] -genSubset maxLength maxValue = take maxLength <$> shuffle [1..maxValue] +genSubset :: [Var (Vector i)] -> Natural -> Gen [Var (Vector i)] +genSubset vars maxLength = take maxLength <$> shuffle vars type PlonkPermutationSize n = 3 * n diff --git a/src/ZkFold/Prelude.hs b/src/ZkFold/Prelude.hs index 34b5e7552..0dc8a6e4a 100644 --- a/src/ZkFold/Prelude.hs +++ b/src/ZkFold/Prelude.hs @@ -5,6 +5,7 @@ import Data.ByteString.Lazy (readFile, writeFile) import Data.List (foldl', genericIndex) import Data.Map (Map, lookup) import GHC.Num (Natural, integerToNatural) +import GHC.Stack (HasCallStack) import Prelude hiding (drop, lookup, readFile, replicate, take, writeFile, (!!)) import Test.QuickCheck (Gen, chooseInteger) @@ -14,7 +15,7 @@ log2ceiling = ceiling @Double . logBase 2 . fromIntegral length :: Foldable t => t a -> Natural length = foldl' (\c _ -> c + 1) 0 -take :: Natural -> [a] -> [a] +take :: HasCallStack => Natural -> [a] -> [a] take 0 _ = [] take n (x:xs) = x : take (n - 1) xs take _ [] = error "ZkFold.Prelude.take: empty list" From 85258e5087dc7463f0658ea82e97681f05662029 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Wed, 28 Aug 2024 01:37:49 +0300 Subject: [PATCH 48/48] reduce diff --- src/ZkFold/Symbolic/Data/UInt.hs | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 4a9d741bb..306230e0b 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -141,6 +141,22 @@ instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Arbitrary (UInt n r c) return $ UInt $ embed $ V.unsafeToVector (lo <> [hi]) where toss b = fromConstant <$> chooseInteger (0, 2 ^ b - 1) +instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (ByteString n c) (UInt n r c) where + from (ByteString b) = UInt $ fromCircuitF b solve + where + solve :: forall i m. MonadCircuit i (BaseField c) m => Vector n i -> m (Vector (NumberOfRegisters (BaseField c) n r) i) + solve bits = do + let bsBits = V.fromVector bits + V.unsafeToVector . Haskell.reverse <$> fromBits (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) bsBits + +instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (UInt n r c) (ByteString n c) where + from (UInt v) = ByteString $ fromCircuitF v solve + where + solve :: forall i m. MonadCircuit i (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) i -> m (Vector n i) + solve ui = do + let regs = V.fromVector ui + V.unsafeToVector <$> toBits (Haskell.reverse regs) (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) + -- -------------------------------------------------------------------------------- instance @@ -204,22 +220,6 @@ instance let rs = force $ addBit (r' + r') (value @n -! i -! 1) in bool @(Bool c) (q', rs) (q' + fromConstant ((2 :: Natural) ^ i), rs - d) (rs >= d) -instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (ByteString n c) (UInt n r c) where - from (ByteString bits) = UInt $ symbolicF bits (\v -> naturalToVector @c @n @r $ vectorToNatural v (registerSize @(BaseField c) @n @r)) solve - where - solve :: MonadCircuit v a m => Vector n v -> m (Vector (NumberOfRegisters a n r) v) - solve xv = do - let bsBits = V.fromVector xv - V.unsafeToVector . Haskell.reverse <$> fromBits (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) bsBits - -instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Iso (UInt n r c) (ByteString n c) where - from (UInt ac) = ByteString $ symbolicF ac (\v -> V.unsafeToVector $ fromConstant <$> toBsBits (vectorToNatural v (registerSize @(BaseField c) @n @r)) (value @n)) solve - where - solve :: MonadCircuit v (BaseField c) m => Vector (NumberOfRegisters (BaseField c) n r) v -> m (Vector n v) - solve xv = do - let regs = V.fromVector xv - V.unsafeToVector <$> toBits (Haskell.reverse regs) (highRegisterSize @(BaseField c) @n @r) (registerSize @(BaseField c) @n @r) - instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Ord (Bool c) (UInt n r c) where x <= y = y >= x