diff --git a/symbolic-base/src/ZkFold/Symbolic/Compiler.hs b/symbolic-base/src/ZkFold/Symbolic/Compiler.hs index 316356b32..609093cba 100644 --- a/symbolic-base/src/ZkFold/Symbolic/Compiler.hs +++ b/symbolic-base/src/ZkFold/Symbolic/Compiler.hs @@ -19,8 +19,8 @@ import Data.Ord (Ord) import Data.Proxy (Proxy (..)) import Data.Traversable (for) import GHC.Generics (Par1 (Par1)) -import Prelude (FilePath, IO, Monoid (mempty), Show (..), Traversable, - putStrLn, return, type (~), ($), (++)) +import Prelude (FilePath, IO, Show (..), Traversable, putStrLn, return, + type (~), ($), (++)) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Prelude (writeFileJSON) @@ -60,7 +60,7 @@ solder f = fromCircuit2F (pieces f input) b $ \r (Par1 i) -> do return r where Bool b = isValid input - input = restore @(Support f) $ const mempty { acOutput = acInput } + input = restore @(Support f) $ const idCircuit -- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1. compileForceOne :: diff --git a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 7c9f2964b..5d2b41896 100644 --- a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE TypeOperators #-} module ZkFold.Symbolic.Compiler.ArithmeticCircuit ( ArithmeticCircuit, Constraint, @@ -6,6 +7,8 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit ( -- high-level functions optimize, desugarRanges, + idCircuit, + guessOutput, -- low-level functions eval, eval1, @@ -30,16 +33,18 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit ( checkClosedCircuit ) where +import Control.DeepSeq (NFData) import Control.Monad (foldM) import Control.Monad.State (execState) import Data.Binary (Binary) -import Data.Functor.Rep (Representable (..)) +import Data.Foldable (for_) +import Data.Functor.Rep (Representable (..), mzipRep) import Data.Map hiding (drop, foldl, foldr, map, null, splitAt, take) import qualified Data.Map.Monoidal as M import qualified Data.Set as S import Data.Void (absurd) -import GHC.Generics (U1 (..)) +import GHC.Generics (U1 (..), (:*:)) import Numeric.Natural (Natural) import Prelude hiding (Num (..), drop, length, product, splitAt, sum, take, (!!), (^)) @@ -49,7 +54,10 @@ import Text.Pretty.Simple (pPrint) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Polynomials.Multivariate (evalMonomial, evalPolynomial) +import ZkFold.Base.Data.HFunctor (hmap) +import ZkFold.Base.Data.Product (fstP, sndP) import ZkFold.Prelude (length) +import ZkFold.Symbolic.Class (fromCircuit2F) import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance () import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit (..), Constraint, SysVar (..), Var (..), acInput, eval, eval1, exec, @@ -90,6 +98,23 @@ desugarRanges c = let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(SysVar v, k) | (k, s) <- M.toList (acRange c), v <- S.toList s] in r' { acRange = mempty, acOutput = acOutput c } +idCircuit :: Representable i => ArithmeticCircuit a p i i +idCircuit = ArithmeticCircuit + { acSystem = empty + , acRange = M.empty + , acWitness = empty + , acOutput = acInput + } + +guessOutput :: + (Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Binary (Rep o)) => + (Ord (Rep i), Ord (Rep o), NFData (Rep i), NFData (Rep o)) => + (Representable i, Representable o, Foldable o) => + ArithmeticCircuit a p i o -> ArithmeticCircuit a p (i :*: o) U1 +guessOutput c = fromCircuit2F (hlmap fstP c) (hmap sndP idCircuit) $ \o o' -> do + for_ (mzipRep o o') $ \(i, j) -> constraint (\x -> x i - x j) + return U1 + ----------------------------------- Information ----------------------------------- -- | Calculates the number of constraints in the system.