From 5c71e8a637178eb8c0c234979f76ed296902c641 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Fri, 6 Sep 2024 01:12:45 +0300 Subject: [PATCH 1/4] improved WitnessField --- src/ZkFold/Base/Algebra/Basic/Class.hs | 3 +- src/ZkFold/Base/Algebra/Basic/Sources.hs | 49 +++++++++++++++---- src/ZkFold/Symbolic/Class.hs | 3 +- .../Symbolic/Compiler/ArithmeticCircuit.hs | 2 +- .../Compiler/ArithmeticCircuit/Internal.hs | 2 +- src/ZkFold/Symbolic/Data/Combinators.hs | 44 ++++++++--------- src/ZkFold/Symbolic/Data/FieldElement.hs | 5 +- src/ZkFold/Symbolic/Data/Ord.hs | 7 ++- src/ZkFold/Symbolic/MonadCircuit.hs | 11 +++-- tests/Tests/ArithmeticCircuit.hs | 4 +- 10 files changed, 83 insertions(+), 47 deletions(-) diff --git a/src/ZkFold/Base/Algebra/Basic/Class.hs b/src/ZkFold/Base/Algebra/Basic/Class.hs index 7970405fd..7f69b0809 100644 --- a/src/ZkFold/Base/Algebra/Basic/Class.hs +++ b/src/ZkFold/Base/Algebra/Basic/Class.hs @@ -279,9 +279,10 @@ If @a@ and @b@ are in @R@ and @b@ is nonzero, then there exist @q@ and @r@ in @R The function @divMod@ associated with this class produces @q@ and @r@ given @a@ and @b@. -} class Semiring a => EuclideanDomain a where - {-# MINIMAL divMod #-} + {-# MINIMAL divMod | (div, mod) #-} divMod :: a -> a -> (a, a) + divMod n d = (n `div` d, n `mod` d) div :: a -> a -> a div n d = Haskell.fst $ divMod n d diff --git a/src/ZkFold/Base/Algebra/Basic/Sources.hs b/src/ZkFold/Base/Algebra/Basic/Sources.hs index 96da98cd1..d0569d3a6 100644 --- a/src/ZkFold/Base/Algebra/Basic/Sources.hs +++ b/src/ZkFold/Base/Algebra/Basic/Sources.hs @@ -1,9 +1,17 @@ {-# LANGUAGE DerivingStrategies #-} -module ZkFold.Base.Algebra.Basic.Sources where - +module ZkFold.Base.Algebra.Basic.Sources (Sources (..)) where + +import Data.Function (const, id, (.)) +import Data.Kind (Type) +import Data.Maybe (Maybe (..)) +import Data.Monoid (Monoid (..)) +import Data.Ord (Ord) +import Data.Semigroup (Semigroup (..)) import Data.Set (Set) -import Prelude hiding (replicate) +import qualified Data.Set as Set +import Numeric.Natural (Natural) +import Prelude (Integer) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Prelude @@ -11,15 +19,22 @@ import ZkFold.Prelude newtype Sources a i = Sources { runSources :: Set i } deriving newtype (Semigroup, Monoid) -instance MultiplicativeSemigroup c => Exponent (Sources a i) c where - (^) = const +empty :: Sources a i +empty = Sources Set.empty -instance MultiplicativeMonoid c => Scale c (Sources a i) where - scale = const id +instance {-# OVERLAPPING #-} FromConstant (Sources a i) (Sources a i) where + fromConstant = id + +instance {-# OVERLAPPING #-} Ord i => Scale (Sources a i) (Sources a i) where + scale = (<>) instance Ord i => AdditiveSemigroup (Sources a i) where (+) = (<>) +instance {-# OVERLAPPABLE #-} + MultiplicativeMonoid c => Scale c (Sources a i) where + scale = const id + instance Ord i => AdditiveMonoid (Sources a i) where zero = mempty @@ -32,14 +47,20 @@ instance Finite a => Finite (Sources a i) where instance Ord i => MultiplicativeSemigroup (Sources a i) where (*) = (<>) +instance Exponent (Sources a i) Natural where + (^) = const + instance Ord i => MultiplicativeMonoid (Sources a i) where one = mempty +instance Exponent (Sources a i) Integer where + (^) = const + instance Ord i => MultiplicativeGroup (Sources a i) where invert = id -instance Ord i => FromConstant c (Sources a i) where - fromConstant _ = mempty +instance {-# OVERLAPPABLE #-} FromConstant c (Sources a i) where + fromConstant = const empty instance Ord i => Semiring (Sources a i) @@ -47,8 +68,16 @@ instance Ord i => Ring (Sources a i) instance Ord i => Field (Sources a i) where finv = id - rootOfUnity _ = Just (Sources mempty) + rootOfUnity _ = Just mempty + +instance ToConstant (Sources (a :: Type) i) where + type Const (Sources a i) = Sources (Const a) i + toConstant = Sources . runSources instance (Finite a, Ord i) => BinaryExpansion (Sources a i) where type Bits (Sources a i) = [Sources a i] binaryExpansion = replicate (numberOfBits @a) + +instance Ord i => EuclideanDomain (Sources a i) where + div = (<>) + mod = (<>) diff --git a/src/ZkFold/Symbolic/Class.hs b/src/ZkFold/Symbolic/Class.hs index c8968dd88..33d06036b 100644 --- a/src/ZkFold/Symbolic/Class.hs +++ b/src/ZkFold/Symbolic/Class.hs @@ -8,6 +8,7 @@ import Data.Functor (Functor (fmap), (<$>)) import Data.Kind (Type) import Data.Type.Equality (type (~)) import GHC.Generics (Par1 (Par1), type (:.:) (unComp1)) +import Numeric.Natural (Natural) import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Control.HApplicative (HApplicative (hpair, hunit)) @@ -49,7 +50,7 @@ class (HApplicative c, Package c, Arithmetic (BaseField c)) => Symbolic c where -- | A wrapper around @'symbolicF'@ which extracts the pure computation -- from the circuit computation using the @'Witnesses'@ newtype. fromCircuitF :: c f -> CircuitFun f g (BaseField c) -> c g - fromCircuitF x f = symbolicF x (runWitnesses @(BaseField c) . f) f + fromCircuitF x f = symbolicF x (runWitnesses @Natural @(BaseField c) . f) f -- | Embeds the pure value(s) into generic context @c@. embed :: (Symbolic c, Foldable f, Functor f) => f (BaseField c) -> c f diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index b7177b955..0c4a5594f 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -69,7 +69,7 @@ desugarRange :: (Arithmetic a, MonadCircuit i a m) => i -> a -> m () desugarRange i b | b == negate one = return () | otherwise = do - let bs = binaryExpansion b + let bs = binaryExpansion (toConstant b) is <- expansion (length bs) i case dropWhile ((== one) . fst) (zip bs is) of [] -> return () diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index f4a6a6345..c106f144c 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -196,7 +196,7 @@ instance o ~ U1 => Monoid (ArithmeticCircuit a i o) where type VarField = Zp BLS12_381_Scalar toField :: Arithmetic a => a -> VarField -toField = toZp . fromConstant . fromBinary @Natural . castBits . binaryExpansion +toField = fromConstant . toConstant -- TODO: Remove the hardcoded constant. toVar :: diff --git a/src/ZkFold/Symbolic/Data/Combinators.hs b/src/ZkFold/Symbolic/Data/Combinators.hs index 77919c9b0..5a6854fb5 100644 --- a/src/ZkFold/Symbolic/Data/Combinators.hs +++ b/src/ZkFold/Symbolic/Data/Combinators.hs @@ -7,7 +7,7 @@ module ZkFold.Symbolic.Data.Combinators where import Control.Monad (mapM) -import Data.Foldable (Foldable (..), foldlM) +import Data.Foldable (foldlM) import Data.Kind (Type) import Data.List (find, splitAt) import Data.List.Split (chunksOf) @@ -16,7 +16,6 @@ import Data.Proxy (Proxy (..)) import Data.Ratio ((%)) import Data.Traversable (Traversable, for) import Data.Type.Bool (If) -import Data.Type.Equality (type (~)) import Data.Type.Ord import qualified Data.Zip as Z import GHC.Base (const, return) @@ -28,7 +27,6 @@ import Type.Errors import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Number (value) -import ZkFold.Prelude (drop, take, (!!)) import ZkFold.Symbolic.MonadCircuit -- | A class for isomorphic types. @@ -47,8 +45,6 @@ class Extend a b where class Shrink a b where shrink :: a -> b - - -- | Convert an @ArithmeticCircuit@ to bits and return their corresponding variables. -- toBits @@ -61,13 +57,10 @@ toBits toBits regs hiBits loBits = do let lows = tail regs high = head regs - bitsLow <- Haskell.concatMap Haskell.reverse <$> mapM (expansion loBits) lows bitsHigh <- Haskell.reverse <$> expansion hiBits high - pure $ bitsHigh <> bitsLow - -- | The inverse of @toBits@. -- fromBits @@ -78,10 +71,8 @@ fromBits fromBits hiBits loBits bits = do let (bitsHighNew, bitsLowNew) = splitAt (Haskell.fromIntegral hiBits) bits let lowVarsNew = chunksOf (Haskell.fromIntegral loBits) bitsLowNew - lowsNew <- mapM (horner . Haskell.reverse) lowVarsNew highNew <- horner . Haskell.reverse $ bitsHighNew - pure $ highNew : lowsNew data RegisterSize = Auto | Fixed Natural @@ -135,9 +126,7 @@ type family MaxRegisterSize (a :: Type) (regCount :: Natural) :: Natural where type family ListRange (from :: Natural) (to :: Natural) :: [Natural] where ListRange from from = '[from] - ListRange from to = from ': (ListRange (from + 1) to) - - + ListRange from to = from ': ListRange (from + 1) to numberOfRegisters :: forall a n r . ( Finite a, KnownNat n, KnownRegisterSize r) => Natural numberOfRegisters = case regSize @r of @@ -175,13 +164,13 @@ highRegisterBits :: forall p n. (Finite p, KnownNat n) => Natural highRegisterBits = case getNatural @n `mod` maxBitsPerFieldElement @p of 0 -> maxBitsPerFieldElement @p m -> m + -- | The lowest possible number of registers to encode @n@ bits using Field elements from @p@ -- assuming that each register storest the largest possible number of bits. -- minNumberOfRegisters :: forall p n. (Finite p, KnownNat n) => Natural minNumberOfRegisters = (getNatural @n + maxBitsPerRegister @p @n -! 1) `div` maxBitsPerRegister @p @n - --------------------------------------------------------------- expansion :: MonadCircuit i a m => Natural -> i -> m [i] @@ -196,10 +185,13 @@ bitsOf :: MonadCircuit i a m => Natural -> i -> m [i] -- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller -- bits of @k@. bitsOf n k = for [0 .. n -! 1] $ \j -> - newConstrained (\x i -> let xi = x i in xi * (xi - one)) ((!! j) . repr . ($ k)) + newConstrained (\x i -> let xi = x i in xi * (xi - one)) (repr j . ($ k)) where - repr :: forall b . (BinaryExpansion b, Bits b ~ [b], Finite b) => b -> [b] - repr = padBits (numberOfBits @b) . binaryExpansion + repr j = + fromConstant + . (`mod` fromConstant @Natural 2) + . (`div` fromConstant @Natural (2 ^ j)) + . toConstant horner :: MonadCircuit i a m => [i] -> m i -- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using @@ -213,15 +205,21 @@ splitExpansion :: (MonadCircuit i a m, Arithmetic a) => Natural -> Natural -> i -- @k = 2^n1 h + l@, @l@ fits in @n1@ bits and @h@ fits in n2 bits (if such -- values exist). splitExpansion n1 n2 k = do - let f x y = x + y + y - l <- newRanged (fromConstant $ (2 :: Natural) ^ n1 -! 1) $ foldr f zero . take n1 . repr . ($ k) - h <- newRanged (fromConstant $ (2 :: Natural) ^ n2 -! 1) $ foldr f zero . take n2 . drop n1 . repr . ($ k) + l <- newRanged (fromConstant @Natural $ 2 ^ n1 -! 1) $ lower . ($ k) + h <- newRanged (fromConstant @Natural $ 2 ^ n2 -! 1) $ upper . ($ k) constraint (\x -> x k - x l - scale (2 ^ n1 :: Natural) (x h)) return (l, h) where - repr :: forall b . (BinaryExpansion b, Bits b ~ [b]) => b -> [b] - repr = padBits (n1 + n2) . binaryExpansion - + lower :: WitnessField n a => a -> a + lower = + fromConstant . (`mod` fromConstant @Natural (2 ^ n1)) . toConstant + + upper :: WitnessField n a => a -> a + upper = + fromConstant + . (`mod` fromConstant @Natural (2 ^ n2)) + . (`div` fromConstant @Natural (2 ^ n1)) + . toConstant runInvert :: (MonadCircuit i a m, Z.Zip f, Traversable f) => f i -> m (f i, f i) runInvert is = do diff --git a/src/ZkFold/Symbolic/Data/FieldElement.hs b/src/ZkFold/Symbolic/Data/FieldElement.hs index bdb111393..ad423bb90 100644 --- a/src/ZkFold/Symbolic/Data/FieldElement.hs +++ b/src/ZkFold/Symbolic/Data/FieldElement.hs @@ -92,8 +92,9 @@ instance instance Symbolic c => BinaryExpansion (FieldElement c) where type Bits (FieldElement c) = c (Vector (NumberOfBits (BaseField c))) binaryExpansion (FieldElement c) = hmap unsafeToVector $ symbolicF c - (\(Par1 v) -> padBits (numberOfBits @(BaseField c)) $ binaryExpansion v) - (\(Par1 i) -> expansion (numberOfBits @(BaseField c)) i) + (padBits n . fmap fromConstant . binaryExpansion . toConstant . unPar1) + (expansion n . unPar1) + where n = numberOfBits @(BaseField c) fromBinary bits = FieldElement $ symbolicF bits (Par1 . foldr (\x y -> x + y + y) zero) $ fmap Par1 . horner . fromVector diff --git a/src/ZkFold/Symbolic/Data/Ord.hs b/src/ZkFold/Symbolic/Data/Ord.hs index 2fcfabd8d..0574ad53c 100644 --- a/src/ZkFold/Symbolic/Data/Ord.hs +++ b/src/ZkFold/Symbolic/Data/Ord.hs @@ -11,6 +11,7 @@ import Data.Data (Proxy (..)) import Data.Foldable (Foldable, toList) import Data.Function ((.)) import Data.Functor ((<$>)) +import Data.List (map) import qualified Data.Zip as Z import GHC.Generics (Par1 (..)) import Prelude (type (~), ($)) @@ -91,8 +92,10 @@ getBitsBE :: -- youngest. getBitsBE x = hmap unsafeToVector - $ symbolicF (pieces x Proxy) (binaryExpansion . V.item) - $ expansion (numberOfBits @(BaseField c)) . V.item + $ symbolicF (pieces x Proxy) + (map fromConstant . padBits n . binaryExpansion . toConstant . V.item) + (expansion n . V.item) + where n = numberOfBits @(BaseField c) bitwiseGE :: forall c f . (Symbolic c, Z.Zip f, Foldable f) => c f -> c f -> Bool c -- ^ Given two lists of bits of equal length, compares them lexicographically. diff --git a/src/ZkFold/Symbolic/MonadCircuit.hs b/src/ZkFold/Symbolic/MonadCircuit.hs index c6a224cd9..45935c733 100644 --- a/src/ZkFold/Symbolic/MonadCircuit.hs +++ b/src/ZkFold/Symbolic/MonadCircuit.hs @@ -17,7 +17,8 @@ import ZkFold.Base.Algebra.Basic.Class -- | A @'WitnessField'@ should support all algebraic operations -- used inside an arithmetic circuit. -type WitnessField a = (FiniteField a, BinaryExpansion a, Bits a ~ [a]) +type WitnessField n a = ( FiniteField a, ToConstant a, Const a ~ n + , FromConstant n a, EuclideanDomain n) -- | A type of witness builders. @i@ is a type of variables, @a@ is a base field. -- @@ -27,7 +28,7 @@ type WitnessField a = (FiniteField a, BinaryExpansion a, Bits a ~ [a]) -- -- NOTE: the property above is correct by construction for each function of a -- suitable type, you don't have to check it yourself. -type Witness i a = forall x . (Algebra a x, WitnessField x) => (i -> x) -> x +type Witness i a = forall x n . (Algebra a x, WitnessField n x) => (i -> x) -> x -- | A type of constraints for new variables. -- @i@ is a type of variables, @a@ is a base field. @@ -109,14 +110,14 @@ class Monad m => MonadCircuit i a m | m -> i, m -> a where -- | Field of witnesses with decidable equality and ordering -- is called an ``arithmetic'' field. -type Arithmetic a = (WitnessField a, ToConstant a, Const a ~ Natural, Eq a, Ord a) +type Arithmetic a = (WitnessField Natural a, Eq a, Ord a) -- | An example implementation of a @'MonadCircuit'@ which computes witnesses -- immediately and drops the constraints. -newtype Witnesses a x = Witnesses { runWitnesses :: x } +newtype Witnesses n a x = Witnesses { runWitnesses :: x } deriving (Functor, Applicative, Monad) via Identity -instance WitnessField a => MonadCircuit a a (Witnesses a) where +instance WitnessField n a => MonadCircuit a a (Witnesses n a) where newRanged _ w = return (w id) newConstrained _ w = return (w id) constraint _ = return () diff --git a/tests/Tests/ArithmeticCircuit.hs b/tests/Tests/ArithmeticCircuit.hs index 7510589dc..113ae853c 100644 --- a/tests/Tests/ArithmeticCircuit.hs +++ b/tests/Tests/ArithmeticCircuit.hs @@ -4,6 +4,7 @@ module Tests.ArithmeticCircuit (exec1, it, specArithmeticCircuit) where import Data.Bool (bool) +import Data.Functor ((<$>)) import GHC.Generics (U1 (..)) import Prelude (IO, Show, String, id, ($)) import qualified Prelude as Haskell @@ -54,7 +55,8 @@ specArithmeticCircuit' = hspec $ do -- in withMaxSuccess 1 $ checkClosedCircuit r .&&. exec1 r === one it "computes binary expansion" $ \(x :: a) -> let rs = binaryExpansion (fromConstant x :: FieldElement (ArithmeticCircuit a U1)) - in checkClosedCircuit rs .&&. V.fromVector (exec rs) === padBits (numberOfBits @a) (binaryExpansion x) + as = padBits (numberOfBits @a) $ fromConstant <$> binaryExpansion (toConstant x) + in checkClosedCircuit rs .&&. V.fromVector (exec rs) === as it "internalizes equality" $ \(x :: a) (y :: a) -> let Bool r = (fromConstant x :: FieldElement (ArithmeticCircuit a U1)) == fromConstant y in checkClosedCircuit @a r .&&. exec1 r === bool zero one (x Haskell.== y) From f653d949605f0cf9becbfb100f040ba5f3512855 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Fri, 6 Sep 2024 01:50:19 +0300 Subject: [PATCH 2/4] fix endianness bug in getBitsBE --- src/ZkFold/Symbolic/Data/Combinators.hs | 9 +++++---- src/ZkFold/Symbolic/Data/Ord.hs | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/ZkFold/Symbolic/Data/Combinators.hs b/src/ZkFold/Symbolic/Data/Combinators.hs index 5a6854fb5..29e1127d7 100644 --- a/src/ZkFold/Symbolic/Data/Combinators.hs +++ b/src/ZkFold/Symbolic/Data/Combinators.hs @@ -187,11 +187,12 @@ bitsOf :: MonadCircuit i a m => Natural -> i -> m [i] bitsOf n k = for [0 .. n -! 1] $ \j -> newConstrained (\x i -> let xi = x i in xi * (xi - one)) (repr j . ($ k)) where + repr :: WitnessField n x => Natural -> x -> x repr j = - fromConstant - . (`mod` fromConstant @Natural 2) - . (`div` fromConstant @Natural (2 ^ j)) - . toConstant + fromConstant + . (`mod` fromConstant @Natural 2) + . (`div` fromConstant @Natural (2 ^ j)) + . toConstant horner :: MonadCircuit i a m => [i] -> m i -- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using diff --git a/src/ZkFold/Symbolic/Data/Ord.hs b/src/ZkFold/Symbolic/Data/Ord.hs index 0574ad53c..7e4f8e300 100644 --- a/src/ZkFold/Symbolic/Data/Ord.hs +++ b/src/ZkFold/Symbolic/Data/Ord.hs @@ -91,9 +91,9 @@ getBitsBE :: -- ^ @getBitsBE x@ returns a list of circuits computing bits of @x@, eldest to -- youngest. getBitsBE x = - hmap unsafeToVector + hmap (V.reverse . unsafeToVector) $ symbolicF (pieces x Proxy) - (map fromConstant . padBits n . binaryExpansion . toConstant . V.item) + (padBits n . map fromConstant . binaryExpansion . toConstant . V.item) (expansion n . V.item) where n = numberOfBits @(BaseField c) From f846fcf526cb60c980ba0037510b276a3e2d7879 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Fri, 6 Sep 2024 21:36:35 +0300 Subject: [PATCH 3/4] fix Sources instances --- src/ZkFold/Base/Algebra/Basic/Sources.hs | 19 +++++++++---------- tests/Tests/ArithmeticCircuit.hs | 4 ++++ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/ZkFold/Base/Algebra/Basic/Sources.hs b/src/ZkFold/Base/Algebra/Basic/Sources.hs index d0569d3a6..8552f8b89 100644 --- a/src/ZkFold/Base/Algebra/Basic/Sources.hs +++ b/src/ZkFold/Base/Algebra/Basic/Sources.hs @@ -2,7 +2,7 @@ module ZkFold.Base.Algebra.Basic.Sources (Sources (..)) where -import Data.Function (const, id, (.)) +import Data.Function (const, id) import Data.Kind (Type) import Data.Maybe (Maybe (..)) import Data.Monoid (Monoid (..)) @@ -28,13 +28,15 @@ instance {-# OVERLAPPING #-} FromConstant (Sources a i) (Sources a i) where instance {-# OVERLAPPING #-} Ord i => Scale (Sources a i) (Sources a i) where scale = (<>) -instance Ord i => AdditiveSemigroup (Sources a i) where - (+) = (<>) +instance {-# OVERLAPPABLE #-} FromConstant c (Sources a i) where + fromConstant = const empty -instance {-# OVERLAPPABLE #-} - MultiplicativeMonoid c => Scale c (Sources a i) where +instance {-# OVERLAPPABLE #-} MultiplicativeMonoid c => Scale c (Sources a i) where scale = const id +instance Ord i => AdditiveSemigroup (Sources a i) where + (+) = (<>) + instance Ord i => AdditiveMonoid (Sources a i) where zero = mempty @@ -59,9 +61,6 @@ instance Exponent (Sources a i) Integer where instance Ord i => MultiplicativeGroup (Sources a i) where invert = id -instance {-# OVERLAPPABLE #-} FromConstant c (Sources a i) where - fromConstant = const empty - instance Ord i => Semiring (Sources a i) instance Ord i => Ring (Sources a i) @@ -71,8 +70,8 @@ instance Ord i => Field (Sources a i) where rootOfUnity _ = Just mempty instance ToConstant (Sources (a :: Type) i) where - type Const (Sources a i) = Sources (Const a) i - toConstant = Sources . runSources + type Const (Sources a i) = Sources a i + toConstant = id instance (Finite a, Ord i) => BinaryExpansion (Sources a i) where type Bits (Sources a i) = [Sources a i] diff --git a/tests/Tests/ArithmeticCircuit.hs b/tests/Tests/ArithmeticCircuit.hs index 113ae853c..6302bdac8 100644 --- a/tests/Tests/ArithmeticCircuit.hs +++ b/tests/Tests/ArithmeticCircuit.hs @@ -22,6 +22,7 @@ import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.Data.Bool import ZkFold.Symbolic.Data.Eq import ZkFold.Symbolic.Data.FieldElement +import ZkFold.Symbolic.Data.Ord ((<=)) correctHom0 :: forall a . (Arithmetic a, Scale a a, Show a) => (forall b . Field b => b) -> Property correctHom0 f = let r = fromFieldElement f in withMaxSuccess 1 $ checkClosedCircuit r .&&. exec1 r === f @a @@ -63,6 +64,9 @@ specArithmeticCircuit' = hspec $ do it "internal equality is reflexive" $ \(x :: a) -> let Bool r = (fromConstant x :: FieldElement (ArithmeticCircuit a U1)) == fromConstant x in checkClosedCircuit @a r .&&. exec1 r === one + it "<=s correctly" $ withMaxSuccess 10 $ \(x :: a) (y :: a) -> + let Bool r = (fromConstant x :: FieldElement (ArithmeticCircuit a U1)) <= fromConstant y + in checkClosedCircuit @a r .&&. exec1 r === bool zero one (x Haskell.<= y) specArithmeticCircuit :: IO () specArithmeticCircuit = do From 26396af46b0a8a8dd565b0397b830327b2710358 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Sat, 7 Sep 2024 00:06:46 +0300 Subject: [PATCH 4/4] EuclideanDomain -> SemiEuclidean --- src/ZkFold/Base/Algebra/Basic/Class.hs | 22 +++++++++++-------- src/ZkFold/Base/Algebra/Basic/Field.hs | 2 +- src/ZkFold/Base/Algebra/Basic/Sources.hs | 2 +- .../Symbolic/Algorithms/Hash/Blake2b.hs | 4 ++-- src/ZkFold/Symbolic/Data/UInt.hs | 4 ++-- src/ZkFold/Symbolic/MonadCircuit.hs | 2 +- 6 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/ZkFold/Base/Algebra/Basic/Class.hs b/src/ZkFold/Base/Algebra/Basic/Class.hs index 7f69b0809..953efc451 100644 --- a/src/ZkFold/Base/Algebra/Basic/Class.hs +++ b/src/ZkFold/Base/Algebra/Basic/Class.hs @@ -269,16 +269,20 @@ intScale n a | n < 0 = naturalFromInteger (-n) `scale` negate a -} class (AdditiveMonoid a, MultiplicativeMonoid a, FromConstant Natural a) => Semiring a -{- | A Euclidean domain @R@ is an integral domain which can be endowed -with at least one function @f : R\{0} -> R+@ s.t. -If @a@ and @b@ are in @R@ and @b@ is nonzero, then there exist @q@ and @r@ in @R@ such that -@a = bq + r@ and either @r = 0@ or @f(r) < f(b)@. +{- | A semi-Euclidean-domain @a@ is a semiring without zero divisors which can +be endowed with at least one function @f : a\{0} -> R+@ s.t. if @x@ and @y@ are +in @a@ and @y@ is nonzero, then there exist @q@ and @r@ in @a@ such that +@x = qy + r@ and either @r = 0@ or @f(r) < f(y)@. -@q@ and @r@ are called respectively a quotient and a remainder of the division (or Euclidean division) of @a@ by @b@. +@q@ and @r@ are called respectively a quotient and a remainder of the division +(or Euclidean division) of @x@ by @y@. -The function @divMod@ associated with this class produces @q@ and @r@ given @a@ and @b@. +The function @divMod@ associated with this class produces @q@ and @r@ +given @a@ and @b@. + +This is a generalization of a notion of Euclidean domains to semirings. -} -class Semiring a => EuclideanDomain a where +class Semiring a => SemiEuclidean a where {-# MINIMAL divMod | (div, mod) #-} divMod :: a -> a -> (a, a) @@ -479,7 +483,7 @@ instance AdditiveMonoid Natural where instance Semiring Natural -instance EuclideanDomain Natural where +instance SemiEuclidean Natural where divMod = Haskell.divMod instance BinaryExpansion Natural where @@ -517,7 +521,7 @@ instance FromConstant Natural Integer where instance Semiring Integer -instance EuclideanDomain Integer where +instance SemiEuclidean Integer where divMod = Haskell.divMod instance Ring Integer diff --git a/src/ZkFold/Base/Algebra/Basic/Field.hs b/src/ZkFold/Base/Algebra/Basic/Field.hs index 047ab7eec..5821fc214 100644 --- a/src/ZkFold/Base/Algebra/Basic/Field.hs +++ b/src/ZkFold/Base/Algebra/Basic/Field.hs @@ -95,7 +95,7 @@ instance KnownNat p => FromConstant Natural (Zp p) where instance KnownNat p => Semiring (Zp p) -instance KnownNat p => EuclideanDomain (Zp p) where +instance KnownNat p => SemiEuclidean (Zp p) where divMod a b = let (q, r) = Haskell.divMod (fromZp a) (fromZp b) in (toZp . fromIntegral $ q, toZp . fromIntegral $ r) diff --git a/src/ZkFold/Base/Algebra/Basic/Sources.hs b/src/ZkFold/Base/Algebra/Basic/Sources.hs index 8552f8b89..c52eea141 100644 --- a/src/ZkFold/Base/Algebra/Basic/Sources.hs +++ b/src/ZkFold/Base/Algebra/Basic/Sources.hs @@ -77,6 +77,6 @@ instance (Finite a, Ord i) => BinaryExpansion (Sources a i) where type Bits (Sources a i) = [Sources a i] binaryExpansion = replicate (numberOfBits @a) -instance Ord i => EuclideanDomain (Sources a i) where +instance Ord i => SemiEuclidean (Sources a i) where div = (<>) mod = (<>) diff --git a/src/ZkFold/Symbolic/Algorithms/Hash/Blake2b.hs b/src/ZkFold/Symbolic/Algorithms/Hash/Blake2b.hs index b4c9b1726..78ce6b2fc 100644 --- a/src/ZkFold/Symbolic/Algorithms/Hash/Blake2b.hs +++ b/src/ZkFold/Symbolic/Algorithms/Hash/Blake2b.hs @@ -15,8 +15,8 @@ import Prelude hiding (Num ( replicate, splitAt, truncate, (!!), (&&), (^)) import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup (..), AdditiveSemigroup (..), - EuclideanDomain (..), Exponent (..), - FromConstant (..), MultiplicativeSemigroup (..), + Exponent (..), FromConstant (..), + MultiplicativeSemigroup (..), SemiEuclidean (..), divMod, one, zero, (-!)) import ZkFold.Base.Algebra.Basic.Number import ZkFold.Prelude (length, replicate, splitAt, (!!)) diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index 062005c16..e497877f1 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -102,7 +102,7 @@ cast n = eea :: forall n c r . Symbolic c - => EuclideanDomain (UInt n r c) + => SemiEuclidean (UInt n r c) => KnownNat n => KnownNat (NumberOfRegisters (BaseField c) n r) => AdditiveGroup (UInt n r c) @@ -200,7 +200,7 @@ instance , KnownRegisterSize rs , r ~ NumberOfRegisters (BaseField c) n rs , NFData (c (Vector r)) - ) => EuclideanDomain (UInt n rs c) where + ) => SemiEuclidean (UInt n rs c) where divMod numerator d = bool @(Bool c) (q, r) (zero, zero) (d == zero) where (q, r) = Haskell.foldl longDivisionStep (zero, zero) [value @n -! 1, value @n -! 2 .. 0] diff --git a/src/ZkFold/Symbolic/MonadCircuit.hs b/src/ZkFold/Symbolic/MonadCircuit.hs index 45935c733..ae79ad640 100644 --- a/src/ZkFold/Symbolic/MonadCircuit.hs +++ b/src/ZkFold/Symbolic/MonadCircuit.hs @@ -18,7 +18,7 @@ import ZkFold.Base.Algebra.Basic.Class -- | A @'WitnessField'@ should support all algebraic operations -- used inside an arithmetic circuit. type WitnessField n a = ( FiniteField a, ToConstant a, Const a ~ n - , FromConstant n a, EuclideanDomain n) + , FromConstant n a, SemiEuclidean n) -- | A type of witness builders. @i@ is a type of variables, @a@ is a base field. --