Skip to content

Commit

Permalink
Merge pull request #248 from zkFold/eitan-const-vars
Browse files Browse the repository at this point in the history
Constant Variables
  • Loading branch information
echatav authored Sep 19, 2024
2 parents d80ed48 + 813cb56 commit 61a2242
Show file tree
Hide file tree
Showing 23 changed files with 190 additions and 134 deletions.
2 changes: 1 addition & 1 deletion bench/BenchCompiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ metrics name circuit =


benchmark ::
(NFData a, AdditiveMonoid a, NFData (o (Var i)), NFData (Rep i), i ~ Vector n_i, KnownNat n_i) =>
(NFData a, AdditiveMonoid a, NFData (o (Var a i)), NFData (Rep i), i ~ Vector n_i, KnownNat n_i) =>
String -> (() -> ArithmeticCircuit a i o) -> Benchmark
benchmark name circuit = bgroup name
[ bench "compilation" $ nf circuit ()
Expand Down
15 changes: 15 additions & 0 deletions examples/Examples/Constant.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module Examples.Constant (exampleConst5, exampleEq5) where

import Examples.Eq (exampleEq)
import Numeric.Natural (Natural)

import ZkFold.Base.Algebra.Basic.Class (FromConstant (..))
import ZkFold.Symbolic.Class (Symbolic)
import ZkFold.Symbolic.Data.Bool (Bool)
import ZkFold.Symbolic.Data.FieldElement (FieldElement)

exampleConst5 :: Symbolic c => FieldElement c
exampleConst5 = fromConstant (5 :: Natural)

exampleEq5 :: Symbolic c => FieldElement c -> Bool c
exampleEq5 = exampleEq exampleConst5
3 changes: 3 additions & 0 deletions examples/ZkFold/Symbolic/Examples.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Data.Type.Equality (type (~))
import Examples.BatchTransfer (exampleBatchTransfer)
import Examples.ByteString
import Examples.Conditional (exampleConditional)
import Examples.Constant (exampleConst5, exampleEq5)
import Examples.Eq (exampleEq)
import Examples.FFA
import Examples.Fibonacci (exampleFibonacci)
Expand Down Expand Up @@ -52,6 +53,8 @@ examples :: [(String, ExampleOutput)]
examples =
[ ("Eq", exampleOutput exampleEq)
, ("Conditional", exampleOutput exampleConditional)
, ("Constant.5", exampleOutput exampleConst5)
, ("Eq.Constant.5", exampleOutput exampleEq5)
, ("ByteString.And.32", exampleOutput $ exampleByteStringAnd @32)
, ("ByteString.Or.64", exampleOutput $ exampleByteStringOr @64)
, ("UInt.Mul.64.Auto", exampleOutput $ exampleUIntMul @64 @Auto)
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Base/Protocol/Plonkup/LookupConstraint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ZkFold.Base.Data.ByteString (toByteStri
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

newtype LookupConstraint i a = LookupConstraint { lkVar :: Var (Vector i) }
newtype LookupConstraint i a = LookupConstraint { lkVar :: SysVar (Vector i) }
deriving (Show, Eq)

instance (Arbitrary a, Binary a) => Arbitrary (LookupConstraint i a) where
Expand Down
60 changes: 28 additions & 32 deletions src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,28 @@

module ZkFold.Base.Protocol.Plonkup.PlonkConstraint where

import Control.Monad (guard, return)
import Control.Monad (guard, replicateM, return)
import Data.Binary (Binary)
import Data.Containers.ListUtils (nubOrd)
import Data.Eq (Eq (..))
import Data.Function (($), (.))
import Data.Functor ((<$>))
import Data.List (find, head, map, permutations, sort, (++))
import Data.List (find, head, map, permutations, sort, (!!), (++))
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (Maybe (..), mapMaybe, maybe)
import Data.Maybe (Maybe (..), mapMaybe)
import Data.Ord (Ord)
import GHC.IsList (IsList (..))
import GHC.TypeNats (KnownNat)
import Numeric.Natural (Natural)
import Prelude (error)
import Test.QuickCheck (Arbitrary (..), elements)
import Test.QuickCheck (Arbitrary (..))
import Text.Show (Show)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, evalMonomial, evalPolynomial, polynomial,
var, variables)
import ZkFold.Base.Algebra.Polynomials.Multivariate (Poly, polynomial, var, variables)
import ZkFold.Base.Data.ByteString (toByteString)
import ZkFold.Base.Data.Vector (Vector)
import ZkFold.Prelude (length, replicate, replicateA, take)
import ZkFold.Prelude (length, take)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

data PlonkConstraint i a = PlonkConstraint
Expand All @@ -34,41 +33,39 @@ data PlonkConstraint i a = PlonkConstraint
, qr :: a
, qo :: a
, qc :: a
, x1 :: Maybe (Var (Vector i))
, x2 :: Maybe (Var (Vector i))
, x3 :: Maybe (Var (Vector i))
, x1 :: Var a (Vector i)
, x2 :: Var a (Vector i)
, x3 :: Var a (Vector i)
}
deriving (Show, Eq)

instance (Arbitrary a, Binary a, KnownNat i) => Arbitrary (PlonkConstraint i a) where
instance (Ord a, Arbitrary a, Binary a, KnownNat i) => Arbitrary (PlonkConstraint i a) where
arbitrary = do
qm <- arbitrary
ql <- arbitrary
qr <- arbitrary
qo <- arbitrary
qc <- arbitrary
k <- elements [1, 2, 3]
xs0 <- sort <$> replicateA k (Just . NewVar . toByteString @a <$> arbitrary)
let (x, y, z) = case replicate (3 -! k) Nothing ++ xs0 of
[x', y', z'] -> (x', y', z')
_ -> error "impossible"
return $ PlonkConstraint qm ql qr qo qc x y z
let arbitraryNewVar = SysVar . NewVar . toByteString @a <$> arbitrary
xs <- sort <$> replicateM 3 arbitraryNewVar
let x1 = xs !! 0; x2 = xs !! 1; x3 = xs !! 2
return $ PlonkConstraint qm ql qr qo qc x1 x2 x3

toPlonkConstraint :: forall a i . (Eq a, FiniteField a, KnownNat i) => Poly a (Var (Vector i)) Natural -> PlonkConstraint i a
toPlonkConstraint :: forall a i . (Ord a, FiniteField a, KnownNat i) => Poly a (Var a (Vector i)) Natural -> PlonkConstraint i a
toPlonkConstraint p =
let xs = map Just $ toList (variables p)
let xs = toList (variables p)
perms = nubOrd $ map (take 3) $ permutations $ case length xs of
0 -> [Nothing, Nothing, Nothing]
1 -> [Nothing, Nothing, head xs, head xs]
2 -> [Nothing] ++ xs ++ xs
0 -> [ConstVar one, ConstVar one, ConstVar one]
1 -> [ConstVar one, ConstVar one, head xs, head xs]
2 -> [ConstVar one] ++ xs ++ xs
_ -> xs ++ xs

getCoef :: Map (Maybe (Var (Vector i))) Natural -> a
getCoef m = case find (\(_, as) -> m == Map.mapKeys Just as) (toList p) of
getCoef :: Map (Var a (Vector i)) Natural -> a
getCoef m = case find (\(_, as) -> m == as) (toList p) of
Just (c, _) -> c
_ -> zero

getCoefs :: [Maybe (Var (Vector i))] -> Maybe (PlonkConstraint i a)
getCoefs :: [Var a (Vector i)] -> Maybe (PlonkConstraint i a)
getCoefs [a, b, c] = do
let xa = [(a, 1)]
xb = [(b, 1)]
Expand All @@ -80,18 +77,17 @@ toPlonkConstraint p =
qr = getCoef $ fromList xb
qo = getCoef $ fromList xc
qc = getCoef Map.empty
guard $ evalPolynomial evalMonomial (var . Just) p - polynomial [(qm, fromList xaxb), (ql, fromList xa), (qr, fromList xb), (qo, fromList xc), (qc, one)] == zero
guard $ p - polynomial [(qm, fromList xaxb), (ql, fromList xa), (qr, fromList xb), (qo, fromList xc), (qc, one)] == zero
return $ PlonkConstraint qm ql qr qo qc a b c
getCoefs _ = Nothing

in head $ mapMaybe getCoefs perms

fromPlonkConstraint :: (Eq a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural
fromPlonkConstraint :: (Ord a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var a (Vector i)) Natural
fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) =
let xvar = maybe zero var
xa = xvar a
xb = xvar b
xc = xvar c
let xa = var a
xb = var b
xc = var c
xaxb = xa * xb
in
scale qm xaxb
Expand Down
14 changes: 7 additions & 7 deletions src/ZkFold/Base/Protocol/Plonkup/PlonkupConstraint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal

data PlonkupConstraint i a = ConsPlonk (PlonkConstraint i a) | ConsLookup (LookupConstraint i a) | ConsExtra

getPlonkConstraint :: (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> PlonkConstraint i a
getPlonkConstraint :: (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> PlonkConstraint i a
getPlonkConstraint (ConsPlonk c) = c
getPlonkConstraint _ = toPlonkConstraint zero

isLookupConstraint :: FiniteField a => PlonkupConstraint i a -> a
isLookupConstraint (ConsLookup _) = one
isLookupConstraint _ = zero

getA :: forall a i . (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Maybe (Var (Vector i))
getA :: forall a i . (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Var a (Vector i)
getA (ConsPlonk c) = x1 c
getA (ConsLookup c) = Just $ lkVar c
getA (ConsLookup c) = SysVar $ lkVar c
getA ConsExtra = x1 (toPlonkConstraint @a zero)

getB :: forall a i . (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Maybe (Var (Vector i))
getB :: forall a i . (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Var a (Vector i)
getB (ConsPlonk c) = x2 c
getB (ConsLookup c) = Just $ lkVar c
getB (ConsLookup c) = SysVar $ lkVar c
getB ConsExtra = x2 (toPlonkConstraint @a zero)

getC :: forall a i . (Eq a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Maybe (Var (Vector i))
getC :: forall a i . (Ord a, FiniteField a, KnownNat i) => PlonkupConstraint i a -> Var a (Vector i)
getC (ConsPlonk c) = x3 c
getC (ConsLookup c) = Just $ lkVar c
getC (ConsLookup c) = SysVar $ lkVar c
getC ConsExtra = x3 (toPlonkConstraint @a zero)
19 changes: 7 additions & 12 deletions src/ZkFold/Base/Protocol/Plonkup/Relation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ module ZkFold.Base.Protocol.Plonkup.Relation where

import Data.Binary (Binary)
import Data.Bool (bool)
import Data.Functor.Rep (index)
import Data.Map (elems, keys, (!))
import Data.Map (elems, keys)
import Data.Maybe (fromJust)
import GHC.IsList (IsList (..))
import Prelude hiding (Num (..), drop, length, replicate, sum,
Expand All @@ -17,7 +16,7 @@ import Test.QuickCheck (Arbitrary
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Algebra.Basic.Permutations (Permutation, fromCycles, mkIndexPartition)
import ZkFold.Base.Algebra.Polynomials.Multivariate (var)
import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial, var)
import ZkFold.Base.Algebra.Polynomials.Univariate (PolyVec, toPolyVec)
import ZkFold.Base.Data.Vector (Vector, fromVector)
import ZkFold.Base.Protocol.Plonkup.Internal (PlonkupPermutationSize)
Expand Down Expand Up @@ -76,7 +75,7 @@ toPlonkupRelation :: forall i n l a .
toPlonkupRelation ac =
let xPub = acOutput ac
pubInputConstraints = map var (fromVector xPub)
plonkConstraints = elems (acSystem ac)
plonkConstraints = map (evalPolynomial evalMonomial (var . SysVar)) (elems (acSystem ac))
rs = map toConstant $ elems $ acRange ac
-- TODO: We are expecting at most one range.
t = toPolyVec $ fromList $ map fromConstant $ bool [] (replicate (value @n -! length rs + 1) 0 ++ [ 0 .. head rs ]) (not $ null rs)
Expand Down Expand Up @@ -107,15 +106,11 @@ toPlonkupRelation ac =
-- TODO: Permutation code is not particularly safe. We rely on the list being of length 3*n.
sigma = fromCycles @(3*n) $ mkIndexPartition $ fromList $ a ++ b ++ c

indexW _ Nothing = one
indexW i (Just (InVar v)) = index i v
indexW i (Just (NewVar v)) = witnessGenerator ac i ! v

w1 i = toPolyVec $ fromList $ fmap (indexW i) a
w2 i = toPolyVec $ fromList $ fmap (indexW i) b
w3 i = toPolyVec $ fromList $ fmap (indexW i) c
w1 i = toPolyVec $ fromList $ fmap (indexW ac i) a
w2 i = toPolyVec $ fromList $ fmap (indexW ac i) b
w3 i = toPolyVec $ fromList $ fmap (indexW ac i) c
witness i = (w1 i, w2 i, w3 i)
pubInput i = fmap (indexW i . Just) xPub
pubInput i = fmap (indexW ac i) xPub

in if max n' nLookup <= value @n
then Just $ PlonkupRelation {..}
Expand Down
14 changes: 7 additions & 7 deletions src/ZkFold/Symbolic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

module ZkFold.Symbolic.Class (module ZkFold.Symbolic.Class, Arithmetic) where

import Control.Monad
import Data.Foldable (Foldable)
import Data.Function (const, ($), (.))
import Data.Functor (Functor (fmap), (<$>))
import Data.Function ((.))
import Data.Functor ((<$>))
import Data.Kind (Type)
import Data.Type.Equality (type (~))
import GHC.Generics (Par1 (Par1), type (:.:) (unComp1))
import GHC.Generics (type (:.:) (unComp1))
import Numeric.Natural (Natural)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Control.HApplicative (HApplicative (hpair, hunit))
import ZkFold.Base.Data.Package (Package (pack), packed)
import ZkFold.Base.Data.Package (Package (pack))
import ZkFold.Base.Data.Product (uncurryP)
import ZkFold.Symbolic.MonadCircuit

Expand Down Expand Up @@ -53,9 +54,8 @@ class (HApplicative c, Package c, Arithmetic (BaseField c)) => Symbolic c where
fromCircuitF x f = symbolicF x (runWitnesses @Natural @(BaseField c) . f) f

-- | Embeds the pure value(s) into generic context @c@.
embed :: (Symbolic c, Foldable f, Functor f) => f (BaseField c) -> c f
embed = packed . fmap (\x ->
fromCircuitF hunit $ const $ Par1 <$> newAssigned (const $ fromConstant x))
embed :: (Symbolic c, Functor f) => f (BaseField c) -> c f
embed cs = fromCircuitF hunit (\_ -> return (fromConstant <$> cs))

symbolic2F ::
(Symbolic c, BaseField c ~ a) => c f -> c g -> (f a -> g a -> h a) ->
Expand Down
3 changes: 2 additions & 1 deletion src/ZkFold/Symbolic/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module ZkFold.Symbolic.Compiler (
solder,
) where

import Data.Aeson (ToJSON)
import Data.Aeson (FromJSON, ToJSON)
import Data.Binary (Binary)
import Data.Function (const, (.))
import Data.Functor (($>))
Expand Down Expand Up @@ -101,6 +101,7 @@ compileIO ::
( KnownNat ni
, ni ~ TypeSize (Support f)
, c ~ ArithmeticCircuit a (Vector ni)
, FromJSON a
, ToJSON a
, SymbolicData f
, Context f ~ c
Expand Down
8 changes: 4 additions & 4 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomi
import ZkFold.Prelude (length)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance ()
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint,
Var (..), acInput, eval, eval1, exec, exec1,
witnessGenerator)
SysVar (..), Var (..), acInput, eval, eval1, exec,
exec1, witnessGenerator)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
import ZkFold.Symbolic.Data.Combinators (expansion)
import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..))
Expand Down Expand Up @@ -87,7 +87,7 @@ desugarRanges ::
(Arithmetic a, Binary a, Binary (Rep i), Ord (Rep i), Representable i) =>
ArithmeticCircuit a i o -> ArithmeticCircuit a i o
desugarRanges c =
let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(NewVar k, v) | (k,v) <- toList (acRange c)]
let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(SysVar (NewVar k), v) | (k,v) <- toList (acRange c)]
in r' { acRange = mempty, acOutput = acOutput c }

----------------------------------- Information -----------------------------------
Expand All @@ -111,7 +111,7 @@ acValue r = eval r U1
--
-- TODO: Move this elsewhere (?)
-- TODO: Check that all arguments have been applied.
acPrint :: (Show a, Show (o (Var U1)), Show (o a), Functor o) => ArithmeticCircuit a U1 o -> IO ()
acPrint :: (Show a, Show (o (Var a U1)), Show (o a), Functor o) => ArithmeticCircuit a U1 o -> IO ()
acPrint ac = do
let m = elems (acSystem ac)
w = witnessGenerator ac U1
Expand Down
Loading

0 comments on commit 61a2242

Please sign in to comment.