Skip to content

Commit

Permalink
Merge pull request #334 from zkFold/TurtlePU/Witness-class
Browse files Browse the repository at this point in the history
+ Witness class for use in MonadCircuit
  • Loading branch information
vlasin authored Nov 8, 2024
2 parents 7e43e1d + 15cfc60 commit 725e9c6
Show file tree
Hide file tree
Showing 13 changed files with 349 additions and 311 deletions.
14 changes: 7 additions & 7 deletions symbolic-base/src/ZkFold/Symbolic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import ZkFold.Symbolic.MonadCircuit
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.
type CircuitFun f g a = forall i m. MonadCircuit i a m => f i -> m (g i)
type CircuitFun f g a = forall i w m. MonadCircuit i a w m => f i -> m (g i)

-- | A Symbolic DSL for performant pure computations with arithmetic circuits.
-- @c@ is a generic context in which computations are performed.
Expand Down Expand Up @@ -59,37 +59,37 @@ embed cs = fromCircuitF hunit (\_ -> return (fromConstant <$> cs))

symbolic2F ::
(Symbolic c, BaseField c ~ a) => c f -> c g -> (f a -> g a -> h a) ->
(forall i m. MonadCircuit i a m => f i -> g i -> m (h i)) -> c h
(forall i w m. MonadCircuit i a w m => f i -> g i -> m (h i)) -> c h
-- | Runs the binary function from @f@ and @g@ into @h@ in a generic context @c@.
symbolic2F x y f m = symbolicF (hpair x y) (uncurryP f) (uncurryP m)

fromCircuit2F ::
Symbolic c => c f -> c g ->
(forall i m. MonadCircuit i (BaseField c) m => f i -> g i -> m (h i)) -> c h
(forall i w m. MonadCircuit i (BaseField c) w m => f i -> g i -> m (h i)) -> c h
-- | Runs the binary @'CircuitFun'@ in a generic context.
fromCircuit2F x y m = fromCircuitF (hpair x y) (uncurryP m)

symbolic3F ::
(Symbolic c, BaseField c ~ a) => c f -> c g -> c h -> (f a -> g a -> h a -> k a) ->
(forall i m. MonadCircuit i a m => f i -> g i -> h i -> m (k i)) -> c k
(forall i w m. MonadCircuit i a w m => f i -> g i -> h i -> m (k i)) -> c k
-- | Runs the ternary function from @f@, @g@ and @h@ into @k@ in a context @c@.
symbolic3F x y z f m = symbolic2F (hpair x y) z (uncurryP f) (uncurryP m)

fromCircuit3F ::
Symbolic c => c f -> c g -> c h ->
(forall i m. MonadCircuit i (BaseField c) m => f i -> g i -> h i -> m (k i)) -> c k
(forall i w m. MonadCircuit i (BaseField c) w m => f i -> g i -> h i -> m (k i)) -> c k
-- | Runs the ternary @'CircuitFun'@ in a generic context.
fromCircuit3F x y z m = fromCircuit2F (hpair x y) z (uncurryP m)

symbolicVF ::
(Symbolic c, BaseField c ~ a, Foldable f, Functor f) =>
f (c g) -> (f (g a) -> h a) ->
(forall i m. MonadCircuit i a m => f (g i) -> m (h i)) -> c h
(forall i w m. MonadCircuit i a w m => f (g i) -> m (h i)) -> c h
-- | Given a generic context @c@, runs the function from @f@ many @c g@'s into @c h@.
symbolicVF xs f m = symbolicF (pack xs) (f . unComp1) (m . unComp1)

fromCircuitVF ::
(Symbolic c, Foldable f, Functor f) => f (c g) ->
(forall i m. MonadCircuit i (BaseField c) m => f (g i) -> m (h i)) -> c h
(forall i w m. MonadCircuit i (BaseField c) w m => f (g i) -> m (h i)) -> c h
-- | Given a generic context @c@, runs the @'CircuitFun'@ from @f@ many @c g@'s into @c h@.
fromCircuitVF xs m = fromCircuitF (pack xs) (m . unComp1)
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ import ZkFold.Symbolic.MonadCircuit (MonadCircu
optimize :: ArithmeticCircuit a i o -> ArithmeticCircuit a i o
optimize = id

desugarRange :: (Arithmetic a, MonadCircuit i a m) => i -> a -> m ()
desugarRange :: (Arithmetic a, MonadCircuit i a w m) => i -> a -> m ()
desugarRange i b
| b == negate one = return ()
| otherwise = do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ arbitrary' ac iter = do
createRangeConstraint :: Symbolic c => FieldElement c -> BaseField c -> FieldElement c
createRangeConstraint (FieldElement x) a = FieldElement $ fromCircuitF x (\ (Par1 v) -> Par1 <$> solve v a)
where
solve :: MonadCircuit var a m => var -> a -> m var
solve :: MonadCircuit var a w m => var -> a -> m var
solve v b = do
v' <- newAssigned (Haskell.const zero)
rangeConstraint v' b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import Data.Semialign (unzipDef
import Data.Semigroup.Generic (GenericSemigroupMonoid (..))
import qualified Data.Set as S
import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..))
import Optics
import Optics hiding (at)
import Prelude hiding (Num (..), drop, length, product, splitAt,
sum, take, (!!), (^))

Expand All @@ -56,6 +56,7 @@ import ZkFold.Base.Data.HFunctor
import ZkFold.Base.Data.Package
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MerkleHash
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Witness
import ZkFold.Symbolic.MonadCircuit

-- | The type that represents a constraint in the arithmetic circuit.
Expand Down Expand Up @@ -186,12 +187,12 @@ instance

instance
( Arithmetic a, Binary a, Representable i, Binary (Rep i), Ord (Rep i)
, o ~ U1) => MonadCircuit (Var a i) a (State (ArithmeticCircuit a i o)) where
, o ~ U1) => MonadCircuit (Var a i) a (WitnessF (Var a i) a) (State (ArithmeticCircuit a i o)) where

unconstrained witness = do
let v = toVar @a witness
unconstrained wf = do
let v = toVar @a wf
-- TODO: forbid reassignment of variables
zoom #acWitness . modify $ insert v $ \i w -> witness $ \case
zoom #acWitness . modify $ insert v $ \i w -> witnessF wf $ \case
SysVar (InVar inV) -> index i inV
SysVar (NewVar newV) -> w ! newV
ConstVar cV -> fromConstant cV
Expand All @@ -203,7 +204,7 @@ instance
SysVar sysV -> var sysV
ConstVar cV -> fromConstant cV
in
zoom #acSystem . modify $ insert (toVar @a p) (p evalConstVar)
zoom #acSystem . modify $ insert (toVar (p at)) (p evalConstVar)

rangeConstraint (SysVar v) upperBound =
zoom #acRange . modify $ insertWith S.union upperBound (S.singleton v)
Expand Down Expand Up @@ -233,8 +234,8 @@ instance
-- 'WitnessField' is a root hash of a Merkle tree for a witness.
toVar ::
forall a i. (Finite a, Binary a, Binary (Rep i)) =>
Witness (Var a i) a -> ByteString
toVar witness = runHash @(Just (Order a)) $ witness $ \case
WitnessF (Var a i) a -> ByteString
toVar (WitnessF w) = runHash @(Just (Order a)) $ w $ \case
SysVar (InVar inV) -> merkleHash inV
SysVar (NewVar newV) -> M newV
ConstVar cV -> fromConstant cV
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{-# LANGUAGE DerivingStrategies #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Witness where

import Data.Function ((.))
import Numeric.Natural (Natural)
import Prelude (Integer)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Symbolic.MonadCircuit

type IsWitness a n w = (Scale a w, FromConstant a w, WitnessField n w)

newtype WitnessF v a = WitnessF { witnessF :: forall n w. IsWitness a n w => (v -> w) -> w }

instance FromConstant Natural (WitnessF v a) where fromConstant x = WitnessF (fromConstant x)
instance FromConstant Integer (WitnessF v a) where fromConstant x = WitnessF (fromConstant x)
instance FromConstant a (WitnessF v a) where fromConstant x = WitnessF (fromConstant x)
instance Scale Natural (WitnessF v a) where scale k (WitnessF f) = WitnessF (scale k f)
instance Scale Integer (WitnessF v a) where scale k (WitnessF f) = WitnessF (scale k f)
instance Scale a (WitnessF v a) where scale k (WitnessF f) = WitnessF (scale k . f)
instance Exponent (WitnessF v a) Natural where WitnessF f ^ p = WitnessF (f ^ p)
instance Exponent (WitnessF v a) Integer where WitnessF f ^ p = WitnessF (f ^ p)
instance AdditiveSemigroup (WitnessF v a) where WitnessF f + WitnessF g = WitnessF (f + g)
instance AdditiveMonoid (WitnessF v a) where zero = WitnessF zero
instance AdditiveGroup (WitnessF v a) where
negate (WitnessF f) = WitnessF (negate f)
WitnessF f - WitnessF g = WitnessF (f - g)
instance MultiplicativeSemigroup (WitnessF v a) where WitnessF f * WitnessF g = WitnessF (f * g)
instance MultiplicativeMonoid (WitnessF v a) where one = WitnessF one
instance Semiring (WitnessF v a)
instance Ring (WitnessF v a)
instance Field (WitnessF v a) where
finv (WitnessF f) = WitnessF (finv . f)
WitnessF f // WitnessF g = WitnessF (\x -> f x // g x)
instance ToConstant (WitnessF v a) where
type Const (WitnessF v a) = EuclideanF v a
toConstant (WitnessF f) = EuclideanF (toConstant . f)
instance FromConstant (EuclideanF v a) (WitnessF v a) where fromConstant (EuclideanF f) = WitnessF (fromConstant . f)
instance Finite a => Finite (WitnessF v a) where type Order (WitnessF v a) = Order a

newtype EuclideanF v a = EuclideanF { euclideanF :: forall n w. IsWitness a n w => (v -> w) -> n }

instance FromConstant Natural (EuclideanF v a) where fromConstant x = EuclideanF (fromConstant x)
instance Scale Natural (EuclideanF v a) where scale k (EuclideanF f) = EuclideanF (scale k f)
instance Exponent (EuclideanF v a) Natural where EuclideanF f ^ p = EuclideanF (f ^ p)
instance AdditiveSemigroup (EuclideanF v a) where EuclideanF f + EuclideanF g = EuclideanF (f + g)
instance AdditiveMonoid (EuclideanF v a) where zero = EuclideanF zero
instance MultiplicativeSemigroup (EuclideanF v a) where EuclideanF f * EuclideanF g = EuclideanF (f * g)
instance MultiplicativeMonoid (EuclideanF v a) where one = EuclideanF one
instance Semiring (EuclideanF v a)
instance SemiEuclidean (EuclideanF v a) where
EuclideanF f `div` EuclideanF g = EuclideanF (\x -> f x `div` g x)
EuclideanF f `mod` EuclideanF g = EuclideanF (\x -> f x `mod` g x)

instance Arithmetic a => Witness v a (WitnessF v a) where
at i = WitnessF (\x -> x i)
129 changes: 55 additions & 74 deletions symbolic-base/src/ZkFold/Symbolic/Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import Data.Kind (Type)
import Data.List (reverse, unfoldr)
import Data.Maybe (Maybe (..))
import Data.String (IsString (..))
import Data.Traversable (for)
import Data.Traversable (for, mapM)
import GHC.Generics (Generic, Par1 (..))
import GHC.Natural (naturalFromInteger)
import Numeric (readHex, showHex)
Expand All @@ -59,7 +59,7 @@ import ZkFold.Symbolic.Data.Eq.Structural
import ZkFold.Symbolic.Data.FieldElement (FieldElement)
import ZkFold.Symbolic.Data.Input (SymbolicInput, isValid)
import ZkFold.Symbolic.Interpreter (Interpreter (..))
import ZkFold.Symbolic.MonadCircuit (ClosedPoly, MonadCircuit, newAssigned)
import ZkFold.Symbolic.MonadCircuit (ClosedPoly, newAssigned)

-- | A ByteString which stores @n@ bits and uses elements of @a@ as registers, one element per register.
-- Bit layout is Big-endian.
Expand Down Expand Up @@ -168,15 +168,9 @@ instance (Symbolic c, KnownNat n) => BoolType (ByteString n c) where
false = fromConstant (0 :: Natural)
true = not false

not (ByteString bits) = ByteString $ fromCircuitF bits solve
where
solve :: MonadCircuit i (BaseField c) m => Vector n i -> m (Vector n i)
solve xv = do
let xs = V.fromVector xv
ys <- for xs $ \i -> newAssigned (\p -> one - p i)
return $ V.unsafeToVector ys
not (ByteString bits) = ByteString $ fromCircuitF bits $ mapM (\i -> newAssigned (\p -> one - p i))

l || r = bitwiseOperation l r cons
l || r = bitwiseOperation l r cons
where
cons i j x =
let xi = x i
Expand All @@ -190,21 +184,22 @@ instance (Symbolic c, KnownNat n) => BoolType (ByteString n c) where
xj = x j
in xi * xj

xor (ByteString l) (ByteString r) = ByteString $ symbolic2F l r (\x y -> V.unsafeToVector $ fromConstant <$> toBsBits (vecToNat x `B.xor` vecToNat y) (value @n)) solve
where
vecToNat :: (ToConstant a, Const a ~ Natural) => Vector n a -> Natural
vecToNat = Haskell.foldl (\x p -> toConstant p + 2 * x :: Natural) 0

solve :: MonadCircuit i (BaseField c) m => Vector n i -> Vector n i -> m (Vector n i)
solve lv rv = do
xor (ByteString l) (ByteString r) =
ByteString $ symbolic2F l r
(\x y -> V.unsafeToVector $ fromConstant <$> toBsBits (vecToNat x `B.xor` vecToNat y) (value @n))
(\lv rv -> do
let varsLeft = lv
varsRight = rv
zipWithM (\i j -> newAssigned $ cons i j) varsLeft varsRight
)
where
vecToNat :: (ToConstant a, Const a ~ Natural) => Vector n a -> Natural
vecToNat = Haskell.foldl (\x p -> toConstant p + 2 * x :: Natural) 0

cons i j x =
let xi = x i
xj = x j
in xi + xj - (xi * xj + xi * xj)
cons i j x =
let xi = x i
xj = x j
in xi + xj - (xi * xj + xi * xj)

-- | A ByteString of length @n@ can only be split into words of length @wordSize@ if all of the following conditions are met:
-- 1. @wordSize@ is not greater than @n@;
Expand Down Expand Up @@ -234,18 +229,18 @@ instance (Symbolic c, KnownNat n) => ShiftBits (ByteString n c) where
shiftBits bs@(ByteString oldBits) s
| s == 0 = bs
| Haskell.abs s >= Haskell.fromIntegral (getNatural @n) = false
| otherwise = ByteString $ symbolicF oldBits (\v -> V.shift v s (fromConstant (0 :: Integer))) solve
where
solve :: forall a m. MonadCircuit a (BaseField c) m => Vector n a -> m (Vector n a)
solve bitsV = do
let bits = V.fromVector bitsV
zeros <- replicateM (Haskell.fromIntegral $ Haskell.abs s) $ newAssigned (Haskell.const zero)
| otherwise = ByteString $ symbolicF oldBits
(\v -> V.shift v s (fromConstant (0 :: Integer)))
(\bitsV -> do
let bits = V.fromVector bitsV
zeros <- replicateM (Haskell.fromIntegral $ Haskell.abs s) $ newAssigned (Haskell.const zero)

let newBits = case s < 0 of
Haskell.True -> take (Haskell.fromIntegral $ getNatural @n) $ zeros <> bits
Haskell.False -> drop (Haskell.fromIntegral s) $ bits <> zeros
let newBits = case s < 0 of
Haskell.True -> take (Haskell.fromIntegral $ getNatural @n) $ zeros <> bits
Haskell.False -> drop (Haskell.fromIntegral s) $ bits <> zeros

pure $ V.unsafeToVector newBits
pure $ V.unsafeToVector newBits
)

rotateBits (ByteString bits) s = ByteString $ hmap (`V.rotate` s) bits

Expand All @@ -254,58 +249,46 @@ instance
, KnownNat k
, KnownNat n
) => Resize (ByteString k c) (ByteString n c) where
resize (ByteString oldBits) = ByteString $ symbolicF oldBits (\v -> V.unsafeToVector $ zeroA <> takeMin (V.fromVector v)) solve
where
solve :: forall i m. MonadCircuit i (BaseField c) m => Vector k i -> m (Vector n i)
solve bitsV = do
resize (ByteString oldBits) = ByteString $ symbolicF oldBits
(\v -> V.unsafeToVector $ zeroA <> takeMin (V.fromVector v))
(\bitsV -> do
let bits = V.fromVector bitsV
zeros <- replicateM diff $ newAssigned (Haskell.const zero)
return $ V.unsafeToVector $ zeros <> takeMin bits
)
where
diff :: Haskell.Int
diff = Haskell.fromIntegral (getNatural @n) Haskell.- Haskell.fromIntegral (getNatural @k)

diff :: Haskell.Int
diff = Haskell.fromIntegral (getNatural @n) Haskell.- Haskell.fromIntegral (getNatural @k)

takeMin :: [a] -> [a]
takeMin = Haskell.take (Haskell.min (Haskell.fromIntegral $ getNatural @n) (Haskell.fromIntegral $ getNatural @k))
takeMin :: [a] -> [a]
takeMin = Haskell.take (Haskell.min (Haskell.fromIntegral $ getNatural @n) (Haskell.fromIntegral $ getNatural @k))

zeroA = Haskell.replicate diff (fromConstant (0 :: Integer ))
zeroA = Haskell.replicate diff (fromConstant (0 :: Integer ))

instance
( Symbolic c
, KnownNat n
) => SymbolicInput (ByteString n c) where
isValid (ByteString bits) = Bool $ fromCircuitF bits solve
where
solve :: MonadCircuit i (BaseField c) m => Vector n i -> m (Par1 i)
solve v = do
let vs = V.fromVector v
ys <- for vs $ \i -> newAssigned (\p -> p i * (one - p i))
us <-for ys $ \i -> isZero $ Par1 i
helper us

helper :: MonadCircuit i a m => [Par1 i] -> m (Par1 i)
helper xs = case xs of

isValid (ByteString bits) = Bool $ fromCircuitF bits $ \v -> do
let vs = V.fromVector v
ys <- for vs $ \i -> newAssigned (\p -> p i * (one - p i))
us <-for ys $ \i -> isZero $ Par1 i
case us of
[] -> Par1 <$> newAssigned (const one)
(b : bs) -> foldlM (\(Par1 v1) (Par1 v2) -> Par1 <$> newAssigned (($ v1) * ($ v2))) b bs


isSet :: forall c n. Symbolic c => ByteString n c -> Natural -> Bool c
isSet (ByteString bits) ix = Bool $ fromCircuitF bits solve
where
solve :: forall i m . MonadCircuit i (BaseField c) m => Vector n i -> m (Par1 i)
solve v = do
let vs = V.fromVector v
return $ Par1 $ (!! ix) vs
isSet (ByteString bits) ix = Bool $ fromCircuitF bits $ \v -> do
let vs = V.fromVector v
return $ Par1 $ (!! ix) vs

isUnset :: forall c n. Symbolic c => ByteString n c -> Natural -> Bool c
isUnset (ByteString bits) ix = Bool $ fromCircuitF bits solve
where
solve :: forall i m . MonadCircuit i (BaseField c) m => Vector n i -> m (Par1 i)
solve v = do
let vs = V.fromVector v
i = (!! ix) vs
j <- newAssigned $ \p -> one - p i
return $ Par1 j
isUnset (ByteString bits) ix = Bool $ fromCircuitF bits $ \v -> do
let vs = V.fromVector v
i = (!! ix) vs
j <- newAssigned $ \p -> one - p i
return $ Par1 j

--------------------------------------------------------------------------------

Expand Down Expand Up @@ -335,13 +318,11 @@ bitwiseOperation
-> ByteString n c
-> (forall i. i -> i -> ClosedPoly i (BaseField c))
-> ByteString n c
bitwiseOperation (ByteString bits1) (ByteString bits2) cons = ByteString $ fromCircuit2F bits1 bits2 solve
where
solve :: MonadCircuit i (BaseField c) m => Vector n i -> Vector n i -> m (Vector n i)
solve lv rv = do
let varsLeft = lv
varsRight = rv
zipWithM (\i j -> newAssigned $ cons i j) varsLeft varsRight
bitwiseOperation (ByteString bits1) (ByteString bits2) cons =
ByteString $ fromCircuit2F bits1 bits2 $ \lv rv -> do
let varsLeft = lv
varsRight = rv
zipWithM (\i j -> newAssigned $ cons i j) varsLeft varsRight

instance (Symbolic c, NumberOfBits (BaseField c) ~ n) => Iso (FieldElement c) (ByteString n c) where
from = ByteString . binaryExpansion
Expand Down
Loading

0 comments on commit 725e9c6

Please sign in to comment.