Skip to content

Commit

Permalink
Merge pull request #201 from zkFold/eitan-unify-circuit-types
Browse files Browse the repository at this point in the history
Unify ArithmeticCircuit type
  • Loading branch information
TurtlePU authored Aug 2, 2024
2 parents b7b5e32 + fcff79b commit 2e2b915
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 125 deletions.
4 changes: 2 additions & 2 deletions src/ZkFold/Base/Protocol/ARK/Plonk.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import ZkFold.Base.Protocol.ARK.Plonk.Relation (PlonkRelation (..),
import ZkFold.Base.Protocol.Commitment.KZG (com)
import ZkFold.Base.Protocol.NonInteractiveProof
import ZkFold.Prelude (length, (!))
import ZkFold.Symbolic.Compiler (ArithmeticCircuit, inputVariables)
import ZkFold.Symbolic.Compiler (ArithmeticCircuit (acInput))
import ZkFold.Symbolic.MonadCircuit (Arithmetic)

{-
Expand All @@ -55,7 +55,7 @@ instance (KnownNat n, KnownNat l, Arithmetic (ScalarField c1), Arbitrary (Scalar
=> Arbitrary (Plonk n l c1 c2 t) where
arbitrary = do
ac <- arbitrary
let fullInp = length . inputVariables $ ac
let fullInp = length . acInput $ ac
vecPubInp <- genSubset (value @l) fullInp
let (omega, k1, k2) = getParams (value @n)
Plonk omega k1 k2 (Vector vecPubInp) ac <$> arbitrary
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Base/Protocol/ARK/Plonk/Relation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ toPlonkRelation xPub ac0 =
evalX0 = evalPolynomial evalMonomial (\x -> if x == 0 then one else var x)

pubInputConstraints = map var (fromVector xPub)
acConstraints = map evalX0 $ elems (constraintSystem ac)
acConstraints = map evalX0 $ elems (acSystem ac)
extraConstraints = replicate (value @n -! acSizeN ac -! value @l) zero

system = map toPlonkConstraint $ pubInputConstraints ++ acConstraints ++ extraConstraints
Expand Down
4 changes: 2 additions & 2 deletions src/ZkFold/Base/Protocol/ARK/Protostar.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ instance Arithmetic a => SpecialSoundProtocol a (RecursiveCircuit n a) where
-- One round for Plonk
rounds = P.const 1

outputLength (RecursiveCircuit _ c) = P.fromIntegral $ M.size $ constraintSystem c
outputLength (RecursiveCircuit _ c) = P.fromIntegral $ M.size $ acSystem c

-- The transcript will be empty at this point, it is a one-round protocol
--
prover rc _ i _ = eval (circuit rc) (M.fromList $ P.zip [1..] (V.fromVector i))

-- We can use the polynomial system from the circuit, no need to build it from scratch
--
algebraicMap rc _ _ _ = M.elems $ constraintSystem (circuit rc)
algebraicMap rc _ _ _ = M.elems $ acSystem (circuit rc)

-- The transcript is only one prover message since this is a one-round protocol
--
Expand Down
19 changes: 9 additions & 10 deletions src/ZkFold/Symbolic/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ module ZkFold.Symbolic.Compiler (
compileIO
) where

import Data.Aeson (ToJSON)
import Data.Eq (Eq)
import Data.Function (const, (.))
import Prelude (FilePath, IO, Monoid (mempty), Show (..),
putStrLn, type (~), ($), (++))
import Data.Aeson (ToJSON)
import Data.Eq (Eq)
import Data.Function (const, (.))
import Prelude (FilePath, IO, Monoid (mempty), Show (..), putStrLn,
type (~), ($), (++))

import ZkFold.Base.Algebra.Basic.Class (MultiplicativeMonoid)
import ZkFold.Base.Algebra.Basic.Class (MultiplicativeMonoid)
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.Vector (Vector, unsafeToVector)
import ZkFold.Prelude (writeFileJSON)
import ZkFold.Base.Data.Vector (Vector, unsafeToVector)
import ZkFold.Prelude (writeFileJSON)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (acInput)
import ZkFold.Symbolic.Data.Class

{-
Expand Down Expand Up @@ -47,7 +46,7 @@ solder ::
solder f = pieces f (restore @c @(Support c f) $ const inputC)
where
inputList = [1..(typeSize @c @(Support c f))]
inputC = withOutputs (mempty { acInput = inputList }) (unsafeToVector inputList)
inputC = mempty { acInput = inputList, acOutput = unsafeToVector inputList }

-- | Compiles function `f` into an arithmetic circuit.
compile ::
Expand Down
44 changes: 21 additions & 23 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
ArithmeticCircuit,
Constraint,

withOutputs,
constraintSystem,
inputVariables,
witnessGenerator,
varOrder,
-- high-level functions
applyArgs,
optimize,
Expand All @@ -30,6 +26,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
-- Arithmetization type fields
acWitness,
acVarOrder,
acInput,
acOutput,
-- Testing functions
checkCircuit,
Expand All @@ -39,6 +36,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit (
import Control.Monad.State (execState)
import Data.Map hiding (drop, foldl, foldr, map, null, splitAt,
take)
import GHC.Generics (U1 (..))
import Numeric.Natural (Natural)
import Prelude hiding (Num (..), drop, length, product,
splitAt, sum, take, (!!), (^))
Expand All @@ -51,18 +49,17 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMon
import ZkFold.Prelude (length)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (desugarRange)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance ()
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..),
Circuit (..), Constraint, apply,
constraintSystem, eval, eval1, exec, exec1,
forceZero, inputVariables, varOrder,
withOutputs, witnessGenerator)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint,
apply, eval, eval1, exec, exec1, forceZero,
witnessGenerator)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map

--------------------------------- High-level functions --------------------------------

-- TODO: make this work for different input types.
applyArgs :: ArithmeticCircuit a f -> [a] -> ArithmeticCircuit a f
applyArgs r args = r { acCircuit = execState (apply args) (acCircuit r) }
applyArgs r args =
(execState (apply args) r {acOutput = U1}) {acOutput = acOutput r}

-- | Optimizes the constraint system.
--
Expand All @@ -72,20 +69,20 @@ optimize = id

-- | Desugars range constraints into polynomial constraints
desugarRanges :: Arithmetic a => ArithmeticCircuit a f -> ArithmeticCircuit a f
desugarRanges c@(ArithmeticCircuit r _) =
let r' = flip execState r . traverse (uncurry desugarRange) $ toList (acRange r)
in c { acCircuit = r' { acRange = mempty } }
desugarRanges c =
let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ toList (acRange c)
in r' { acRange = mempty, acOutput = acOutput c }

----------------------------------- Information -----------------------------------

-- | Calculates the number of constraints in the system.
acSizeN :: ArithmeticCircuit a f -> Natural
acSizeN = length . acSystem . acCircuit
acSizeN = length . acSystem

-- | Calculates the number of variables in the system.
-- The constant `1` is not counted.
acSizeM :: ArithmeticCircuit a f -> Natural
acSizeM = length . acVarOrder . acCircuit
acSizeM = length . acVarOrder

acValue :: Functor f => ArithmeticCircuit a f -> f a
acValue r = eval r mempty
Expand All @@ -95,12 +92,13 @@ acValue r = eval r mempty
-- TODO: Move this elsewhere (?)
-- TODO: Check that all arguments have been applied.
acPrint :: (Show a, Show (f Natural), Show (f a), Functor f) => ArithmeticCircuit a f -> IO ()
acPrint ac@(ArithmeticCircuit r o) = do
let m = elems (acSystem r)
i = acInput r
acPrint ac = do
let m = elems (acSystem ac)
i = acInput ac
w = witnessGenerator ac empty
v = acValue ac
vo = acVarOrder r
vo = acVarOrder ac
o = acOutput ac
putStr "System size: "
pPrint $ acSizeN ac
putStr "Variable size: "
Expand All @@ -127,7 +125,7 @@ checkClosedCircuit
=> Show a
=> ArithmeticCircuit a n
-> Property
checkClosedCircuit c@(ArithmeticCircuit r _) = withMaxSuccess 1 $ conjoin [ testPoly p | p <- elems (acSystem r) ]
checkClosedCircuit c = withMaxSuccess 1 $ conjoin [ testPoly p | p <- elems (acSystem c) ]
where
w = witnessGenerator c empty
testPoly p = evalPolynomial evalMonomial (w !) p === zero
Expand All @@ -139,9 +137,9 @@ checkCircuit
=> Show a
=> ArithmeticCircuit a n
-> Property
checkCircuit c@(ArithmeticCircuit r _) = conjoin [ property (testPoly p) | p <- elems (acSystem r) ]
checkCircuit c = conjoin [ property (testPoly p) | p <- elems (acSystem c) ]
where
testPoly p = do
ins <- vector . fromIntegral $ length (acInput r)
let w = witnessGenerator c . fromList $ zip (acInput r) ins
ins <- vector . fromIntegral $ length (acInput c)
let w = witnessGenerator c . fromList $ zip (acInput c) ins
return $ evalPolynomial evalMonomial (w !) p === zero
9 changes: 4 additions & 5 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ import ZkFold.Base.Algebra.Polynomials.Multivariate (vari
import qualified ZkFold.Base.Data.Vector as V
import ZkFold.Base.Data.Vector (Vector (..))
import ZkFold.Prelude (length, splitAt, (!!))
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), Circuit (acSystem),
acInput)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (ArithmeticCircuit (..), acInput)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
import ZkFold.Symbolic.MonadCircuit

Expand Down Expand Up @@ -135,10 +134,10 @@ runInvert r = do
return (js, ks)

embedVarIndex :: Arithmetic a => Natural -> ArithmeticCircuit a Par1
embedVarIndex n = ArithmeticCircuit { acCircuit = mempty { acInput = [ n ]}, acOutput = pure n}
embedVarIndex n = mempty { acInput = [ n ], acOutput = pure n}

embedVarIndexV :: (Arithmetic a, KnownNat n) => Natural -> ArithmeticCircuit a (Vector n)
embedVarIndexV n = ArithmeticCircuit { acCircuit = mempty { acInput = [ n ]}, acOutput = pure n}
embedVarIndexV n = mempty { acInput = [ n ], acOutput = pure n}

getAllVars :: MultiplicativeMonoid a => Circuit a -> [Natural]
getAllVars :: MultiplicativeMonoid a => ArithmeticCircuit a o -> [Natural]
getAllVars ac = nubOrd $ sort $ 0 : acInput ac ++ concatMap (toList . variables) (elems $ acSystem ac)
13 changes: 6 additions & 7 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ instance
instance (Arithmetic a, Arbitrary a) => Arbitrary (ArithmeticCircuit a Par1) where
arbitrary = do
k <- integerToNatural <$> chooseInteger (2, 10)
let ac = ArithmeticCircuit { acCircuit = mempty {acInput = [1..k]}, acOutput = pure k }
let ac = mempty { acInput = [1..k], acOutput = pure k }
arbitrary' ac 10

arbitrary' :: forall a . (Arithmetic a, Arbitrary a, FromConstant a a) => ArithmeticCircuit a Par1 -> Natural -> Gen (ArithmeticCircuit a Par1)
arbitrary' ac 0 = return ac
arbitrary' ac iter = do
let vars = getAllVars . acCircuit $ ac
let vars = getAllVars ac
li <- elements vars
ri <- elements vars
let (l, r) =( ac { acOutput = pure li }, ac { acOutput = pure ri })
Expand All @@ -215,16 +215,16 @@ arbitrary' ac iter = do

-- TODO: make it more readable
instance (FiniteField a, Haskell.Eq a, Show a, Show (f Natural)) => Show (ArithmeticCircuit a f) where
show (ArithmeticCircuit r o) = "ArithmeticCircuit { acInput = " ++ show (acInput r)
++ "\n, acSystem = " ++ show (acSystem r) ++ "\n, acOutput = " ++ show o ++ "\n, acVarOrder = " ++ show (acVarOrder r) ++ " }"
show r = "ArithmeticCircuit { acInput = " ++ show (acInput r)
++ "\n, acSystem = " ++ show (acSystem r) ++ "\n, acOutput = " ++ show (acOutput r) ++ "\n, acVarOrder = " ++ show (acVarOrder r) ++ " }"

-- TODO: add witness generation info to the JSON object
instance (ToJSON a, ToJSON (f Natural)) => ToJSON (ArithmeticCircuit a f) where
toJSON (ArithmeticCircuit r o) = object
toJSON r = object
[
"system" .= acSystem r,
"input" .= acInput r,
"output" .= o,
"output" .= acOutput r,
"order" .= acVarOrder r
]

Expand All @@ -240,5 +240,4 @@ instance (FromJSON a, FromJSON (f Natural)) => FromJSON (ArithmeticCircuit a f)
acOutput <- v .: "output"
let acWitness = empty
acRNG = mkStdGen 0
acCircuit = Circuit{..}
pure ArithmeticCircuit{..}
Loading

0 comments on commit 2e2b915

Please sign in to comment.