Skip to content

Commit

Permalink
Merge pull request #405 from zkFold/TurtlePU/circuit-fold
Browse files Browse the repository at this point in the history
+ `acFold` field in `ArithmeticCircuit`
  • Loading branch information
vlasin authored Dec 24, 2024
2 parents bad45c1 + 08546e0 commit f457833
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 36 deletions.
4 changes: 2 additions & 2 deletions symbolic-base/src/ZkFold/Base/Data/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

module ZkFold.Base.Data.Vector where

import Control.DeepSeq (NFData)
import Control.DeepSeq (NFData, NFData1)
import Control.Monad.State.Strict (runState, state)
import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Distributive (Distributive (..))
Expand All @@ -31,7 +31,7 @@ import ZkFold.Base.Data.ByteString (Binary (..))
import ZkFold.Prelude (length)

newtype Vector (size :: Natural) a = Vector {toV :: V.Vector a}
deriving (Show, Show1, Eq, Functor, Foldable, Traversable, Generic, NFData, Ord)
deriving (Show, Show1, Eq, Functor, Foldable, Traversable, Generic, NFData, NFData1, Ord)
deriving newtype (FromJSON, ToJSON)

-- helper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ desugarRanges c =
in r' { acRange = mempty, acOutput = acOutput c }

emptyCircuit :: ArithmeticCircuit a p i U1
emptyCircuit = ArithmeticCircuit empty M.empty empty U1
emptyCircuit = ArithmeticCircuit empty M.empty empty empty U1

-- | Given a natural transformation
-- from payload @p@ and input @i@ to output @o@,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,5 @@ instance (FromJSON a, FromJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey a
acRange <- v .: "range"
acOutput <- v .: "output"
let acWitness = empty
acFold = empty
pure ArithmeticCircuit{..}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE TypeApplications #-}
Expand All @@ -8,6 +7,7 @@

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (
ArithmeticCircuit(..),
CircuitFold (..),
Var (..),
SysVar (..),
WitVar (..),
Expand All @@ -32,12 +32,14 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (
witToVar
) where

import Control.DeepSeq (NFData)
import Control.DeepSeq (NFData (..), NFData1 (..))
import Control.Monad.State (State, modify, runState)
import Data.Bifunctor (Bifunctor (..))
import Data.Binary (Binary)
import Data.ByteString (ByteString)
import Data.Foldable (fold, toList)
import Data.Functor.Rep
import Data.List.Infinite (Infinite)
import Data.Map.Monoidal (MonoidalMap, insertWith)
import Data.Map.Strict hiding (drop, foldl, foldr, insertWith,
map, null, splitAt, take, toList)
Expand Down Expand Up @@ -68,15 +70,49 @@ import ZkFold.Symbolic.MonadCircuit
-- | The type that represents a constraint in the arithmetic circuit.
type Constraint c i = Poly c (SysVar i) Natural

type CircuitWitness a p i = WitnessF a (WitVar p i)

data CircuitFold a v w =
forall s j.
( Functor s, NFData1 s, Binary (Rep s), NFData (Rep s), Ord (Rep s)
, Functor j, Binary (Rep j), NFData (Rep j), Ord (Rep j)) =>
CircuitFold
{ foldStep :: ArithmeticCircuit a U1 (j :*: s) s
, foldSeed :: s v
, foldStream :: Infinite (j w)
, foldCount :: v
, foldResult :: s v
}

instance Functor (CircuitFold a v) where
fmap = second

instance Bifunctor (CircuitFold a) where
bimap f g CircuitFold {..} = CircuitFold
{ foldStep = foldStep
, foldSeed = f <$> foldSeed
, foldStream = fmap g <$> foldStream
, foldCount = f foldCount
, foldResult = f <$> foldResult
}

instance (NFData a, NFData v) => NFData (CircuitFold a v w) where
rnf CircuitFold {..} =
rnf (foldStep, foldCount)
`seq` liftRnf rnf foldSeed
`seq` liftRnf rnf foldResult

-- | Arithmetic circuit in the form of a system of polynomial constraints.
data ArithmeticCircuit a p i o = ArithmeticCircuit
{
acSystem :: Map ByteString (Constraint a i),
-- ^ The system of polynomial constraints
acRange :: MonoidalMap a (S.Set (SysVar i)),
-- ^ The range constraints [0, a] for the selected variables
acWitness :: Map ByteString (WitnessF a (WitVar p i)),
acWitness :: Map ByteString (CircuitWitness a p i),
-- ^ The witness generation functions
acFold :: Map ByteString (CircuitFold a (Var a i) (CircuitWitness a p i)),
-- ^ The set of folding operations
acOutput :: o (Var a i)
-- ^ The output variables
} deriving (Generic)
Expand All @@ -87,8 +123,9 @@ deriving via (GenericSemigroupMonoid (ArithmeticCircuit a p i o))
deriving via (GenericSemigroupMonoid (ArithmeticCircuit a p i o))
instance (Ord a, Ord (Rep i), o ~ U1) => Monoid (ArithmeticCircuit a p i o)

instance (NFData a, NFData (o (Var a i)), NFData (Rep i))
=> NFData (ArithmeticCircuit a p i o)
instance (NFData a, NFData1 o, NFData (Rep i))
=> NFData (ArithmeticCircuit a p i o) where
rnf (ArithmeticCircuit s r w f o) = rnf (s, r, w, f) `seq` liftRnf rnf o

-- | Variables are SHA256 digests (32 bytes)
type VarField = Zp (2 ^ (32 * 8))
Expand Down Expand Up @@ -134,17 +171,21 @@ indexW circuit payload inputs = \case
hlmap ::
(Representable i, Representable j, Ord (Rep j), Functor o) =>
(forall x . j x -> i x) -> ArithmeticCircuit a p i o -> ArithmeticCircuit a p j o
hlmap f (ArithmeticCircuit s r w o) = ArithmeticCircuit
hlmap f (ArithmeticCircuit s r w d o) = ArithmeticCircuit
{ acSystem = mapVars (imapSysVar f) <$> s
, acRange = S.map (imapSysVar f) <$> r
, acWitness = fmap (imapWitVar f) <$> w
, acFold = bimap (imapVar f) (imapWitVar f <$>) <$> d
, acOutput = imapVar f <$> o
}

hpmap ::
(Representable p, Representable q) => (forall x. q x -> p x) ->
ArithmeticCircuit a p i o -> ArithmeticCircuit a q i o
hpmap f ac = ac { acWitness = fmap (pmapWitVar f) <$> acWitness ac }
hpmap f ac = ac
{ acWitness = fmap (pmapWitVar f) <$> acWitness ac
, acFold = fmap (pmapWitVar f <$>) <$> acFold ac
}

--------------------------- Symbolic compiler context --------------------------

Expand Down Expand Up @@ -175,13 +216,13 @@ instance

----------------------------- MonadCircuit instance ----------------------------

instance Finite a => Witness (Var a i) (WitnessF a (WitVar p i)) where
instance Finite a => Witness (Var a i) (CircuitWitness a p i) where
at (ConstVar cV) = fromConstant cV
at (LinVar k sV b) = fromConstant k * pure (WSysVar sV) + fromConstant b

instance
( Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i)
, o ~ U1) => MonadCircuit (Var a i) a (WitnessF a (WitVar p i)) (State (ArithmeticCircuit a p i o)) where
(Arithmetic a, Binary a, Binary (Rep p), Binary (Rep i), Ord (Rep i), o ~ U1)
=> MonadCircuit (Var a i) a (CircuitWitness a p i) (State (ArithmeticCircuit a p i o)) where

unconstrained wf = case runWitnessF wf $ \case
WSysVar sV -> LinUVar one sV zero
Expand Down Expand Up @@ -289,15 +330,21 @@ exec ac = eval ac U1 U1

-- | Applies the values of the first couple of inputs to the arithmetic circuit.
apply ::
(Eq a, Field a, Ord (Rep j), Representable i) =>
i a -> ArithmeticCircuit a p (i :*: j) U1 -> ArithmeticCircuit a p j U1
(Eq a, Field a, Ord (Rep j), Representable i, Functor o) =>
i a -> ArithmeticCircuit a p (i :*: j) o -> ArithmeticCircuit a p j o
apply xs ac = ac
{ acSystem = fmap (evalPolynomial evalMonomial varF) (acSystem ac)
, acRange = S.fromList . catMaybes . toList . filterSet <$> acRange ac
, acWitness = (>>= witF) <$> acWitness ac
, acOutput = U1
, acFold = bimap outF (>>= witF) <$> acFold ac
, acOutput = outF <$> acOutput ac
}
where
outF (LinVar k (InVar (Left v)) b) = ConstVar (k * index xs v + b)
outF (LinVar k (InVar (Right v)) b) = LinVar k (InVar v) b
outF (LinVar k (NewVar v) b) = LinVar k (NewVar v) b
outF (ConstVar a) = ConstVar a

varF (InVar (Left v)) = fromConstant (index xs v)
varF (InVar (Right v)) = var (InVar v)
varF (NewVar v) = var (NewVar v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map (
mapVarArithmeticCircuit,
) where

import Data.Functor ((<&>))
import Data.Bifunctor (bimap)
import Data.Functor.Rep (Representable (..))
import Data.Map hiding (drop, foldl, foldr, fromList, map, null,
splitAt, take, toList)
Expand Down Expand Up @@ -34,13 +34,14 @@ mapVarArithmeticCircuit ac =
backward = Map.fromAscList $ zip asc vars
varF (InVar v) = InVar v
varF (NewVar v) = NewVar (forward ! v)
oVarF (LinVar k v b) = LinVar k (varF v) b
oVarF (ConstVar c) = ConstVar c
witF (WSysVar v) = WSysVar (varF v)
witF (WExVar v) = WExVar v
in ArithmeticCircuit
{ acRange = Set.map varF <$> acRange ac
, acSystem = fromList $ zip asc $ evalPolynomial evalMonomial (var . varF) <$> elems (acSystem ac)
, acWitness = (`Map.compose` backward) $ fmap witF <$> acWitness ac
, acOutput = acOutput ac <&> \case
LinVar k v b -> LinVar k (varF v) b
ConstVar c -> ConstVar c
, acWitness = (fmap witF <$> acWitness ac) `Map.compose` backward
, acFold = bimap oVarF (fmap witF) <$> acFold ac
, acOutput = oVarF <$> acOutput ac
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Optimization where

import Data.Bifunctor (bimap)
import Data.Binary (Binary)
import Data.Bool (bool)
import Data.ByteString (ByteString)
import Data.Functor ((<&>))
import Data.Functor.Rep (Representable (..))
import Data.Map hiding (drop, foldl, foldr, map, null, splitAt,
take)
Expand Down Expand Up @@ -33,13 +33,13 @@ import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Witness (Witnes
optimize :: forall a p i o.
(Arithmetic a, Ord (Rep i), Functor o, Binary (Rep i), Binary a, Binary (Rep p)) =>
ArithmeticCircuit a p i o -> ArithmeticCircuit a p i o
optimize (ArithmeticCircuit s r w o) = ArithmeticCircuit {
optimize (ArithmeticCircuit s r w f o) = ArithmeticCircuit {
acSystem = addInVarConstraints newS,
acRange = optRanges vs r,
acWitness = (>>= optWitVar vs) <$> M.filterWithKey (\k _ -> notMember (NewVar k) vs) w,
acOutput = o <&> \case
lv@(LinVar k sV b) -> maybe lv (ConstVar . (\t -> k * t + b)) (M.lookup sV vs)
so -> so}
acFold = optimizeFold . bimap varF (>>= optWitVar vs) <$> f,
acOutput = varF <$> o
}
where
(newS, vs) = varsToReplace (s, M.empty)

Expand All @@ -63,6 +63,13 @@ optimize (ArithmeticCircuit s r w o) = ArithmeticCircuit {
Nothing -> pure $ WSysVar sv
we -> pure we

optimizeFold CircuitFold {..} =
CircuitFold { foldStep = optimize foldStep, .. }

varF lv@(LinVar k sV b) = maybe lv (ConstVar . (\t -> k * t + b)) (M.lookup sV vs)
varF (ConstVar c) = ConstVar c


varsToReplace :: (Arithmetic a, Ord (Rep i)) => (Map ByteString (Constraint a i) , Map (SysVar i) a) -> (Map ByteString (Constraint a i) , Map (SysVar i) a)
varsToReplace (s, l) = if newVars == M.empty then (s, l) else varsToReplace (M.filter (/= zero) $ optimizeSystems newVars s, M.union newVars l)
where
Expand Down
5 changes: 2 additions & 3 deletions symbolic-examples/bench/BenchCompiler.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Main where

import Control.DeepSeq (NFData, force)
import Control.DeepSeq (NFData, NFData1, force)
import Control.Monad (return)
import Data.ByteString.Lazy (ByteString)
import Data.Function (const, ($))
Expand All @@ -25,8 +25,7 @@ metrics name circuit =
<> "\nNumber of range lookups: " <> fromString (show $ acSizeR circuit)

benchmark ::
( Arithmetic a, NFData (o (Var a i)), NFData (Rep i)
, Representable p, Representable i) =>
(Arithmetic a, NFData1 o, NFData (Rep i), Representable p, Representable i) =>
String -> (() -> ArithmeticCircuit a p i o) -> Benchmark
benchmark name circuit = bgroup name
[ bench "compilation" $ nf circuit ()
Expand Down
14 changes: 7 additions & 7 deletions symbolic-examples/src/ZkFold/Symbolic/Examples.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

module ZkFold.Symbolic.Examples (ExampleOutput (..), examples) where

import Control.DeepSeq (NFData)
import Control.DeepSeq (NFData, NFData1)
import Data.Function (const, ($), (.))
import Data.Functor (Functor)
import Data.Functor.Rep (Rep, Representable)
import Data.Proxy (Proxy)
import Data.String (String)
Expand All @@ -27,7 +28,6 @@ import GHC.Generics (Par1, (:*:), (:.:)
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar)
import ZkFold.Symbolic.Compiler (ArithmeticCircuit, compile)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit (Var)
import ZkFold.Symbolic.Data.ByteString (ByteString)
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Combinators (RegisterSize (Auto))
Expand All @@ -39,7 +39,7 @@ type C = ArithmeticCircuit A
data ExampleOutput where
ExampleOutput ::
forall p i o.
(Representable p, Representable i, NFData (Rep i), NFData (o (Var A i))) =>
(Representable p, Representable i, NFData (Rep i), NFData1 o) =>
(() -> C p i o) -> ExampleOutput

exampleOutput ::
Expand All @@ -55,14 +55,14 @@ exampleOutput ::
, Payload (Support f) ~ p
, Representable i
, NFData (Rep i)
, NFData (o (Var A i))
, NFData1 o
) => f -> ExampleOutput
exampleOutput = ExampleOutput @p @i @o . const . compile

-- | TODO: Maybe there is a better place for these orphans?
instance NFData a => NFData (Par1 a)
instance (NFData (f a), NFData (g a)) => NFData ((f :*: g) a)
instance NFData (f (g a)) => NFData ((f :.: g) a)
instance NFData1 Par1
instance (NFData1 f, NFData1 g) => NFData1 (f :*: g)
instance (Functor f, NFData1 f, NFData1 g) => NFData1 (f :.: g)

examples :: [(String, ExampleOutput)]
examples =
Expand Down

0 comments on commit f457833

Please sign in to comment.