diff --git a/bench/BenchCompiler.hs b/bench/BenchCompiler.hs index 6e9d5a833..56df25f67 100644 --- a/bench/BenchCompiler.hs +++ b/bench/BenchCompiler.hs @@ -30,7 +30,7 @@ metrics name circuit = benchmark :: - (NFData a, AdditiveMonoid a, NFData (o (Var i)), NFData (Rep i), i ~ Vector n_i, KnownNat n_i) => + (NFData a, AdditiveMonoid a, NFData (o (Var a i)), NFData (Rep i), i ~ Vector n_i, KnownNat n_i) => String -> (() -> ArithmeticCircuit a i o) -> Benchmark benchmark name circuit = bgroup name [ bench "compilation" $ nf circuit () diff --git a/examples/Examples/Constant.hs b/examples/Examples/Constant.hs new file mode 100644 index 000000000..becc7d8e6 --- /dev/null +++ b/examples/Examples/Constant.hs @@ -0,0 +1,15 @@ +module Examples.Constant (exampleConst5, exampleEq5) where + +import Examples.Eq (exampleEq) +import Numeric.Natural (Natural) + +import ZkFold.Base.Algebra.Basic.Class (FromConstant (..)) +import ZkFold.Symbolic.Class (Symbolic) +import ZkFold.Symbolic.Data.Bool (Bool) +import ZkFold.Symbolic.Data.FieldElement (FieldElement) + +exampleConst5 :: Symbolic c => FieldElement c +exampleConst5 = fromConstant (5 :: Natural) + +exampleEq5 :: Symbolic c => FieldElement c -> Bool c +exampleEq5 = exampleEq exampleConst5 diff --git a/examples/ZkFold/Symbolic/Examples.hs b/examples/ZkFold/Symbolic/Examples.hs index 218bbb7ad..792e94fc5 100644 --- a/examples/ZkFold/Symbolic/Examples.hs +++ b/examples/ZkFold/Symbolic/Examples.hs @@ -9,6 +9,7 @@ import Data.Type.Equality (type (~)) import Examples.BatchTransfer (exampleBatchTransfer) import Examples.ByteString import Examples.Conditional (exampleConditional) +import Examples.Constant (exampleConst5, exampleEq5) import Examples.Eq (exampleEq) import Examples.FFA import Examples.Fibonacci (exampleFibonacci) @@ -52,6 +53,8 @@ examples :: [(String, ExampleOutput)] examples = [ ("Eq", exampleOutput exampleEq) , ("Conditional", exampleOutput exampleConditional) + , ("Constant.5", exampleOutput exampleConst5) + , ("Eq.Constant.5", exampleOutput exampleEq5) , ("ByteString.And.32", exampleOutput $ exampleByteStringAnd @32) , ("ByteString.Or.64", exampleOutput $ exampleByteStringOr @64) , ("UInt.Mul.64.Auto", exampleOutput $ exampleUIntMul @64 @Auto) diff --git a/src/ZkFold/Base/Protocol/Plonkup/LookupConstraint.hs b/src/ZkFold/Base/Protocol/Plonkup/LookupConstraint.hs index d5d1eb528..c20ba6204 100644 --- a/src/ZkFold/Base/Protocol/Plonkup/LookupConstraint.hs +++ b/src/ZkFold/Base/Protocol/Plonkup/LookupConstraint.hs @@ -10,7 +10,7 @@ import ZkFold.Base.Data.ByteString (toByteStri import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal -newtype LookupConstraint i a = LookupConstraint { lkVar :: Var (Vector i) } +newtype LookupConstraint i a = LookupConstraint { lkVar :: SysVar (Vector i) } deriving (Show, Eq) instance (Arbitrary a, Binary a) => Arbitrary (LookupConstraint i a) where diff --git a/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs b/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs index 4bcf58911..ac1311dba 100644 --- a/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs +++ b/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs @@ -3,29 +3,28 @@ module ZkFold.Base.Protocol.Plonkup.PlonkConstraint where -import Control.Monad (guard, return) +import Control.Monad (guard, replicateM, return) import Data.Binary (Binary) import Data.Containers.ListUtils (nubOrd) import Data.Eq (Eq (..)) import Data.Function (($), (.)) import Data.Functor ((<$>)) -import Data.List (find, head, map, permutations, sort, (++)) +import Data.List (find, head, map, permutations, sort, (!!), (++)) import Data.Map (Map) import qualified Data.Map as Map -import Data.Maybe (Maybe (..), mapMaybe, maybe) +import Data.Maybe (Maybe (..), mapMaybe) +import Data.Ord (Ord) import GHC.IsList (IsList (..)) import GHC.TypeNats (KnownNat) import Numeric.Natural (Natural) -import Prelude (error) -import Test.QuickCheck (Arbitrary (..), elements) +import Test.QuickCheck (Arbitrary (..)) import Text.Show (Show) import ZkFold.Base.Algebra.Basic.Class -import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, evalMonomial, evalPolynomial, polynomial, - var, variables) +import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, polynomial, var, variables) import ZkFold.Base.Data.ByteString (toByteString) import ZkFold.Base.Data.Vector (Vector) -import ZkFold.Prelude (length, replicate, replicateA, take) +import ZkFold.Prelude (length, take) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal data PlonkConstraint i a = PlonkConstraint @@ -34,41 +33,39 @@ data PlonkConstraint i a = PlonkConstraint , qr :: a , qo :: a , qc :: a - , x1 :: Maybe (Var (Vector i)) - , x2 :: Maybe (Var (Vector i)) - , x3 :: Maybe (Var (Vector i)) + , x1 :: Var a (Vector i) + , x2 :: Var a (Vector i) + , x3 :: Var a (Vector i) } deriving (Show, Eq) -instance (Arbitrary a, Binary a, KnownNat i) => Arbitrary (PlonkConstraint i a) where +instance (Ord a, Arbitrary a, Binary a, KnownNat i) => Arbitrary (PlonkConstraint i a) where arbitrary = do qm <- arbitrary ql <- arbitrary qr <- arbitrary qo <- arbitrary qc <- arbitrary - k <- elements [1, 2, 3] - xs0 <- sort <$> replicateA k (Just . NewVar . toByteString @a <$> arbitrary) - let (x, y, z) = case replicate (3 -! k) Nothing ++ xs0 of - [x', y', z'] -> (x', y', z') - _ -> error "impossible" - return $ PlonkConstraint qm ql qr qo qc x y z + let arbitraryNewVar = SysVar . NewVar . toByteString @a <$> arbitrary + xs <- sort <$> replicateM 3 arbitraryNewVar + let x1 = xs !! 0; x2 = xs !! 1; x3 = xs !! 2 + return $ PlonkConstraint qm ql qr qo qc x1 x2 x3 -toPlonkConstraint :: forall a i . (Eq a, FiniteField a, KnownNat i) => Poly a (Var (Vector i)) Natural -> PlonkConstraint i a +toPlonkConstraint :: forall a i . (Ord a, FiniteField a, KnownNat i) => Poly a (Var a (Vector i)) Natural -> PlonkConstraint i a toPlonkConstraint p = - let xs = map Just $ toList (variables p) + let xs = toList (variables p) perms = nubOrd $ map (take 3) $ permutations $ case length xs of - 0 -> [Nothing, Nothing, Nothing] - 1 -> [Nothing, Nothing, head xs, head xs] - 2 -> [Nothing] ++ xs ++ xs + 0 -> [ConstVar one, ConstVar one, ConstVar one] + 1 -> [ConstVar one, ConstVar one, head xs, head xs] + 2 -> [ConstVar one] ++ xs ++ xs _ -> xs ++ xs - getCoef :: Map (Maybe (Var (Vector i))) Natural -> a - getCoef m = case find (\(_, as) -> m == Map.mapKeys Just as) (toList p) of + getCoef :: Map (Var a (Vector i)) Natural -> a + getCoef m = case find (\(_, as) -> m == as) (toList p) of Just (c, _) -> c _ -> zero - getCoefs :: [Maybe (Var (Vector i))] -> Maybe (PlonkConstraint i a) + getCoefs :: [Var a (Vector i)] -> Maybe (PlonkConstraint i a) getCoefs [a, b, c] = do let xa = [(a, 1)] xb = [(b, 1)] @@ -80,18 +77,17 @@ toPlonkConstraint p = qr = getCoef $ fromList xb qo = getCoef $ fromList xc qc = getCoef Map.empty - guard $ evalPolynomial evalMonomial (var . Just) p - polynomial [(qm, fromList xaxb), (ql, fromList xa), (qr, fromList xb), (qo, fromList xc), (qc, one)] == zero + guard $ p - polynomial [(qm, fromList xaxb), (ql, fromList xa), (qr, fromList xb), (qo, fromList xc), (qc, one)] == zero return $ PlonkConstraint qm ql qr qo qc a b c getCoefs _ = Nothing in head $ mapMaybe getCoefs perms -fromPlonkConstraint :: (Eq a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural +fromPlonkConstraint :: (Ord a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var a (Vector i)) Natural fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) = - let xvar = maybe zero var - xa = xvar a - xb = xvar b - xc = xvar c + let xa = var a + xb = var b + xc = var c xaxb = xa * xb in scale qm xaxb diff --git a/src/ZkFold/Base/Protocol/Plonkup/PlonkupConstraint.hs b/src/ZkFold/Base/Protocol/Plonkup/PlonkupConstraint.hs index 313d1dc81..287edb93d 100644 --- a/src/ZkFold/Base/Protocol/Plonkup/PlonkupConstraint.hs +++ b/src/ZkFold/Base/Protocol/Plonkup/PlonkupConstraint.hs @@ -12,7 +12,7 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal data PlonkupConstraint i a = ConsPlonk (PlonkConstraint i a) | ConsLookup (LookupConstraint i a) | ConsExtra -getPlonkConstraint :: (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> PlonkConstraint i a +getPlonkConstraint :: (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> PlonkConstraint i a getPlonkConstraint (ConsPlonk c) = c getPlonkConstraint _ = toPlonkConstraint zero @@ -20,17 +20,17 @@ isLookupConstraint :: FiniteField a => PlonkupConstraint i a -> a isLookupConstraint (ConsLookup _) = one isLookupConstraint _ = zero -getA :: forall a i . (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Maybe (Var (Vector i)) +getA :: forall a i . (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Var a (Vector i) getA (ConsPlonk c) = x1 c -getA (ConsLookup c) = Just $ lkVar c +getA (ConsLookup c) = SysVar $ lkVar c getA ConsExtra = x1 (toPlonkConstraint @a zero) -getB :: forall a i . (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Maybe (Var (Vector i)) +getB :: forall a i . (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Var a (Vector i) getB (ConsPlonk c) = x2 c -getB (ConsLookup c) = Just $ lkVar c +getB (ConsLookup c) = SysVar $ lkVar c getB ConsExtra = x2 (toPlonkConstraint @a zero) -getC :: forall a i . (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Maybe (Var (Vector i)) +getC :: forall a i . (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Var a (Vector i) getC (ConsPlonk c) = x3 c -getC (ConsLookup c) = Just $ lkVar c +getC (ConsLookup c) = SysVar $ lkVar c getC ConsExtra = x3 (toPlonkConstraint @a zero) diff --git a/src/ZkFold/Base/Protocol/Plonkup/Relation.hs b/src/ZkFold/Base/Protocol/Plonkup/Relation.hs index f810bfc45..fe401e34b 100644 --- a/src/ZkFold/Base/Protocol/Plonkup/Relation.hs +++ b/src/ZkFold/Base/Protocol/Plonkup/Relation.hs @@ -6,8 +6,7 @@ module ZkFold.Base.Protocol.Plonkup.Relation where import Data.Binary (Binary) import Data.Bool (bool) -import Data.Functor.Rep (index) -import Data.Map (elems, keys, (!)) +import Data.Map (elems, keys) import Data.Maybe (fromJust) import GHC.IsList (IsList (..)) import Prelude hiding (Num (..), drop, length, replicate, sum, @@ -17,7 +16,7 @@ import Test.QuickCheck (Arbitrary import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number import ZkFold.Base.Algebra.Basic.Permutations (Permutation, fromCycles, mkIndexPartition) -import ZkFold.Base.Algebra.Polynomials.Multivariate (var) +import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var) import ZkFold.Base.Algebra.Polynomials.Univariate (PolyVec, toPolyVec) import ZkFold.Base.Data.Vector (Vector, fromVector) import ZkFold.Base.Protocol.Plonkup.Internal (PlonkupPermutationSize) @@ -76,7 +75,7 @@ toPlonkupRelation :: forall i n l a . toPlonkupRelation ac = let xPub = acOutput ac pubInputConstraints = map var (fromVector xPub) - plonkConstraints = elems (acSystem ac) + plonkConstraints = map (evalPolynomial evalMonomial (var . SysVar)) (elems (acSystem ac)) rs = map toConstant $ elems $ acRange ac -- TODO: We are expecting at most one range. t = toPolyVec $ fromList $ map fromConstant $ bool [] (replicate (value @n -! length rs + 1) 0 ++ [ 0 .. head rs ]) (not $ null rs) @@ -107,15 +106,11 @@ toPlonkupRelation ac = -- TODO: Permutation code is not particularly safe. We rely on the list being of length 3*n. sigma = fromCycles @(3*n) $ mkIndexPartition $ fromList $ a ++ b ++ c - indexW _ Nothing = one - indexW i (Just (InVar v)) = index i v - indexW i (Just (NewVar v)) = witnessGenerator ac i ! v - - w1 i = toPolyVec $ fromList $ fmap (indexW i) a - w2 i = toPolyVec $ fromList $ fmap (indexW i) b - w3 i = toPolyVec $ fromList $ fmap (indexW i) c + w1 i = toPolyVec $ fromList $ fmap (indexW ac i) a + w2 i = toPolyVec $ fromList $ fmap (indexW ac i) b + w3 i = toPolyVec $ fromList $ fmap (indexW ac i) c witness i = (w1 i, w2 i, w3 i) - pubInput i = fmap (indexW i . Just) xPub + pubInput i = fmap (indexW ac i) xPub in if max n' nLookup <= value @n then Just $ PlonkupRelation {..} diff --git a/src/ZkFold/Symbolic/Class.hs b/src/ZkFold/Symbolic/Class.hs index 33d06036b..beb2b016e 100644 --- a/src/ZkFold/Symbolic/Class.hs +++ b/src/ZkFold/Symbolic/Class.hs @@ -2,17 +2,18 @@ module ZkFold.Symbolic.Class (module ZkFold.Symbolic.Class, Arithmetic) where +import Control.Monad import Data.Foldable (Foldable) -import Data.Function (const, ($), (.)) -import Data.Functor (Functor (fmap), (<$>)) +import Data.Function ((.)) +import Data.Functor ((<$>)) import Data.Kind (Type) import Data.Type.Equality (type (~)) -import GHC.Generics (Par1 (Par1), type (:.:) (unComp1)) +import GHC.Generics (type (:.:) (unComp1)) import Numeric.Natural (Natural) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Control.HApplicative (HApplicative (hpair, hunit)) -import ZkFold.Base.Data.Package (Package (pack), packed) +import ZkFold.Base.Data.Package (Package (pack)) import ZkFold.Base.Data.Product (uncurryP) import ZkFold.Symbolic.MonadCircuit @@ -53,9 +54,8 @@ class (HApplicative c, Package c, Arithmetic (BaseField c)) => Symbolic c where fromCircuitF x f = symbolicF x (runWitnesses @Natural @(BaseField c) . f) f -- | Embeds the pure value(s) into generic context @c@. -embed :: (Symbolic c, Foldable f, Functor f) => f (BaseField c) -> c f -embed = packed . fmap (\x -> - fromCircuitF hunit $ const $ Par1 <$> newAssigned (const $ fromConstant x)) +embed :: (Symbolic c, Functor f) => f (BaseField c) -> c f +embed cs = fromCircuitF hunit (\_ -> return (fromConstant <$> cs)) symbolic2F :: (Symbolic c, BaseField c ~ a) => c f -> c g -> (f a -> g a -> h a) -> diff --git a/src/ZkFold/Symbolic/Compiler.hs b/src/ZkFold/Symbolic/Compiler.hs index 2ffdeaf14..c4d9b8d25 100644 --- a/src/ZkFold/Symbolic/Compiler.hs +++ b/src/ZkFold/Symbolic/Compiler.hs @@ -10,7 +10,7 @@ module ZkFold.Symbolic.Compiler ( solder, ) where -import Data.Aeson (ToJSON) +import Data.Aeson (FromJSON, ToJSON) import Data.Binary (Binary) import Data.Function (const, (.)) import Data.Functor (($>)) @@ -101,6 +101,7 @@ compileIO :: ( KnownNat ni , ni ~ TypeSize (Support f) , c ~ ArithmeticCircuit a (Vector ni) + , FromJSON a , ToJSON a , SymbolicData f , Context f ~ c diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 0b261ab3e..c7302735e 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -52,8 +52,8 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomi import ZkFold.Prelude (length) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, - Var (..), acInput, eval, eval1, exec, exec1, - witnessGenerator) + SysVar (..), Var (..), acInput, eval, eval1, exec, + exec1, witnessGenerator) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map import ZkFold.Symbolic.Data.Combinators (expansion) import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..)) @@ -87,7 +87,7 @@ desugarRanges :: (Arithmetic a, Binary a, Binary (Rep i), Ord (Rep i), Representable i) => ArithmeticCircuit a i o -> ArithmeticCircuit a i o desugarRanges c = - let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(NewVar k, v) | (k,v) <- toList (acRange c)] + let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(SysVar (NewVar k), v) | (k,v) <- toList (acRange c)] in r' { acRange = mempty, acOutput = acOutput c } ----------------------------------- Information ----------------------------------- @@ -111,7 +111,7 @@ acValue r = eval r U1 -- -- TODO: Move this elsewhere (?) -- TODO: Check that all arguments have been applied. -acPrint :: (Show a, Show (o (Var U1)), Show (o a), Functor o) => ArithmeticCircuit a U1 o -> IO () +acPrint :: (Show a, Show (o (Var a U1)), Show (o a), Functor o) => ArithmeticCircuit a U1 o -> IO () acPrint ac = do let m = elems (acSystem ac) w = witnessGenerator ac U1 diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs index e02073db3..930958f0b 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs @@ -12,7 +12,7 @@ import Data.Functor.Rep (Representa import Data.Map hiding (drop, foldl, foldl', foldr, map, null, splitAt, take, toList) import GHC.Generics (Par1 (..)) -import Prelude (Show, mempty, pure, return, show, ($), (++), +import Prelude (Show, mempty, pure, return, show, ($), (++), (.), (<$>)) import qualified Prelude as Haskell import Test.QuickCheck (Arbitrary (arbitrary), Gen, elements) @@ -37,7 +37,7 @@ instance , Haskell.Foldable i ) => Arbitrary (ArithmeticCircuit a i Par1) where arbitrary = do - outVar <- InVar <$> arbitrary + outVar <- SysVar . InVar <$> arbitrary let ac = mempty {acOutput = Par1 outVar} fromFieldElement <$> arbitrary' (FieldElement ac) 10 @@ -55,7 +55,7 @@ instance arbitrary = do ac <- arbitrary @(ArithmeticCircuit a i Par1) o <- unsafeToVector <$> genSubset (value @l) (getAllVars ac) - return ac {acOutput = o} + return ac {acOutput = SysVar <$> o} arbitrary' :: forall a i . @@ -68,8 +68,8 @@ arbitrary' ac iter = do let vars = getAllVars (fromFieldElement ac) li <- elements vars ri <- elements vars - let (l, r) = ( FieldElement (fromFieldElement ac) { acOutput = pure li } - , FieldElement (fromFieldElement ac) { acOutput = pure ri }) + let (l, r) = ( FieldElement (fromFieldElement ac) { acOutput = pure (SysVar li) } + , FieldElement (fromFieldElement ac) { acOutput = pure (SysVar ri) }) ac' <- elements [ l + r , l * r @@ -79,14 +79,14 @@ arbitrary' ac iter = do arbitrary' ac' (iter -! 1) -- TODO: make it more readable -instance (FiniteField a, Haskell.Eq a, Show a, Show (o (Var i)), Haskell.Ord (Rep i), Show (Var i)) => Show (ArithmeticCircuit a i o) where +instance (FiniteField a, Haskell.Eq a, Show a, Show (o (Var a i)), Haskell.Ord (Rep i), Show (Var a i), Show (Rep i)) => Show (ArithmeticCircuit a i o) where show r = "ArithmeticCircuit { acSystem = " ++ show (acSystem r) ++ "\n, acRange = " ++ show (acRange r) ++ "\n, acOutput = " ++ show (acOutput r) ++ " }" -- TODO: add witness generation info to the JSON object -instance (ToJSON a, ToJSON (o (Var i)), ToJSONKey (Var i), FromJSONKey (Var i)) => ToJSON (ArithmeticCircuit a i o) where +instance (ToJSON a, ToJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey (Var a i), ToJSON (Rep i)) => ToJSON (ArithmeticCircuit a i o) where toJSON r = object [ "system" .= acSystem r, @@ -95,7 +95,7 @@ instance (ToJSON a, ToJSON (o (Var i)), ToJSONKey (Var i), FromJSONKey (Var i)) ] -- TODO: properly restore the witness generation function -instance (FromJSON a, FromJSON (o (Var i)), ToJSONKey (Var i), FromJSONKey (Var i), Haskell.Ord (Rep i)) => FromJSON (ArithmeticCircuit a i o) where +instance (FromJSON a, FromJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey (Var a i), Haskell.Ord (Rep i), FromJSON (Rep i)) => FromJSON (ArithmeticCircuit a i o) where parseJSON = withObject "ArithmeticCircuit" $ \v -> do acSystem <- v .: "system" diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 4344f3904..5341e1122 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -8,6 +8,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( ArithmeticCircuit(..), Var (..), + SysVar (..), VarField, Arithmetic, Constraint, @@ -21,6 +22,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal ( exec, exec1, apply, + indexW ) where import Control.DeepSeq (NFData) @@ -32,6 +34,7 @@ import Data.Foldable (fold, to import Data.Functor.Rep import Data.Map.Strict hiding (drop, foldl, foldr, map, null, splitAt, take, toList) +import Data.Maybe (fromMaybe) import Data.Semialign (unzipDefault) import Data.Semigroup.Generic (GenericSemigroupMonoid (..)) import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..)) @@ -51,7 +54,7 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MerkleHash import ZkFold.Symbolic.MonadCircuit -- | The type that represents a constraint in the arithmetic circuit. -type Constraint c i = Poly c (Var i) Natural +type Constraint c i = Poly c (SysVar i) Natural -- | Arithmetic circuit in the form of a system of polynomial constraints. data ArithmeticCircuit a i o = ArithmeticCircuit @@ -62,7 +65,7 @@ data ArithmeticCircuit a i o = ArithmeticCircuit -- ^ The range constraints [0, a] for the selected variables acWitness :: Map ByteString (i a -> Map ByteString a -> a), -- ^ The witness generation functions - acOutput :: o (Var i) + acOutput :: o (Var a i) -- ^ The output variables } deriving (Generic) @@ -72,39 +75,64 @@ deriving via (GenericSemigroupMonoid (ArithmeticCircuit a i o)) deriving via (GenericSemigroupMonoid (ArithmeticCircuit a i o)) instance o ~ U1 => Monoid (ArithmeticCircuit a i o) -instance (NFData a, NFData (o (Var i)), NFData (Rep i)) +instance (NFData a, NFData (o (Var a i)), NFData (Rep i)) => NFData (ArithmeticCircuit a i o) -- | Variables are SHA256 digests (32 bytes) type VarField = Zp (2 ^ (32 * 8)) -data Var i +data SysVar i = InVar (Rep i) | NewVar ByteString deriving Generic -deriving anyclass instance FromJSON (Rep i) => FromJSON (Var i) -deriving anyclass instance FromJSON (Rep i) => FromJSONKey (Var i) -deriving anyclass instance ToJSON (Rep i) => ToJSONKey (Var i) -deriving anyclass instance ToJSON (Rep i) => ToJSON (Var i) -deriving stock instance Show (Rep i) => Show (Var i) -deriving stock instance Eq (Rep i) => Eq (Var i) -deriving stock instance Ord (Rep i) => Ord (Var i) -deriving instance NFData (Rep i) => NFData (Var i) +deriving anyclass instance FromJSON (Rep i) => FromJSON (SysVar i) +deriving anyclass instance FromJSON (Rep i) => FromJSONKey (SysVar i) +deriving anyclass instance ToJSON (Rep i) => ToJSONKey (SysVar i) +deriving anyclass instance ToJSON (Rep i) => ToJSON (SysVar i) +deriving stock instance Show (Rep i) => Show (SysVar i) +deriving stock instance Eq (Rep i) => Eq (SysVar i) +deriving stock instance Ord (Rep i) => Ord (SysVar i) +deriving instance NFData (Rep i) => NFData (SysVar i) + +data Var a i + = SysVar (SysVar i) + | ConstVar a + deriving Generic +deriving anyclass instance (FromJSON (Rep i), FromJSON a) => FromJSON (Var a i) +deriving anyclass instance (FromJSON (Rep i), FromJSON a) => FromJSONKey (Var a i) +deriving anyclass instance (ToJSON (Rep i), ToJSON a) => ToJSONKey (Var a i) +deriving anyclass instance (ToJSON (Rep i), ToJSON a) => ToJSON (Var a i) +deriving stock instance (Show (Rep i), Show a) => Show (Var a i) +deriving stock instance (Eq (Rep i), Eq a) => Eq (Var a i) +deriving stock instance (Ord (Rep i), Ord a) => Ord (Var a i) +deriving instance (NFData (Rep i), NFData a) => NFData (Var a i) +instance FromConstant a (Var a i) where + fromConstant = ConstVar ---------------------------------- Variables ----------------------------------- -acInput :: Representable i => i (Var i) -acInput = tabulate InVar +acInput :: Representable i => i (Var a i) +acInput = fmapRep (SysVar . InVar) (tabulate id) + +getAllVars :: forall a i o. (Representable i, Foldable i) => ArithmeticCircuit a i o -> [SysVar i] +getAllVars ac = toList acInput0 ++ map NewVar (keys $ acWitness ac) where + acInput0 :: i (SysVar i) + acInput0 = fmapRep InVar (tabulate @i id) -getAllVars :: (Representable i, Foldable i) => ArithmeticCircuit a i o -> [Var i] -getAllVars ac = toList acInput ++ map NewVar (keys $ acWitness ac) +indexW :: Representable i => ArithmeticCircuit a i o -> i a -> Var a i -> a +indexW circuit inputs = \case + SysVar (InVar inV) -> index inputs inV + SysVar (NewVar newV) -> fromMaybe + (error ("no such NewVar: " <> show newV)) + (witnessGenerator circuit inputs !? newV) + ConstVar cV -> cV --------------------------- Symbolic compiler context -------------------------- -crown :: ArithmeticCircuit a i g -> f (Var i) -> ArithmeticCircuit a i f +crown :: ArithmeticCircuit a i g -> f (Var a i) -> ArithmeticCircuit a i f crown = flip (set #acOutput) -behead :: ArithmeticCircuit a i f -> (ArithmeticCircuit a i U1, f (Var i)) +behead :: ArithmeticCircuit a i f -> (ArithmeticCircuit a i U1, f (Var a i)) behead = liftA2 (,) (set #acOutput U1) acOutput instance HFunctor (ArithmeticCircuit a i) where @@ -128,19 +156,26 @@ instance instance ( Arithmetic a, Binary a, Representable i, Binary (Rep i), Ord (Rep i) - , o ~ U1) => MonadCircuit (Var i) a (State (ArithmeticCircuit a i o)) where + , o ~ U1) => MonadCircuit (Var a i) a (State (ArithmeticCircuit a i o)) where unconstrained witness = do let v = toVar @a witness -- TODO: forbid reassignment of variables zoom #acWitness . modify $ insert v $ \i w -> witness $ \case - InVar inV -> index i inV - NewVar newV -> w ! newV - return (NewVar v) - - constraint p = zoom #acSystem . modify $ insert (toVar @a p) (p var) - - rangeConstraint (NewVar v) upperBound = + SysVar (InVar inV) -> index i inV + SysVar (NewVar newV) -> w ! newV + ConstVar cV -> fromConstant cV + return (SysVar (NewVar v)) + + constraint p = + let + evalConstVar = \case + SysVar sysV -> var sysV + ConstVar cV -> fromConstant cV + in + zoom #acSystem . modify $ insert (toVar @a p) (p evalConstVar) + + rangeConstraint (SysVar (NewVar v)) upperBound = zoom #acRange . modify $ insert v upperBound -- FIXME range-constrain other variable types rangeConstraint _ _ = error "Cannot range-constrain this variable" @@ -168,10 +203,11 @@ instance -- 'WitnessField' is a root hash of a Merkle tree for a witness. toVar :: forall a i. (Finite a, Binary a, Binary (Rep i)) => - Witness (Var i) a -> ByteString + Witness (Var a i) a -> ByteString toVar witness = runHash @(Just (Order a)) $ witness $ \case - InVar inV -> merkleHash inV - NewVar newV -> M newV + SysVar (InVar inV) -> merkleHash inV + SysVar (NewVar newV) -> M newV + ConstVar cV -> fromConstant cV ----------------------------- Evaluation functions ----------------------------- @@ -186,9 +222,7 @@ eval1 ctx i = unPar1 (eval ctx i) -- | Evaluates the arithmetic circuit using the supplied input map. eval :: (Representable i, Functor o) => ArithmeticCircuit a i o -> i a -> o a -eval ctx i = acOutput ctx <&> \case - NewVar k -> witnessGenerator ctx i ! k - InVar j -> index i j +eval ctx i = indexW ctx i <$> acOutput ctx -- | Evaluates the arithmetic circuit with no inputs and one output. exec1 :: ArithmeticCircuit a U1 Par1 -> a diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index 710062899..6ce2dfb4d 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -19,8 +19,8 @@ import Test.QuickCheck (Arbitrary import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate import ZkFold.Base.Data.ByteString (toByteString) -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Var (..), - VarField, getAllVars) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), SysVar (..), + Var (..), VarField, getAllVars) -- This module contains functions for mapping variables in arithmetic circuits. @@ -57,5 +57,8 @@ mapVarArithmeticCircuit (ArithmeticCircuitTest ac wi) = -- TODO: the new arithmetic circuit expects the old input variables! We should make this safer. acWitness = (`Map.compose` backward) $ (\f i m -> f i (Map.compose m forward)) <$> acWitness ac } - mappedOutputs = varF <$> acOutput ac + varG = \case + SysVar v -> SysVar (varF v) + ConstVar c -> ConstVar c + mappedOutputs = varG <$> acOutput ac in ArithmeticCircuitTest (mappedCircuit {acOutput = mappedOutputs}) wi diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MerkleHash.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MerkleHash.hs index bb4021150..98b5709b4 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MerkleHash.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/MerkleHash.hs @@ -17,7 +17,7 @@ import ZkFold.Base.Data.ByteString (toByteString) newtype MerkleHash (n :: Maybe Natural) = M { runHash :: ByteString } -data Prec = Add | Mul | Div | Mod | Exp deriving (Generic, Binary) +data Prec = Add | Mul | Div | Mod | Exp | Const deriving (Generic, Binary) merkleHash :: Binary a => a -> MerkleHash n merkleHash = M . hash . toByteString @@ -34,7 +34,7 @@ instance {-# OVERLAPPING #-} FromConstant (MerkleHash n) (MerkleHash n) instance {-# OVERLAPPING #-} Scale (MerkleHash n) (MerkleHash n) instance Binary a => FromConstant a (MerkleHash n) where - fromConstant = merkleHash + fromConstant x = merkleHash (Const, x) instance Binary a => Scale a (MerkleHash n) diff --git a/src/ZkFold/Symbolic/MonadCircuit.hs b/src/ZkFold/Symbolic/MonadCircuit.hs index 29bfd2b9b..93f258911 100644 --- a/src/ZkFold/Symbolic/MonadCircuit.hs +++ b/src/ZkFold/Symbolic/MonadCircuit.hs @@ -28,7 +28,7 @@ type WitnessField n a = ( FiniteField a, ToConstant a, Const a ~ n -- -- NOTE: the property above is correct by construction for each function of a -- suitable type, you don't have to check it yourself. -type Witness i a = forall x n . (Algebra a x, WitnessField n x) => (i -> x) -> x +type Witness var a = forall x n . (Algebra a x, WitnessField n x) => (var -> x) -> x -- | A type of polynomial expressions. -- @i@ is a type of variables, @a@ is a base field. @@ -39,10 +39,10 @@ type Witness i a = forall x n . (Algebra a x, WitnessField n x) => (i -> x) -> x -- -- NOTE: the property above is correct by construction for each function of a -- suitable type, you don't have to check it yourself. -type ClosedPoly i a = forall x . Algebra a x => (i -> x) -> x +type ClosedPoly var a = forall x . Algebra a x => (var -> x) -> x -- | A type of constraints for new variables. --- @i@ is a type of variables, @a@ is a base field. +-- @var@ is a type of variables, @a@ is a base field. -- -- A function is a constraint for a new variable if, given an arbitrary algebra -- @x@ over @a@, a function mapping known variables to their witnesses in that @@ -51,7 +51,7 @@ type ClosedPoly i a = forall x . Algebra a x => (i -> x) -> x -- -- NOTE: the property above is correct by construction for each function of a -- suitable type, you don't have to check it yourself. -type NewConstraint i a = forall x . Algebra a x => (i -> x) -> i -> x +type NewConstraint var a = forall x . Algebra a x => (var -> x) -> var -> x -- | A monadic DSL for constructing arithmetic circuits. -- @i@ is a type of variables, @a@ is a base field @@ -69,23 +69,23 @@ type NewConstraint i a = forall x . Algebra a x => (i -> x) -> i -> x -- * That provided witnesses satisfy the provided constraints. To check this, -- you can use 'ZkFold.Symbolic.Compiler.ArithmeticCircuit.checkCircuit'. -- * That introduced constraints are supported by the zk-SNARK utilized for later proving. -class Monad m => MonadCircuit i a m | m -> i, m -> a where +class (Monad m, FromConstant a var) => MonadCircuit var a m | m -> var, m -> a where -- | Creates new variable from witness. -- -- NOTE: this does not add any constraints to the system, -- use 'rangeConstraint' or 'constraint' to add them. - unconstrained :: Witness i a -> m i + unconstrained :: Witness var a -> m var -- | Adds new polynomial constraint to the system. -- E.g., @'constraint' (\\x -> x i)@ forces variable @i@ to be zero. -- -- NOTE: it is not checked (yet) whether provided constraint is in -- appropriate form for zkSNARK in use. - constraint :: ClosedPoly i a -> m () + constraint :: ClosedPoly var a -> m () -- | Adds new range constraint to the system. -- E.g., @'rangeConstraint' i B@ forces variable @i@ to be in range \([0; B]\). - rangeConstraint :: i -> a -> m () + rangeConstraint :: var -> a -> m () -- | Creates new variable given a polynomial witness -- AND adds a corresponding polynomial constraint. @@ -98,15 +98,15 @@ class Monad m => MonadCircuit i a m | m -> i, m -> a where -- -- NOTE: is is not checked (yet) whether the corresponding constraint is in -- appropriate form for zkSNARK in use. - newAssigned :: ClosedPoly i a -> m i - newAssigned p = newConstrained (\x i -> p x - x i) p + newAssigned :: ClosedPoly var a -> m var + newAssigned p = newConstrained (\x var -> p x - x var) p -- | Creates new variable from witness constrained with an inclusive upper bound. --- E.g., @'newRanged' b (\\x -> x i - one)@ creates new variable whose value --- is equal to @x i - one@ and which is expected to be in range @[0..b]@. +-- E.g., @'newRanged' b (\\x -> x var - one)@ creates new variable whose value +-- is equal to @x var - one@ and which is expected to be in range @[0..b]@. -- -- NOTE: this adds a range constraint to the system. -newRanged :: MonadCircuit i a m => a -> Witness i a -> m i +newRanged :: MonadCircuit var a m => a -> Witness var a -> m var newRanged upperBound witness = do v <- unconstrained witness rangeConstraint v upperBound @@ -121,7 +121,7 @@ newRanged upperBound witness = do -- -- NOTE: it is not checked (yet) whether provided constraint is in -- appropriate form for zkSNARK in use. -newConstrained :: MonadCircuit i a m => NewConstraint i a -> Witness i a -> m i +newConstrained :: MonadCircuit var a m => NewConstraint var a -> Witness var a -> m var newConstrained poly witness = do v <- unconstrained witness constraint (`poly` v) diff --git a/stats/Constant.5 b/stats/Constant.5 new file mode 100644 index 000000000..1d59a3f61 --- /dev/null +++ b/stats/Constant.5 @@ -0,0 +1,4 @@ +Constant.5 +Number of constraints: 0 +Number of variables: 0 +Number of range lookups: 0 \ No newline at end of file diff --git a/stats/Eq b/stats/Eq index d0a755681..73b62a0f1 100644 --- a/stats/Eq +++ b/stats/Eq @@ -1,4 +1,4 @@ Eq -Number of constraints: 5 -Number of variables: 5 +Number of constraints: 4 +Number of variables: 4 Number of range lookups: 0 \ No newline at end of file diff --git a/stats/Eq.Constant.5 b/stats/Eq.Constant.5 new file mode 100644 index 000000000..6b977fb9e --- /dev/null +++ b/stats/Eq.Constant.5 @@ -0,0 +1,4 @@ +Eq.Constant.5 +Number of constraints: 4 +Number of variables: 4 +Number of range lookups: 0 \ No newline at end of file diff --git a/stats/Fibonacci.100 b/stats/Fibonacci.100 index b1640ef09..42dc6e566 100644 --- a/stats/Fibonacci.100 +++ b/stats/Fibonacci.100 @@ -1,4 +1,4 @@ Fibonacci.100 -Number of constraints: 895 -Number of variables: 895 +Number of constraints: 794 +Number of variables: 794 Number of range lookups: 0 \ No newline at end of file diff --git a/stats/MiMCHash b/stats/MiMCHash index 22656f251..6a7770d49 100644 --- a/stats/MiMCHash +++ b/stats/MiMCHash @@ -1,4 +1,4 @@ MiMCHash -Number of constraints: 1980 -Number of variables: 1980 +Number of constraints: 1760 +Number of variables: 1760 Number of range lookups: 0 \ No newline at end of file diff --git a/stats/SHA256.32 b/stats/SHA256.32 index 4e274c970..b293f2578 100644 --- a/stats/SHA256.32 +++ b/stats/SHA256.32 @@ -1,4 +1,4 @@ SHA256.32 -Number of constraints: 63618 -Number of variables: 63892 +Number of constraints: 63616 +Number of variables: 63890 Number of range lookups: 1172 \ No newline at end of file diff --git a/stats/UInt.DivMod.32.Auto b/stats/UInt.DivMod.32.Auto index 25b3b2b4f..0536880d3 100644 --- a/stats/UInt.DivMod.32.Auto +++ b/stats/UInt.DivMod.32.Auto @@ -1,4 +1,4 @@ UInt.DivMod.32.Auto -Number of constraints: 4939 -Number of variables: 5033 +Number of constraints: 4906 +Number of variables: 5000 Number of range lookups: 256 \ No newline at end of file diff --git a/zkfold-base.cabal b/zkfold-base.cabal index 77622b761..6c9e32945 100644 --- a/zkfold-base.cabal +++ b/zkfold-base.cabal @@ -265,6 +265,7 @@ library zkfold-symbolic-examples Examples.BatchTransfer Examples.ByteString Examples.Conditional + Examples.Constant Examples.Eq Examples.FFA Examples.Fibonacci