From 4a60f19f29847ed19f7291901299ab0f639326e5 Mon Sep 17 00:00:00 2001 From: TurtlePU Date: Tue, 10 Sep 2024 10:21:14 +0300 Subject: [PATCH] removed problematic Scale and FromConstant instances --- src/ZkFold/Base/Algebra/Basic/Class.hs | 105 ++++++++++-------- src/ZkFold/Base/Algebra/Basic/Field.hs | 10 +- src/ZkFold/Base/Algebra/Basic/Sources.hs | 2 +- .../Polynomials/Multivariate/Polynomial.hs | 8 +- .../Base/Algebra/Polynomials/Univariate.hs | 11 +- .../Base/Protocol/Plonkup/PlonkConstraint.hs | 4 +- src/ZkFold/Base/Protocol/Plonkup/Relation.hs | 1 - .../Base/Protocol/Protostar/Internal.hs | 2 +- .../Symbolic/Compiler/ArithmeticCircuit.hs | 2 - .../Compiler/ArithmeticCircuit/Instance.hs | 2 +- .../Compiler/ArithmeticCircuit/Internal.hs | 2 +- .../Compiler/ArithmeticCircuit/Map.hs | 2 +- src/ZkFold/Symbolic/Data/FFA.hs | 4 + src/ZkFold/Symbolic/Data/FieldElement.hs | 7 +- src/ZkFold/Symbolic/Data/UInt.hs | 5 +- tests/Tests/ArithmeticCircuit.hs | 8 +- tests/Tests/Arithmetization.hs | 6 +- tests/Tests/Arithmetization/Test1.hs | 6 +- tests/Tests/NonInteractiveProof/Plonkup.hs | 7 +- 19 files changed, 114 insertions(+), 80 deletions(-) diff --git a/src/ZkFold/Base/Algebra/Basic/Class.hs b/src/ZkFold/Base/Algebra/Basic/Class.hs index 953efc451..401e5c729 100644 --- a/src/ZkFold/Base/Algebra/Basic/Class.hs +++ b/src/ZkFold/Base/Algebra/Basic/Class.hs @@ -34,10 +34,11 @@ class FromConstant a b where -- -- [Homomorphism] @fromConstant (c + d) == fromConstant c + fromConstant d@ fromConstant :: a -> b - -instance FromConstant a a where + default fromConstant :: a ~ b => a -> b fromConstant = id +instance FromConstant a a + -- | A class of algebraic structures which can be converted to "constant type" -- related with it: natural numbers, integers, rationals etc. Subject to the -- following law: @@ -60,10 +61,40 @@ instance ToConstant Void where -------------------------------------------------------------------------------- -{- | A class of types with a binary associative operation with a multiplicative -feel to it. Not necessarily commutative. --} -class MultiplicativeSemigroup a where +-- | A class for actions where multiplicative notation is the most natural +-- (including multiplication by constant itself). +class Scale b a where + -- | A left monoid action on a type. Should satisfy the following: + -- + -- [Compatibility] @scale (c * d) a == scale c (scale d a)@ + -- [Left identity] @scale one a == a@ + -- + -- If, in addition, a cast from constant is defined, they should agree: + -- + -- [Scale agrees] @scale c a == fromConstant c * a@ + -- [Cast agrees] @fromConstant c == scale c one@ + -- + -- If the action is on an abelian structure, scaling should respect it: + -- + -- [Left distributivity] @scale c (a + b) == scale c a + scale c b@ + -- [Right absorption] @scale c zero == zero@ + -- + -- If, in addition, the scaling itself is abelian, this structure should + -- propagate: + -- + -- [Right distributivity] @scale (c + d) a == scale c a + scale d a@ + -- [Left absorption] @scale zero a == zero@ + -- + -- The default implementation is the multiplication by a constant. + scale :: b -> a -> a + default scale :: (FromConstant b a, MultiplicativeSemigroup a) => b -> a -> a + scale = (*) . fromConstant + +instance MultiplicativeSemigroup a => Scale a a + +-- | A class of types with a binary associative operation with a multiplicative +-- feel to it. Not necessarily commutative. +class (FromConstant a a, Scale a a) => MultiplicativeSemigroup a where -- | A binary associative operation. The following should hold: -- -- [Associativity] @x * (y * z) == (x * y) * z@ @@ -72,11 +103,12 @@ class MultiplicativeSemigroup a where product1 :: (Foldable t, MultiplicativeSemigroup a) => t a -> a product1 = foldl1 (*) -{- | A class for semigroup (and monoid) actions on types where exponential -notation is the most natural (including an exponentiation itself). --} -class MultiplicativeSemigroup b => Exponent a b where - -- | A right semigroup action on a type. The following should hold: +-- | A class for actions on types where exponential notation is the most natural +-- (including an exponentiation itself). +class Exponent a b where + -- | A right action on a type. + -- + -- If exponents form a semigroup, the following should hold: -- -- [Compatibility] @a ^ (m * n) == (a ^ m) ^ n@ -- @@ -130,41 +162,6 @@ product = foldl' (*) one multiExp :: (MultiplicativeMonoid a, Exponent a b, Foldable t) => a -> t b -> a multiExp a = foldl' (\x y -> x * (a ^ y)) one -{- | A class for monoid actions where multiplicative notation is the most -natural (including multiplication by constant itself). --} -class MultiplicativeMonoid b => Scale b a where - -- | A left monoid action on a type. Should satisfy the following: - -- - -- [Compatibility] @scale (c * d) a == scale c (scale d a)@ - -- [Left identity] @scale one a == a@ - -- - -- If, in addition, a cast from constant is defined, they should agree: - -- - -- [Scale agrees] @scale c a == fromConstant c * a@ - -- [Cast agrees] @fromConstant c == scale c one@ - -- - -- If the action is on an abelian structure, scaling should respect it: - -- - -- [Left distributivity] @scale c (a + b) == scale c a + scale c b@ - -- [Right absorption] @scale c zero == zero@ - -- - -- If, in addition, the scaling itself is abelian, this structure should - -- propagate: - -- - -- [Right distributivity] @scale (c + d) a == scale c a + scale d a@ - -- [Left absorption] @scale zero a == zero@ - -- - -- The default implementation is the multiplication by a constant. - scale :: b -> a -> a - default scale :: (FromConstant b a, MultiplicativeSemigroup a) => b -> a -> a - scale = (*) . fromConstant - -instance MultiplicativeMonoid a => Scale a a - -instance {-# OVERLAPPABLE #-} (Scale b a, Functor f) => Scale b (f a) where - scale = fmap . scale - {- | A class of groups in a multiplicative notation. While exponentiation by an integer is specified in a constraint, a default @@ -201,7 +198,7 @@ intPow a n | n < 0 = invert a ^ naturalFromInteger (-n) -------------------------------------------------------------------------------- -- | A class of types with a binary associative, commutative operation. -class AdditiveSemigroup a where +class FromConstant a a => AdditiveSemigroup a where -- | A binary associative commutative operation. The following should hold: -- -- [Associativity] @x + (y + z) == (x + y) + z@ @@ -628,6 +625,10 @@ instance MultiplicativeMonoid a => Exponent a Bool where -------------------------------------------------------------------------------- +instance {-# OVERLAPPING #-} FromConstant [a] [a] + +instance {-# OVERLAPPING #-} MultiplicativeSemigroup a => Scale [a] [a] + instance MultiplicativeSemigroup a => MultiplicativeSemigroup [a] where (*) = zipWith (*) @@ -643,6 +644,9 @@ instance MultiplicativeGroup a => MultiplicativeGroup [a] where instance AdditiveSemigroup a => AdditiveSemigroup [a] where (+) = zipWith (+) +instance Scale b a => Scale b [a] where + scale = map . scale + instance AdditiveMonoid a => AdditiveMonoid [a] where zero = repeat zero @@ -658,6 +662,10 @@ instance Ring a => Ring [a] -------------------------------------------------------------------------------- +instance {-# OVERLAPPING #-} FromConstant (p -> a) (p -> a) + +instance {-# OVERLAPPING #-} MultiplicativeSemigroup a => Scale (p -> a) (p -> a) + instance MultiplicativeSemigroup a => MultiplicativeSemigroup (p -> a) where p1 * p2 = \x -> p1 x * p2 x @@ -673,6 +681,9 @@ instance MultiplicativeGroup a => MultiplicativeGroup (p -> a) where instance AdditiveSemigroup a => AdditiveSemigroup (p -> a) where p1 + p2 = \x -> p1 x + p2 x +instance Scale b a => Scale b (p -> a) where + scale = (.) . scale + instance AdditiveMonoid a => AdditiveMonoid (p -> a) where zero = const zero diff --git a/src/ZkFold/Base/Algebra/Basic/Field.hs b/src/ZkFold/Base/Algebra/Basic/Field.hs index 5821fc214..3de4df7c1 100644 --- a/src/ZkFold/Base/Algebra/Basic/Field.hs +++ b/src/ZkFold/Base/Algebra/Basic/Field.hs @@ -194,7 +194,7 @@ instance KnownNat p => Random (Zp p) where -- -- Note that left distributivity is satisfied, meaning -- @a ^ (m + n) = (a ^ m) * (a ^ n)@. -instance (KnownNat p, MultiplicativeGroup a, Order a ~ p) => Exponent a (Zp p) where +instance (MultiplicativeGroup a, Order a ~ p) => Exponent a (Zp p) where a ^ n = a ^ fromZp n ----------------------------- Field Extensions -------------------------------- @@ -211,6 +211,8 @@ instance Ord f => Ord (Ext2 f e) where instance (KnownNat (Order (Ext2 f e)), KnownNat (NumberOfBits (Ext2 f e))) => Finite (Ext2 f e) where type Order (Ext2 f e) = Order f ^ 2 +instance {-# OVERLAPPING #-} FromConstant (Ext2 f e) (Ext2 f e) + instance Field f => AdditiveSemigroup (Ext2 f e) where Ext2 a b + Ext2 c d = Ext2 (a + c) (b + d) @@ -224,6 +226,8 @@ instance Field f => AdditiveGroup (Ext2 f e) where negate (Ext2 a b) = Ext2 (negate a) (negate b) Ext2 a b - Ext2 c d = Ext2 (a - c) (b - d) +instance {-# OVERLAPPING #-} (Field f, Eq f, IrreduciblePoly f e) => Scale (Ext2 f e) (Ext2 f e) + instance (Field f, Eq f, IrreduciblePoly f e) => MultiplicativeSemigroup (Ext2 f e) where Ext2 a b * Ext2 c d = fromConstant (toPoly [a, b] * toPoly [c, d]) @@ -275,6 +279,8 @@ instance Ord f => Ord (Ext3 f e) where instance (KnownNat (Order (Ext3 f e)), KnownNat (NumberOfBits (Ext3 f e))) => Finite (Ext3 f e) where type Order (Ext3 f e) = Order f ^ 3 +instance {-# OVERLAPPING #-} FromConstant (Ext3 f e) (Ext3 f e) + instance Field f => AdditiveSemigroup (Ext3 f e) where Ext3 a b c + Ext3 d e f = Ext3 (a + d) (b + e) (c + f) @@ -288,6 +294,8 @@ instance Field f => AdditiveGroup (Ext3 f e) where negate (Ext3 a b c) = Ext3 (negate a) (negate b) (negate c) Ext3 a b c - Ext3 d e f = Ext3 (a - d) (b - e) (c - f) +instance {-# OVERLAPPING #-} (Field f, Eq f, IrreduciblePoly f e) => Scale (Ext3 f e) (Ext3 f e) + instance (Field f, Eq f, IrreduciblePoly f e) => MultiplicativeSemigroup (Ext3 f e) where Ext3 a b c * Ext3 d e f = fromConstant (toPoly [a, b, c] * toPoly [d, e, f]) diff --git a/src/ZkFold/Base/Algebra/Basic/Sources.hs b/src/ZkFold/Base/Algebra/Basic/Sources.hs index c52eea141..8bc0666a4 100644 --- a/src/ZkFold/Base/Algebra/Basic/Sources.hs +++ b/src/ZkFold/Base/Algebra/Basic/Sources.hs @@ -31,7 +31,7 @@ instance {-# OVERLAPPING #-} Ord i => Scale (Sources a i) (Sources a i) where instance {-# OVERLAPPABLE #-} FromConstant c (Sources a i) where fromConstant = const empty -instance {-# OVERLAPPABLE #-} MultiplicativeMonoid c => Scale c (Sources a i) where +instance {-# OVERLAPPABLE #-} Scale c (Sources a i) where scale = const id instance Ord i => AdditiveSemigroup (Sources a i) where diff --git a/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs b/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs index 4d1dec0df..aa14d84b0 100644 --- a/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs +++ b/src/ZkFold/Base/Algebra/Polynomials/Multivariate/Polynomial.hs @@ -48,9 +48,7 @@ evalPolynomial -> b evalPolynomial e f (P p) = foldr (\(c, m) x -> x + scale c (e f m)) zero p -variables :: forall c v . - (Ord v, MultiplicativeMonoid c) => - Poly c v Natural -> Set v +variables :: forall c v . Ord v => Poly c v Natural -> Set v variables = runSources . evalPolynomial evalMonomial (Sources @c . singleton) mapVarPolynomial :: Variable i => Map i i-> Poly c i j -> Poly c i j @@ -83,6 +81,8 @@ instance Polynomial c i j => Ord (Poly c i j) where instance (Arbitrary c, Arbitrary (Mono i j)) => Arbitrary (Poly c i j) where arbitrary = P <$> arbitrary +instance {-# OVERLAPPING #-} FromConstant (Poly c i j) (Poly c i j) + instance Polynomial c i j => AdditiveSemigroup (Poly c i j) where P l + P r = P $ go l r where @@ -106,6 +106,8 @@ instance Polynomial c i j => AdditiveMonoid (Poly c i j) where instance Polynomial c i j => AdditiveGroup (Poly c i j) where negate (P p) = P $ map (first negate) p +instance {-# OVERLAPPING #-} Polynomial c i j => Scale (Poly c i j) (Poly c i j) + instance Polynomial c i j => MultiplicativeSemigroup (Poly c i j) where P l * r = foldl' (+) (P []) $ map (`scaleM` r) l diff --git a/src/ZkFold/Base/Algebra/Polynomials/Univariate.hs b/src/ZkFold/Base/Algebra/Polynomials/Univariate.hs index 2998fbd9d..59de483d0 100644 --- a/src/ZkFold/Base/Algebra/Polynomials/Univariate.hs +++ b/src/ZkFold/Base/Algebra/Polynomials/Univariate.hs @@ -62,6 +62,8 @@ toPoly = removeZeros . P fromPoly :: Poly c -> V.Vector c fromPoly (P cs) = cs +instance {-# OVERLAPPING #-} FromConstant (Poly c) (Poly c) + instance FromConstant c c' => FromConstant c (Poly c') where fromConstant = P . V.singleton . fromConstant @@ -73,6 +75,11 @@ instance (Ring c, Eq c) => AdditiveSemigroup (Poly c) where lPadded = l V.++ V.replicate (len P.- V.length l) zero rPadded = r V.++ V.replicate (len P.- V.length r) zero +instance {-# OVERLAPPING #-} (Field c, Eq c) => Scale (Poly c) (Poly c) + +instance Scale k c => Scale k (Poly c) where + scale = fmap . scale + instance (Ring c, Eq c) => AdditiveMonoid (Poly c) where zero = P V.empty @@ -269,7 +276,7 @@ vec2poly :: (Ring c, Eq c) => PolyVec c size -> Poly c vec2poly (PV cs) = removeZeros $ P cs instance Scale c' c => Scale c' (PolyVec c size) where - scale c (PV p) = PV (scale c p) + scale c (PV p) = PV (scale c <$> p) instance Ring c => AdditiveSemigroup (PolyVec c size) where PV l + PV r = PV $ V.zipWith (+) l r @@ -283,6 +290,8 @@ instance (Ring c, KnownNat size) => AdditiveGroup (PolyVec c size) where instance (Field c, KnownNat size, Eq c) => Exponent (PolyVec c size) Natural where (^) = natPow +instance {-# OVERLAPPING #-} (Field c, KnownNat size, Eq c) => Scale (PolyVec c size) (PolyVec c size) + -- TODO (Issue #18): check for overflow instance (Field c, KnownNat size, Eq c) => MultiplicativeSemigroup (PolyVec c size) where l * r = poly2vec $ vec2poly l * vec2poly r diff --git a/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs b/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs index 7d501928c..85b4107d0 100644 --- a/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs +++ b/src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs @@ -54,7 +54,7 @@ instance (Arbitrary a, Finite a, ToConstant a, Const a ~ Natural, KnownNat i) => _ -> error "impossible" return $ PlonkConstraint qm ql qr qo qc x y z -toPlonkConstraint :: forall a i . (Eq a, FiniteField a, Scale a a, KnownNat i) => Poly a (Var (Vector i)) Natural -> PlonkConstraint i a +toPlonkConstraint :: forall a i . (Eq a, FiniteField a, KnownNat i) => Poly a (Var (Vector i)) Natural -> PlonkConstraint i a toPlonkConstraint p = let xs = map Just $ toList (variables p) perms = nubOrd $ map (take 3) $ permutations $ case length xs of @@ -86,7 +86,7 @@ toPlonkConstraint p = in head $ mapMaybe getCoefs perms -fromPlonkConstraint :: (Eq a, Scale a a, FromConstant a a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural +fromPlonkConstraint :: (Eq a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) = let xvar = maybe zero var xa = xvar a diff --git a/src/ZkFold/Base/Protocol/Plonkup/Relation.hs b/src/ZkFold/Base/Protocol/Plonkup/Relation.hs index 70f7323b0..635faa23c 100644 --- a/src/ZkFold/Base/Protocol/Plonkup/Relation.hs +++ b/src/ZkFold/Base/Protocol/Plonkup/Relation.hs @@ -38,7 +38,6 @@ toPlonkRelation :: forall i n l a . => KnownNat (3 * n) => KnownNat l => Arithmetic a - => Scale a a => Vector l (Var (Vector i)) -> ArithmeticCircuit a (Vector i) Par1 -> Maybe (PlonkRelation n i a) diff --git a/src/ZkFold/Base/Protocol/Protostar/Internal.hs b/src/ZkFold/Base/Protocol/Protostar/Internal.hs index 4b66d2cbe..0bdc0332d 100644 --- a/src/ZkFold/Base/Protocol/Protostar/Internal.hs +++ b/src/ZkFold/Base/Protocol/Protostar/Internal.hs @@ -3,7 +3,7 @@ module ZkFold.Base.Protocol.Protostar.Internal where import Numeric.Natural (Natural) import Prelude (Eq, Integer, Ord, Show) -import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup, AdditiveMonoid, AdditiveSemigroup, Scale) +import ZkFold.Base.Algebra.Basic.Class import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.Polynomials.Multivariate diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 0c4a5594f..c6cf17c1c 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -135,7 +135,6 @@ acPrint ac = do checkClosedCircuit :: forall a n . Arithmetic a - => Scale a a => Show a => ArithmeticCircuit a U1 n -> Property @@ -149,7 +148,6 @@ checkClosedCircuit c = withMaxSuccess 1 $ conjoin [ testPoly p | p <- elems (acS checkCircuit :: Arbitrary (i a) => Arithmetic a - => Scale a a => Show a => Representable i => ArithmeticCircuit a i n diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs index 15c606129..9fddd91ce 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs @@ -39,7 +39,7 @@ instance arbitrary' :: forall a i . - (Arithmetic a, Arbitrary a, FromConstant a a) => + (Arithmetic a, Arbitrary a) => (Haskell.Ord (Rep i), Representable i, Haskell.Foldable i) => (ToConstant (Rep i), Const (Rep i) ~ Natural) => FieldElement (ArithmeticCircuit a i) -> Natural -> diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index c106f144c..bb8b65939 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -265,7 +265,7 @@ exec :: Functor o => ArithmeticCircuit a U1 o -> o a exec ac = eval ac U1 -- | Applies the values of the first couple of inputs to the arithmetic circuit. -apply :: (Eq a, Field a, Ord (Rep j), Scale a a, FromConstant a a, Representable i) => i a -> ArithmeticCircuit a (i :*: j) U1 -> ArithmeticCircuit a j U1 +apply :: (Eq a, Field a, Ord (Rep j), Representable i) => i a -> ArithmeticCircuit a (i :*: j) U1 -> ArithmeticCircuit a j U1 apply xs ac = ac { acSystem = fmap (evalPolynomial evalMonomial varF) (acSystem ac) , acWitness = fmap witF (acWitness ac) diff --git a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs index 2fa95ed7a..60f18ef6a 100644 --- a/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs +++ b/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs @@ -42,7 +42,7 @@ instance (Arithmetic a, Arbitrary (i a), Arbitrary (ArithmeticCircuit a i Par1), , witnessInput = wi } -mapVarArithmeticCircuit :: (Field a, Scale a a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o +mapVarArithmeticCircuit :: (Field a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o mapVarArithmeticCircuit (ArithmeticCircuitTest ac wi) = let vars = [v | NewVar v <- getAllVars ac] forward = Map.fromAscList $ zip vars [0..] diff --git a/src/ZkFold/Symbolic/Data/FFA.hs b/src/ZkFold/Symbolic/Data/FFA.hs index debc0c087..110b9fe8d 100644 --- a/src/ZkFold/Symbolic/Data/FFA.hs +++ b/src/ZkFold/Symbolic/Data/FFA.hs @@ -150,6 +150,10 @@ instance (FromConstant a (Zp p), Symbolic c) => FromConstant a (FFA p c) where impl :: Natural -> Vector Size (BaseField c) impl x = fromConstant . (x `mod`) <$> coprimes @(BaseField c) +instance {-# OVERLAPPING #-} FromConstant (FFA p c) (FFA p c) + +instance {-# OVERLAPPING #-} (KnownNat p, Symbolic c) => Scale (FFA p c) (FFA p c) + instance (KnownNat p, Symbolic c) => MultiplicativeSemigroup (FFA p c) where FFA x * FFA y = FFA $ symbolic2F x y (\u v -> fromZp (toZp u * toZp v :: Zp p)) (mul @p) diff --git a/src/ZkFold/Symbolic/Data/FieldElement.hs b/src/ZkFold/Symbolic/Data/FieldElement.hs index ad423bb90..d8fc7740a 100644 --- a/src/ZkFold/Symbolic/Data/FieldElement.hs +++ b/src/ZkFold/Symbolic/Data/FieldElement.hs @@ -49,11 +49,14 @@ instance Symbolic c => Exponent (FieldElement c) Natural where instance Symbolic c => Exponent (FieldElement c) Integer where (^) = intPowF -instance (Symbolic c, MultiplicativeMonoid k, Scale k (BaseField c)) => - Scale k (FieldElement c) where +instance (Symbolic c, Scale k (BaseField c)) => Scale k (FieldElement c) where scale k (FieldElement c) = FieldElement $ fromCircuitF c $ \(Par1 i) -> Par1 <$> newAssigned (\x -> fromConstant (scale k one :: BaseField c) * x i) +instance {-# OVERLAPPING #-} FromConstant (FieldElement c) (FieldElement c) + +instance {-# OVERLAPPING #-} Symbolic c => Scale (FieldElement c) (FieldElement c) + instance Symbolic c => MultiplicativeSemigroup (FieldElement c) where FieldElement x * FieldElement y = FieldElement $ fromCircuit2F x y $ \(Par1 i) (Par1 j) -> Par1 <$> newAssigned (\w -> w i * w j) diff --git a/src/ZkFold/Symbolic/Data/UInt.hs b/src/ZkFold/Symbolic/Data/UInt.hs index e497877f1..de9b01fc0 100644 --- a/src/ZkFold/Symbolic/Data/UInt.hs +++ b/src/ZkFold/Symbolic/Data/UInt.hs @@ -65,7 +65,9 @@ instance (Symbolic c, KnownNat n, KnownRegisterSize r) => FromConstant Natural ( instance (Symbolic c, KnownNat n, KnownRegisterSize r) => FromConstant Integer (UInt n r c) where fromConstant = fromConstant . naturalFromInteger . (`Haskell.mod` (2 ^ getNatural @n)) -instance (Symbolic c, KnownNat n, KnownRegisterSize r, FromConstant a (UInt n r c), MultiplicativeMonoid a) => Scale a (UInt n r c) +instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Scale Natural (UInt n r c) + +instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Scale Integer (UInt n r c) instance MultiplicativeMonoid (UInt n r c) => Exponent (UInt n r c) Natural where (^) = natPow @@ -238,7 +240,6 @@ instance (Symbolic c, KnownNat n, KnownRegisterSize r) => Ord (Bool c) (UInt n r min x y = bool @(Bool c) x y $ x > y - instance (Symbolic c, KnownNat n, KnownRegisterSize r) => AdditiveSemigroup (UInt n r c) where UInt xc + UInt yc = UInt $ symbolic2F xc yc (\u v -> naturalToVector @c @n @r $ vectorToNatural u (registerSize @(BaseField c) @n @r) + vectorToNatural v (registerSize @(BaseField c) @n @r)) solve where diff --git a/tests/Tests/ArithmeticCircuit.hs b/tests/Tests/ArithmeticCircuit.hs index 6302bdac8..62400cf20 100644 --- a/tests/Tests/ArithmeticCircuit.hs +++ b/tests/Tests/ArithmeticCircuit.hs @@ -24,20 +24,20 @@ import ZkFold.Symbolic.Data.Eq import ZkFold.Symbolic.Data.FieldElement import ZkFold.Symbolic.Data.Ord ((<=)) -correctHom0 :: forall a . (Arithmetic a, Scale a a, Show a) => (forall b . Field b => b) -> Property +correctHom0 :: forall a . (Arithmetic a, Show a) => (forall b . Field b => b) -> Property correctHom0 f = let r = fromFieldElement f in withMaxSuccess 1 $ checkClosedCircuit r .&&. exec1 r === f @a -correctHom1 :: forall a . (Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a U1))) => (forall b . Field b => b -> b) -> a -> Property +correctHom1 :: forall a . (Arithmetic a, Show a) => (forall b . Field b => b -> b) -> a -> Property correctHom1 f x = let r = fromFieldElement $ f (fromConstant x) in checkClosedCircuit r .&&. exec1 r === f x -correctHom2 :: forall a . (Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a U1))) => (forall b . Field b => b -> b -> b) -> a -> a -> Property +correctHom2 :: forall a . (Arithmetic a, Show a) => (forall b . Field b => b -> b -> b) -> a -> a -> Property correctHom2 f x y = let r = fromFieldElement $ f (fromConstant x) (fromConstant y) in checkClosedCircuit r .&&. exec1 r === f x y it :: Testable prop => String -> prop -> Spec it desc prop = Test.Hspec.it desc (property prop) -specArithmeticCircuit' :: forall a . (Arbitrary a, Arithmetic a, Scale a a, Show a, FromConstant a (FieldElement (ArithmeticCircuit a U1))) => IO () +specArithmeticCircuit' :: forall a . (Arbitrary a, Arithmetic a, Show a) => IO () specArithmeticCircuit' = hspec $ do describe "ArithmeticCircuit specification" $ do it "embeds constants" $ correctHom1 @a id diff --git a/tests/Tests/Arithmetization.hs b/tests/Tests/Arithmetization.hs index eddbec02f..73501d4ab 100644 --- a/tests/Tests/Arithmetization.hs +++ b/tests/Tests/Arithmetization.hs @@ -14,7 +14,7 @@ import Tests.Arithmetization.Test2 (specArithmetizatio import Tests.Arithmetization.Test3 (specArithmetization3) import Tests.Arithmetization.Test4 (specArithmetization4) -import ZkFold.Base.Algebra.Basic.Class (FromConstant, Scale, ToConstant (..)) +import ZkFold.Base.Algebra.Basic.Class (ToConstant (..)) import ZkFold.Base.Algebra.Basic.Field (Zp) import ZkFold.Base.Algebra.Basic.Number (Natural) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 @@ -22,7 +22,7 @@ import ZkFold.Base.Data.Vector (Vector) import ZkFold.Symbolic.Compiler import ZkFold.Symbolic.MonadCircuit (Arithmetic) -propCircuitInvariance :: (Arithmetic a, Scale a a, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i Par1 -> Bool +propCircuitInvariance :: (Arithmetic a, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i Par1 -> Bool propCircuitInvariance act@(ArithmeticCircuitTest ac wi) = let ArithmeticCircuitTest ac' wi' = mapVarArithmeticCircuit act v = ac `eval` wi @@ -31,7 +31,7 @@ propCircuitInvariance act@(ArithmeticCircuitTest ac wi) = specArithmetization' :: forall a i . - (FromConstant a a, Scale a a, Arithmetic a, Arbitrary a, Arbitrary (i a)) => + (Arithmetic a, Arbitrary a, Arbitrary (i a)) => (Show a, Show (ArithmeticCircuitTest a i Par1)) => (Arbitrary (Rep i), Ord (Rep i), Representable i, Traversable i) => (ToConstant (Rep i), Const (Rep i) ~ Natural) => IO () diff --git a/tests/Tests/Arithmetization/Test1.hs b/tests/Tests/Arithmetization/Test1.hs index f99663cd9..45e153eb5 100644 --- a/tests/Tests/Arithmetization/Test1.hs +++ b/tests/Tests/Arithmetization/Test1.hs @@ -22,7 +22,7 @@ import ZkFold.Symbolic.Interpreter (Interpreter) import ZkFold.Symbolic.MonadCircuit (Arithmetic) -- f x y = if (2 / x > y) then (x ^ 2 + 3 * x + 5) else (4 * x ^ 3) -testFunc :: forall c . (Symbolic c, Field (FieldElement c)) => FieldElement c -> FieldElement c -> FieldElement c +testFunc :: forall c . Symbolic c => FieldElement c -> FieldElement c -> FieldElement c testFunc x y = let c = fromConstant @Integer @(FieldElement c) g1 = x ^ (2 :: Natural) + c 3 * x + c 5 @@ -30,11 +30,11 @@ testFunc x y = g3 = c 2 // x in (g3 == y :: Bool c) ? g1 $ g2 -testResult :: forall a . (FromConstant a a, Arithmetic a) => ArithmeticCircuit a (Vector 2) Par1 -> a -> a -> Haskell.Bool +testResult :: forall a . Arithmetic a => ArithmeticCircuit a (Vector 2) Par1 -> a -> a -> Haskell.Bool testResult r x y = fromConstant (unPar1 $ eval r (unsafeToVector [x, y])) Haskell.== testFunc @(Interpreter a) (fromConstant x) (fromConstant y) -specArithmetization1 :: forall a . (FromConstant a a, Arithmetic a, Arbitrary a, Show a) => Spec +specArithmetization1 :: forall a . (Arithmetic a, Arbitrary a, Show a) => Spec specArithmetization1 = do describe "Arithmetization test 1" $ do it "should pass" $ do diff --git a/tests/Tests/NonInteractiveProof/Plonkup.hs b/tests/Tests/NonInteractiveProof/Plonkup.hs index 48ab329cd..97c324876 100644 --- a/tests/Tests/NonInteractiveProof/Plonkup.hs +++ b/tests/Tests/NonInteractiveProof/Plonkup.hs @@ -17,9 +17,8 @@ import Test.Hspec import Test.QuickCheck import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup (..), AdditiveSemigroup (..), - FiniteField, FromConstant (..), - MultiplicativeSemigroup (..), Scale (..), negate, - zero, (-!)) + FiniteField, MultiplicativeSemigroup (..), zero, + (-!)) import ZkFold.Base.Algebra.Basic.Number (KnownNat, value) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_G2) import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (..)) @@ -37,7 +36,7 @@ type PlonkPolyLengthBS = 32 type PlonkBS n = Plonk 1 PlonkPolyLengthBS n BLS12_381_G1 BLS12_381_G2 ByteString type PlonkPolyExtendedLengthBS = PlonkPolyExtendedLength PlonkPolyLengthBS -propPlonkConstraintConversion :: (Eq a, Scale a a, FromConstant a a, FiniteField a) => PlonkConstraint 1 a -> Bool +propPlonkConstraintConversion :: (Eq a, FiniteField a) => PlonkConstraint 1 a -> Bool propPlonkConstraintConversion p = toPlonkConstraint (fromPlonkConstraint p) == p