Skip to content

Commit

Permalink
Merge pull request #302 from zkFold/hov-symbolic-input
Browse files Browse the repository at this point in the history
Input validation for arithmetizable functions
  • Loading branch information
vlasin authored Oct 30, 2024
2 parents 6566feb + 5d0824d commit 22fc18c
Show file tree
Hide file tree
Showing 40 changed files with 285 additions and 110 deletions.
30 changes: 4 additions & 26 deletions symbolic-base/src/ZkFold/Base/Protocol/Protostar/Fold.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import Control.Lens ((^.))
import Data.Binary (Binary)
import Data.Function ((.))
import Data.Functor (fmap)
import Data.Functor.Rep (Rep, Representable)
import Data.Kind (Type)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Ord (Ord)
import Data.Proxy (Proxy)
import GHC.Generics (Generic, Par1 (..), U1 (..), type (:*:) (..),
type (:.:) (..))
Expand All @@ -43,6 +41,7 @@ import ZkFold.Symbolic.Data.Bool
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Eq
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Input (SymbolicInput)


-- | These instances might seem off, but accumulator scheme requires this exact behaviour for ProverMessages which are Maps in this case.
Expand Down Expand Up @@ -118,34 +117,13 @@ ivcVerifier (i, pi_x, accTuple, acc'Tuple, pf) (a, ckTuple, dkTuple)
ivcVerifierAc
:: forall i f c m ctx a y t
. Symbolic ctx
=> SymbolicData i
=> SymbolicData f
=> SymbolicData c
=> SymbolicData m
=> SymbolicData a
=> SymbolicInput (i, c, (i, c, f, c, f), (i, c, f, c, f), c)
=> SymbolicInput (a, (f, (f, f)), ((i, c, f, c, f), m))
=> SymbolicData y
=> Context i ~ ctx
=> Context f ~ ctx
=> Context c ~ ctx
=> Context m ~ ctx
=> Context a ~ ctx
=> Context i ~ ctx
=> Context y ~ ctx
=> Support i ~ Proxy ctx
=> Support f ~ Proxy ctx
=> Support c ~ Proxy ctx
=> Support m ~ Proxy ctx
=> Support a ~ Proxy ctx
=> Support y ~ Proxy ctx
=> Representable (Layout i)
=> Representable (Layout c)
=> Representable (Layout f)
=> Representable (Layout a)
=> Representable (Layout m)
=> Ord (Rep (Layout i))
=> Ord (Rep (Layout c))
=> Ord (Rep (Layout f))
=> Ord (Rep (Layout a))
=> Ord (Rep (Layout m))
=> Layout y ~ Par1
=> t ~ ((i,c,(i,c,f,c,f),(i,c,f,c,f),c),(a,(f,f,f),(i,c,f,c,f),m),Proxy ctx)
=> ctx ~ ArithmeticCircuit a (Layout t)
Expand Down
44 changes: 21 additions & 23 deletions symbolic-base/src/ZkFold/Symbolic/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ import Data.Function (const, (.))
import Data.Functor (($>))
import Data.Functor.Rep (Rep, Representable)
import Data.Ord (Ord)
import Data.Proxy (Proxy)
import Data.Proxy (Proxy (..))
import Data.Traversable (for)
import GHC.Generics (Par1 (Par1))
import Prelude (FilePath, IO, Monoid (mempty), Show (..), Traversable,
putStrLn, type (~), ($), (++))
putStrLn, return, type (~), ($), (++))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Prelude (writeFileJSON)
import ZkFold.Symbolic.Class (Arithmetic, Symbolic (..))
import ZkFold.Symbolic.Class (Arithmetic, Symbolic (..), fromCircuit2F)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit
import ZkFold.Symbolic.Data.Bool (Bool (Bool))
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Input
import ZkFold.Symbolic.MonadCircuit (MonadCircuit (..))

{-
Expand All @@ -43,21 +46,21 @@ forceOne r = fromCircuitF r (\fi -> for fi $ \i -> constraint (\x -> x i - one)

-- | Arithmetizes an argument by feeding an appropriate amount of inputs.
solder ::
forall a c f s l .
( c ~ ArithmeticCircuit a l
forall a c f s .
( c ~ ArithmeticCircuit a (Layout s)
, SymbolicData f
, Context f ~ c
, Support f ~ s
, SymbolicData s
, SymbolicInput s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Ord (Rep l)
, Symbolic c
) => f -> c (Layout f)
solder f = pieces f (restore @(Support f) $ const inputC)
where
inputC = mempty { acOutput = acInput }
solder f = fromCircuit2F (pieces f input) b $ \r (Par1 i) -> do
constraint (\x -> one - x i)
return r
where
Bool b = isValid input
input = restore @(Support f) $ const mempty { acOutput = acInput }

-- | Compiles function `f` into an arithmetic circuit with all outputs equal to 1.
compileForceOne ::
Expand All @@ -68,9 +71,8 @@ compileForceOne ::
, SymbolicData f
, Context f ~ c
, Support f ~ s
, SymbolicData s
, SymbolicInput s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Binary (Rep l)
Expand All @@ -90,16 +92,14 @@ compile ::
, SymbolicData f
, Context f ~ c
, Support f ~ s
, SymbolicData s
, SymbolicInput s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Ord (Rep l)
, SymbolicData y
, Context y ~ c
, Support y ~ Proxy c
, Layout f ~ Layout y
, Symbolic c
) => f -> y
compile = restore . const . optimize . solder @a

Expand All @@ -113,14 +113,12 @@ compileIO ::
, Context f ~ c
, Support f ~ s
, ToJSON (Layout f (Var a l))
, SymbolicData s
, SymbolicInput s
, Context s ~ c
, Support s ~ Proxy c
, Layout s ~ l
, Representable l
, Ord (Rep l)
, FromJSON (Rep l)
, ToJSON (Rep l)
, Symbolic c
) => FilePath -> f -> IO ()
compileIO scriptFile f = do
let ac = optimize (solder @a f) :: c (Layout f)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE NoStarIsType #-}
Expand Down Expand Up @@ -129,6 +130,7 @@ imapVar _ (ConstVar c) = ConstVar c
acInput :: Representable i => i (Var a i)
acInput = fmapRep (SysVar . InVar) (tabulate id)


getAllVars :: forall a i o. (Representable i, Foldable i) => ArithmeticCircuit a i o -> [SysVar i]
getAllVars ac = toList acInput0 ++ map NewVar (keys $ acWitness ac) where
acInput0 :: i (SysVar i)
Expand Down
25 changes: 23 additions & 2 deletions symbolic-base/src/ZkFold/Symbolic/Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import Control.Monad (replicateM)
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Bits as B
import qualified Data.ByteString as Bytes
import Data.Foldable (foldlM)
import Data.Kind (Type)
import Data.List (reverse, unfoldr)
import Data.Maybe (Maybe (..))
Expand All @@ -34,8 +35,8 @@ import Data.Traversable (for)
import GHC.Generics (Generic, Par1 (..))
import GHC.Natural (naturalFromInteger)
import Numeric (readHex, showHex)
import Prelude (Integer, drop, fmap, otherwise, pure, return, take, type (~), ($),
(.), (<$>), (<), (<>), (==), (>=))
import Prelude (Integer, const, drop, fmap, otherwise, pure, return, take,
type (~), ($), (.), (<$>), (<), (<>), (==), (>=))
import qualified Prelude as Haskell
import Test.QuickCheck (Arbitrary (..), chooseInteger)

Expand All @@ -55,6 +56,7 @@ import ZkFold.Symbolic.Data.Combinators
import ZkFold.Symbolic.Data.Eq (Eq)
import ZkFold.Symbolic.Data.Eq.Structural
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Input (SymbolicInput, isValid)
import ZkFold.Symbolic.Interpreter (Interpreter (..))
import ZkFold.Symbolic.MonadCircuit (ClosedPoly, MonadCircuit, newAssigned)

Expand Down Expand Up @@ -265,6 +267,25 @@ instance

zeroA = Haskell.replicate diff (fromConstant (0 :: Integer ))

instance
( Symbolic c
, KnownNat n
) => SymbolicInput (ByteString n c) where
isValid (ByteString bits) = Bool $ fromCircuitF bits solve
where
solve :: MonadCircuit i (BaseField c) m => Vector n i -> m (Par1 i)
solve v = do
let vs = V.fromVector v
ys <- for vs $ \i -> newAssigned (\p -> p i * (one - p i))
us <-for ys $ \i -> isZero $ Par1 i
helper us

helper :: MonadCircuit i a m => [Par1 i] -> m (Par1 i)
helper xs = case xs of
[] -> Par1 <$> newAssigned (const one)
(b : bs) -> foldlM (\(Par1 v1) (Par1 v2) -> Par1 <$> newAssigned (($ v1) * ($ v2))) b bs


isSet :: forall c n. Symbolic c => ByteString n c -> Natural -> Bool c
isSet (ByteString bits) ix = Bool $ fromCircuitF bits solve
where
Expand Down
3 changes: 3 additions & 0 deletions symbolic-base/src/ZkFold/Symbolic/Data/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,6 @@ runInvert is = do
js <- for is $ \i -> newConstrained (\x j -> x i * x j) (\x -> let xi = x i in one - xi // xi)
ks <- for (mzipRep is js) $ \(i, j) -> newConstrained (\x k -> x i * x k + x j - one) (finv . ($ i))
return (js, ks)

isZero :: (MonadCircuit i a m, Representable f, Traversable f) => f i -> m (f i)
isZero is = Haskell.fst <$> runInvert is
7 changes: 6 additions & 1 deletion symbolic-base/src/ZkFold/Symbolic/Data/FFA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import ZkFold.Base.Data.Utils (zipWithM)
import ZkFold.Base.Data.Vector
import ZkFold.Prelude (iterateM, length)
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool (Bool)
import ZkFold.Symbolic.Data.Bool (Bool, BoolType (..))
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Combinators (expansionW, log2, maxBitsPerFieldElement, splitExpansion)
import ZkFold.Symbolic.Data.Eq
import ZkFold.Symbolic.Data.Input
import ZkFold.Symbolic.Data.Ord (blueprintGE)
import ZkFold.Symbolic.Interpreter
import ZkFold.Symbolic.MonadCircuit (MonadCircuit, newAssigned)
Expand Down Expand Up @@ -212,3 +213,7 @@ instance Finite (Zp p) => Finite (FFA p b) where

-- FIXME: This Eq instance is wrong
deriving newtype instance Symbolic c => Eq (Bool c) (FFA p c)

-- | TODO: fix when rewrite is done
instance (Symbolic c) => SymbolicInput (FFA p c) where
isValid _ = true
6 changes: 5 additions & 1 deletion symbolic-base/src/ZkFold/Symbolic/Data/FieldElement.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ import ZkFold.Base.Data.HFunctor (hmap)
import ZkFold.Base.Data.Par1 ()
import ZkFold.Base.Data.Vector (Vector, fromVector, unsafeToVector)
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool (Bool)
import ZkFold.Symbolic.Data.Bool (Bool, BoolType (true))
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Combinators (expansion, horner, runInvert)
import ZkFold.Symbolic.Data.Eq (Eq)
import ZkFold.Symbolic.Data.Input
import ZkFold.Symbolic.Data.Ord
import ZkFold.Symbolic.MonadCircuit (newAssigned)

Expand Down Expand Up @@ -105,3 +106,6 @@ instance Symbolic c => BinaryExpansion (FieldElement c) where
fromBinary bits =
FieldElement $ symbolicF bits (Par1 . foldr (\x y -> x + y + y) zero)
$ fmap Par1 . horner . fromVector

instance (Symbolic c) => SymbolicInput (FieldElement c) where
isValid _ = true
79 changes: 79 additions & 0 deletions symbolic-base/src/ZkFold/Symbolic/Data/Input.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Data.Input (
SymbolicInput (..)
) where

import Control.Monad.Representable.Reader (Rep)
import Data.Functor.Rep (Representable)
import Data.Ord (Ord)
import Data.Type.Equality (type (~))
import Data.Typeable (Proxy (..))
import GHC.Generics (Par1 (..))
import GHC.TypeLits (KnownNat)
import Prelude (foldl, ($))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Data.ByteString (Binary)
import ZkFold.Base.Data.Vector (Vector, fromVector)
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool
import ZkFold.Symbolic.Data.Class
import ZkFold.Symbolic.Data.Combinators
import ZkFold.Symbolic.MonadCircuit


-- | A class for Symbolic input.
class
( SymbolicData d
, Support d ~ Proxy (Context d)
, Representable (Layout d)
, Binary (Rep (Layout d))
, Ord (Rep (Layout d))
) => SymbolicInput d where
isValid :: d -> Bool (Context d)


instance Symbolic c => SymbolicInput (Bool c) where
isValid (Bool b) = Bool $ fromCircuitF b $
\(Par1 v) -> do
u <- newAssigned (\x -> x v * (one - x v))
isZero $ Par1 u


instance
( Symbolic c
, Binary (Rep f)
, Ord (Rep f)
, Representable f) => SymbolicInput (c f) where
isValid _ = true


instance Symbolic c => SymbolicInput (Proxy c) where
isValid _ = true

instance (
Symbolic (Context x)
, Context x ~ Context y
, SymbolicInput x
, SymbolicInput y
) => SymbolicInput (x, y) where
isValid (l, r) = isValid l && isValid r

instance (
Symbolic (Context x)
, Context x ~ Context y
, Context y ~ Context z
, SymbolicInput x
, SymbolicInput y
, SymbolicInput z
) => SymbolicInput (x, y, z) where
isValid (l, m, r) = isValid l && isValid m && isValid r

instance (
Symbolic (Context x)
, KnownNat n
, SymbolicInput x
) => SymbolicInput (Vector n x) where
isValid v = foldl (\l r -> l && isValid r) true $ fromVector v
Loading

0 comments on commit 22fc18c

Please sign in to comment.