diff --git a/examples/Examples/Eq.hs b/examples/Examples/Eq.hs index dce3e323e..78beef5c7 100644 --- a/examples/Examples/Eq.hs +++ b/examples/Examples/Eq.hs @@ -7,10 +7,10 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) import ZkFold.Symbolic.Data.Eq (Eq (..)) -import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Data.FieldElement (FieldElement) -- | (==) operation diff --git a/examples/Examples/LEQ.hs b/examples/Examples/LEQ.hs index 077a6e112..94ec0e436 100644 --- a/examples/Examples/LEQ.hs +++ b/examples/Examples/LEQ.hs @@ -7,13 +7,14 @@ import Prelude hiding (Bool, Eq (. import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_Scalar) +import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler -import ZkFold.Symbolic.Data.Bool (Bool (..)) +import ZkFold.Symbolic.Data.Bool (Bool) import ZkFold.Symbolic.Data.FieldElement (FieldElement) -import ZkFold.Symbolic.Data.Ord (Ord (..)) +import ZkFold.Symbolic.Data.Ord ((<=)) -- | (<=) operation -leq :: Ord (Bool c) (FieldElement c) => FieldElement c -> FieldElement c -> Bool c +leq :: Symbolic c => FieldElement c -> FieldElement c -> Bool c leq x y = x <= y exampleLEQ :: IO () diff --git a/src/ZkFold/Base/Protocol/ARK/Plonk.hs b/src/ZkFold/Base/Protocol/ARK/Plonk.hs index c3c3e55e0..a5114b29e 100644 --- a/src/ZkFold/Base/Protocol/ARK/Plonk.hs +++ b/src/ZkFold/Base/Protocol/ARK/Plonk.hs @@ -106,11 +106,7 @@ instance forall n l c1 c2 t plonk f g1. , KnownNat l , KnownNat (PlonkPermutationSize n) , KnownNat (PlonkPolyExtendedLength n) - , Eq (ScalarField c1) - , Scale (ScalarField c1) (ScalarField c1) - , BinaryExpansion (ScalarField c1) - , Bits (ScalarField c1) ~ [ScalarField c1] - , FiniteField (ScalarField c1) + , Arithmetic f , AdditiveGroup (BaseField c1) , Pairing c1 c2 , ToTranscript t (ScalarField c1) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs index 7e1fa895b..5e05ae1ec 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Combinators.hs @@ -68,7 +68,7 @@ embedVar x = newAssigned $ const (fromConstant x) embedAll :: forall a n . (Arithmetic a, KnownNat n) => a -> ArithmeticCircuit a (Vector n) embedAll x = circuitF $ Vector <$> replicateM (fromIntegral $ value @n) (newAssigned $ const (fromConstant x)) -expansion :: MonadBlueprint i a m => Natural -> i -> m [i] +expansion :: MonadCircuit i a m => Natural -> i -> m [i] -- ^ @expansion n k@ computes a binary expansion of @k@ if it fits in @n@ bits. expansion n k = do bits <- bitsOf n k @@ -88,7 +88,7 @@ splitExpansion n1 n2 k = do constraint (\x -> x k - x l - scale (2 ^ n1 :: Natural) (x h)) return (l, h) -bitsOf :: MonadBlueprint i a m => Natural -> i -> m [i] +bitsOf :: MonadCircuit i a m => Natural -> i -> m [i] -- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller -- bits of @k@. bitsOf n k = for [0 .. n -! 1] $ \j -> @@ -97,7 +97,7 @@ bitsOf n k = for [0 .. n -! 1] $ \j -> repr :: forall b . (BinaryExpansion b, Bits b ~ [b], Finite b) => b -> [b] repr = padBits (numberOfBits @b) . binaryExpansion -horner :: MonadBlueprint i a m => [i] -> m i +horner :: MonadCircuit i a m => [i] -> m i -- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using -- Horner's scheme. horner xs = case reverse xs of diff --git a/src/ZkFold/Symbolic/Data/Eq/Structural.hs b/src/ZkFold/Symbolic/Data/Eq/Structural.hs index fbe0ab248..9b57d166e 100644 --- a/src/ZkFold/Symbolic/Data/Eq/Structural.hs +++ b/src/ZkFold/Symbolic/Data/Eq/Structural.hs @@ -5,9 +5,9 @@ module ZkFold.Symbolic.Data.Eq.Structural where import Prelude (type (~)) +import ZkFold.Symbolic.Class import ZkFold.Symbolic.Data.Bool import ZkFold.Symbolic.Data.Class -import ZkFold.Symbolic.Class import ZkFold.Symbolic.Data.Eq newtype Structural a = Structural a diff --git a/src/ZkFold/Symbolic/Data/Ord.hs b/src/ZkFold/Symbolic/Data/Ord.hs index 1b999a21b..d26454ae0 100644 --- a/src/ZkFold/Symbolic/Data/Ord.hs +++ b/src/ZkFold/Symbolic/Data/Ord.hs @@ -3,28 +3,29 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module ZkFold.Symbolic.Data.Ord (Ord (..), Lexicographical (..), blueprintGE, circuitGE, circuitGT, getBitsBE) where - -import Control.Monad (foldM) -import qualified Data.Bool as Haskell -import Data.Foldable (Foldable) -import Data.Function ((.)) -import qualified Data.Zip as Z -import GHC.Generics (Par1 (..)) -import Prelude (type (~), ($)) -import qualified Prelude as Haskell +module ZkFold.Symbolic.Data.Ord (Ord (..), Lexicographical (..), blueprintGE, bitwiseGE, bitwiseGT, getBitsBE) where + +import Control.Monad (foldM) +import qualified Data.Bool as Haskell +import Data.Foldable (Foldable, toList) +import Data.Function ((.)) +import Data.Functor ((<$>)) +import qualified Data.Zip as Z +import GHC.Generics (Par1 (..)) +import Prelude (type (~), ($)) +import qualified Prelude as Haskell import ZkFold.Base.Algebra.Basic.Class -import ZkFold.Base.Data.HFunctor (hmap) -import qualified ZkFold.Base.Data.Vector as V -import ZkFold.Symbolic.Compiler -import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint (MonadBlueprint (..), circuit) -import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..)) +import ZkFold.Base.Data.HFunctor (hmap) +import qualified ZkFold.Base.Data.Vector as V +import ZkFold.Base.Data.Vector (unsafeToVector) +import ZkFold.Symbolic.Class (Symbolic (BaseField, symbolicF), symbolic2F) +import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (expansion) +import ZkFold.Symbolic.Data.Bool (Bool (..)) import ZkFold.Symbolic.Data.Class -import ZkFold.Symbolic.Data.Conditional (Conditional (..)) -import ZkFold.Symbolic.Data.FieldElement (FieldElement (..)) -import ZkFold.Symbolic.Interpreter (Interpreter (..)) -import ZkFold.Symbolic.MonadCircuit (Arithmetic, newAssigned) +import ZkFold.Symbolic.Data.Conditional (Conditional (..)) +import ZkFold.Symbolic.Data.FieldElement (FieldElement (..)) +import ZkFold.Symbolic.MonadCircuit (MonadCircuit, newAssigned) -- TODO (Issue #23): add `compare` class Ord b a where @@ -55,79 +56,68 @@ instance Haskell.Ord a => Ord Haskell.Bool a where min = Haskell.min -toValue :: Interpreter a Par1 -> a -toValue (Interpreter (Par1 v)) = v - -fromValue :: a -> Interpreter a Par1 -fromValue = Interpreter Haskell.. Par1 - -instance (Arithmetic a, Haskell.Ord a) => Ord (Bool (Interpreter a)) (Interpreter a Par1) where - (toValue -> x) <= (toValue -> y) = Haskell.bool false true (x Haskell.<= y) - (toValue -> x) < (toValue -> y) = Haskell.bool false true (x Haskell.< y) - (toValue -> x) >= (toValue -> y) = Haskell.bool false true (x Haskell.>= y) - (toValue -> x) > (toValue -> y) = Haskell.bool false true (x Haskell.> y) - (toValue -> x) `max` (toValue -> y) = fromValue $ Haskell.max x y - (toValue -> x) `min` (toValue -> y) = fromValue $ Haskell.min x y - newtype Lexicographical a = Lexicographical a -- ^ A newtype wrapper for easy definition of Ord instances -- (though not necessarily a most effective one) -deriving newtype instance SymbolicData c x => SymbolicData c (Lexicographical x) - -deriving via (Lexicographical (ArithmeticCircuit a Par1)) - instance Arithmetic a => Ord (Bool (ArithmeticCircuit a)) (ArithmeticCircuit a Par1) +deriving newtype instance SymbolicData c a => SymbolicData c (Lexicographical a) -deriving newtype instance (Arithmetic a, Haskell.Ord a) => Ord (Bool (Interpreter a)) (FieldElement (Interpreter a)) -deriving newtype instance Arithmetic a => Ord (Bool (ArithmeticCircuit a)) (FieldElement (ArithmeticCircuit a)) +deriving via (Lexicographical (FieldElement c)) + instance Symbolic c => Ord (Bool c) (FieldElement c) -- | Every @SymbolicData@ type can be compared lexicographically. instance - ( Arithmetic a - , SymbolicData (ArithmeticCircuit a) x - , Support (ArithmeticCircuit a) x ~ () - , TypeSize (ArithmeticCircuit a) x ~ 1 - ) => Ord (Bool (ArithmeticCircuit a)) (Lexicographical x) where + ( Symbolic c + , SymbolicData c x + , Support c x ~ () + , TypeSize c x ~ 1 + ) => Ord (Bool c) (Lexicographical x) where x <= y = y >= x x < y = y > x - x >= y = circuitGE (getBitsBE x) (getBitsBE y) + x >= y = bitwiseGE (getBitsBE x) (getBitsBE y) - x > y = circuitGT (getBitsBE x) (getBitsBE y) + x > y = bitwiseGT (getBitsBE x) (getBitsBE y) - max x y = bool @(Bool (ArithmeticCircuit a)) x y $ x < y + max x y = bool @(Bool c) x y $ x < y - min x y = bool @(Bool (ArithmeticCircuit a)) x y $ x > y + min x y = bool @(Bool c) x y $ x > y -getBitsBE :: forall c a x . (Arithmetic a, c ~ ArithmeticCircuit a, SymbolicData c x, Support c x ~ (), TypeSize c x ~ 1) => x -> c (V.Vector (NumberOfBits a)) +getBitsBE :: + forall c x . + (Symbolic c, SymbolicData c x, Support c x ~ (), TypeSize c x ~ 1) => + x -> c (V.Vector (NumberOfBits (BaseField c))) -- ^ @getBitsBE x@ returns a list of circuits computing bits of @x@, eldest to -- youngest. -getBitsBE x = let expansion = binaryExpansion $ hmap (Par1 . V.item) (pieces @c @x x ()) - in expansion { acOutput = V.reverse $ acOutput expansion } +getBitsBE x = + hmap unsafeToVector + $ symbolicF (pieces x ()) (binaryExpansion . V.item) + $ expansion (numberOfBits @(BaseField c)) . V.item -circuitGE :: forall a f . (Arithmetic a, Z.Zip f, Foldable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f -> Bool (ArithmeticCircuit a) +bitwiseGE :: forall c f . (Symbolic c, Z.Zip f, Foldable f) => c f -> c f -> Bool c -- ^ Given two lists of bits of equal length, compares them lexicographically. -circuitGE xs ys = Bool $ circuit $ do - is <- runCircuit xs - js <- runCircuit ys - blueprintGE is js +bitwiseGE xs ys = Bool $ + symbolic2F xs ys + (\us vs -> Par1 $ Haskell.bool zero one (toList us Haskell.>= toList vs)) + $ \is js -> Par1 <$> blueprintGE is js -blueprintGE :: (MonadBlueprint i a m, Z.Zip f, Foldable f) => f i -> f i -> m i +blueprintGE :: (MonadCircuit i a m, Z.Zip f, Foldable f) => f i -> f i -> m i blueprintGE xs ys = do (_, hasNegOne) <- circuitDelta xs ys newAssigned $ \p -> one - p hasNegOne -circuitGT :: forall a f . (Arithmetic a, Z.Zip f, Foldable f) => ArithmeticCircuit a f -> ArithmeticCircuit a f -> Bool (ArithmeticCircuit a) +bitwiseGT :: forall c f . (Symbolic c, Z.Zip f, Foldable f) => c f -> c f -> Bool c -- ^ Given two lists of bits of equal length, compares them lexicographically. -circuitGT xs ys = Bool $ circuit $ do - is <- runCircuit xs - js <- runCircuit ys - (hasOne, hasNegOne) <- circuitDelta is js - newAssigned $ \p -> p hasOne * (one - p hasNegOne) - -circuitDelta :: forall i a m f . (MonadBlueprint i a m, Z.Zip f, Foldable f) => f i -> f i -> m (i, i) +bitwiseGT xs ys = Bool $ + symbolic2F xs ys + (\us vs -> Par1 $ Haskell.bool zero one (toList us Haskell.> toList vs)) + $ \is js -> do + (hasOne, hasNegOne) <- circuitDelta is js + Par1 <$> newAssigned (\p -> p hasOne * (one - p hasNegOne)) + +circuitDelta :: forall i a m f . (MonadCircuit i a m, Z.Zip f, Foldable f) => f i -> f i -> m (i, i) circuitDelta l r = do z1 <- newAssigned (Haskell.const zero) z2 <- newAssigned (Haskell.const zero) diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 7514f3ee3..8ba35eae8 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -293,12 +293,12 @@ instance (Arithmetic a, KnownNat n, KnownRegisterSize r, KnownNat (NumberOfRegis u1 >= u2 = let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a) ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a) - in circuitGE rs1 rs2 + in bitwiseGE rs1 rs2 u1 > u2 = let ByteString rs1 = from u1 :: ByteString n (ArithmeticCircuit a) ByteString rs2 = from u2 :: ByteString n (ArithmeticCircuit a) - in circuitGT rs1 rs2 + in bitwiseGT rs1 rs2 max x y = bool @(Bool (ArithmeticCircuit a)) x y $ x < y diff --git a/src/ZkFold/Symbolic/MonadCircuit.hs b/src/ZkFold/Symbolic/MonadCircuit.hs index abeacf73e..b4de5df48 100644 --- a/src/ZkFold/Symbolic/MonadCircuit.hs +++ b/src/ZkFold/Symbolic/MonadCircuit.hs @@ -9,6 +9,7 @@ import Data.Eq (Eq) import Data.Function (id) import Data.Functor (Functor) import Data.Functor.Identity (Identity (..)) +import Data.Ord (Ord) import Data.Type.Equality (type (~)) import ZkFold.Base.Algebra.Basic.Class @@ -105,8 +106,9 @@ class Monad m => MonadCircuit i a m | m -> i, m -> a where newAssigned :: ClosedPoly i a -> m i newAssigned p = newConstrained (\x i -> p x - x i) p --- | Field of witnesses with decidable equality is called an ``arithmetic'' field. -type Arithmetic a = (WitnessField a, Eq a) +-- | Field of witnesses with decidable equality and ordering +-- is called an ``arithmetic'' field. +type Arithmetic a = (WitnessField a, Eq a, Ord a) -- | An example implementation of a @'MonadCircuit'@ which computes witnesses -- immediately and drops the constraints. diff --git a/tests/Tests/Arithmetization/Test3.hs b/tests/Tests/Arithmetization/Test3.hs index 4ff746804..e1726dc6e 100644 --- a/tests/Tests/Arithmetization/Test3.hs +++ b/tests/Tests/Arithmetization/Test3.hs @@ -9,16 +9,17 @@ import Test.Hspec import ZkFold.Base.Algebra.Basic.Class (fromConstant) import ZkFold.Base.Algebra.Basic.Field (Zp) +import ZkFold.Symbolic.Class (Symbolic) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool (Bool (..)) import ZkFold.Symbolic.Data.FieldElement (FieldElement) -import ZkFold.Symbolic.Data.Ord (Ord (..)) +import ZkFold.Symbolic.Data.Ord ((<=)) import ZkFold.Symbolic.Interpreter (Interpreter (Interpreter)) type R = ArithmeticCircuit (Zp 97) -- A comparison test -testFunc :: Ord (Bool c) (FieldElement c) => FieldElement c -> FieldElement c -> Bool c +testFunc :: Symbolic c => FieldElement c -> FieldElement c -> Bool c testFunc x y = x <= y specArithmetization3 :: Spec diff --git a/tests/Tests/Arithmetization/Test4.hs b/tests/Tests/Arithmetization/Test4.hs index cb3e9a7eb..4ffc79725 100644 --- a/tests/Tests/Arithmetization/Test4.hs +++ b/tests/Tests/Arithmetization/Test4.hs @@ -18,11 +18,11 @@ import ZkFold.Base.Protocol.ARK.Plonk (Plonk (..), PlonkP plonkVerifierInput) import ZkFold.Base.Protocol.ARK.Plonk.Internal (getParams) import ZkFold.Base.Protocol.NonInteractiveProof (NonInteractiveProof (..)) +import ZkFold.Symbolic.Class import ZkFold.Symbolic.Compiler (ArithmeticCircuit (..), acValue, applyArgs, compile) import ZkFold.Symbolic.Data.Bool (Bool (..)) import ZkFold.Symbolic.Data.Eq (Eq (..)) import ZkFold.Symbolic.Data.FieldElement (FieldElement) -import ZkFold.Symbolic.Class type N = 1