Skip to content

Commit

Permalink
Merge pull request #246 from zkFold/TurtlePU/cleanup-FromConstant-Scale
Browse files Browse the repository at this point in the history
Cleanup `FromConstant` & `Scale`
  • Loading branch information
vlasin authored Sep 11, 2024
2 parents 420da6a + 4a60f19 commit 1ee3266
Show file tree
Hide file tree
Showing 19 changed files with 114 additions and 80 deletions.
105 changes: 58 additions & 47 deletions src/ZkFold/Base/Algebra/Basic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ class FromConstant a b where
--
-- [Homomorphism] @fromConstant (c + d) == fromConstant c + fromConstant d@
fromConstant :: a -> b

instance FromConstant a a where
default fromConstant :: a ~ b => a -> b
fromConstant = id

instance FromConstant a a

-- | A class of algebraic structures which can be converted to "constant type"
-- related with it: natural numbers, integers, rationals etc. Subject to the
-- following law:
Expand All @@ -60,10 +61,40 @@ instance ToConstant Void where

--------------------------------------------------------------------------------

{- | A class of types with a binary associative operation with a multiplicative
feel to it. Not necessarily commutative.
-}
class MultiplicativeSemigroup a where
-- | A class for actions where multiplicative notation is the most natural
-- (including multiplication by constant itself).
class Scale b a where
-- | A left monoid action on a type. Should satisfy the following:
--
-- [Compatibility] @scale (c * d) a == scale c (scale d a)@
-- [Left identity] @scale one a == a@
--
-- If, in addition, a cast from constant is defined, they should agree:
--
-- [Scale agrees] @scale c a == fromConstant c * a@
-- [Cast agrees] @fromConstant c == scale c one@
--
-- If the action is on an abelian structure, scaling should respect it:
--
-- [Left distributivity] @scale c (a + b) == scale c a + scale c b@
-- [Right absorption] @scale c zero == zero@
--
-- If, in addition, the scaling itself is abelian, this structure should
-- propagate:
--
-- [Right distributivity] @scale (c + d) a == scale c a + scale d a@
-- [Left absorption] @scale zero a == zero@
--
-- The default implementation is the multiplication by a constant.
scale :: b -> a -> a
default scale :: (FromConstant b a, MultiplicativeSemigroup a) => b -> a -> a
scale = (*) . fromConstant

instance MultiplicativeSemigroup a => Scale a a

-- | A class of types with a binary associative operation with a multiplicative
-- feel to it. Not necessarily commutative.
class (FromConstant a a, Scale a a) => MultiplicativeSemigroup a where
-- | A binary associative operation. The following should hold:
--
-- [Associativity] @x * (y * z) == (x * y) * z@
Expand All @@ -72,11 +103,12 @@ class MultiplicativeSemigroup a where
product1 :: (Foldable t, MultiplicativeSemigroup a) => t a -> a
product1 = foldl1 (*)

{- | A class for semigroup (and monoid) actions on types where exponential
notation is the most natural (including an exponentiation itself).
-}
class MultiplicativeSemigroup b => Exponent a b where
-- | A right semigroup action on a type. The following should hold:
-- | A class for actions on types where exponential notation is the most natural
-- (including an exponentiation itself).
class Exponent a b where
-- | A right action on a type.
--
-- If exponents form a semigroup, the following should hold:
--
-- [Compatibility] @a ^ (m * n) == (a ^ m) ^ n@
--
Expand Down Expand Up @@ -130,41 +162,6 @@ product = foldl' (*) one
multiExp :: (MultiplicativeMonoid a, Exponent a b, Foldable t) => a -> t b -> a
multiExp a = foldl' (\x y -> x * (a ^ y)) one

{- | A class for monoid actions where multiplicative notation is the most
natural (including multiplication by constant itself).
-}
class MultiplicativeMonoid b => Scale b a where
-- | A left monoid action on a type. Should satisfy the following:
--
-- [Compatibility] @scale (c * d) a == scale c (scale d a)@
-- [Left identity] @scale one a == a@
--
-- If, in addition, a cast from constant is defined, they should agree:
--
-- [Scale agrees] @scale c a == fromConstant c * a@
-- [Cast agrees] @fromConstant c == scale c one@
--
-- If the action is on an abelian structure, scaling should respect it:
--
-- [Left distributivity] @scale c (a + b) == scale c a + scale c b@
-- [Right absorption] @scale c zero == zero@
--
-- If, in addition, the scaling itself is abelian, this structure should
-- propagate:
--
-- [Right distributivity] @scale (c + d) a == scale c a + scale d a@
-- [Left absorption] @scale zero a == zero@
--
-- The default implementation is the multiplication by a constant.
scale :: b -> a -> a
default scale :: (FromConstant b a, MultiplicativeSemigroup a) => b -> a -> a
scale = (*) . fromConstant

instance MultiplicativeMonoid a => Scale a a

instance {-# OVERLAPPABLE #-} (Scale b a, Functor f) => Scale b (f a) where
scale = fmap . scale

{- | A class of groups in a multiplicative notation.
While exponentiation by an integer is specified in a constraint, a default
Expand Down Expand Up @@ -201,7 +198,7 @@ intPow a n | n < 0 = invert a ^ naturalFromInteger (-n)
--------------------------------------------------------------------------------

-- | A class of types with a binary associative, commutative operation.
class AdditiveSemigroup a where
class FromConstant a a => AdditiveSemigroup a where
-- | A binary associative commutative operation. The following should hold:
--
-- [Associativity] @x + (y + z) == (x + y) + z@
Expand Down Expand Up @@ -628,6 +625,10 @@ instance MultiplicativeMonoid a => Exponent a Bool where

--------------------------------------------------------------------------------

instance {-# OVERLAPPING #-} FromConstant [a] [a]

instance {-# OVERLAPPING #-} MultiplicativeSemigroup a => Scale [a] [a]

instance MultiplicativeSemigroup a => MultiplicativeSemigroup [a] where
(*) = zipWith (*)

Expand All @@ -643,6 +644,9 @@ instance MultiplicativeGroup a => MultiplicativeGroup [a] where
instance AdditiveSemigroup a => AdditiveSemigroup [a] where
(+) = zipWith (+)

instance Scale b a => Scale b [a] where
scale = map . scale

instance AdditiveMonoid a => AdditiveMonoid [a] where
zero = repeat zero

Expand All @@ -658,6 +662,10 @@ instance Ring a => Ring [a]

--------------------------------------------------------------------------------

instance {-# OVERLAPPING #-} FromConstant (p -> a) (p -> a)

instance {-# OVERLAPPING #-} MultiplicativeSemigroup a => Scale (p -> a) (p -> a)

instance MultiplicativeSemigroup a => MultiplicativeSemigroup (p -> a) where
p1 * p2 = \x -> p1 x * p2 x

Expand All @@ -673,6 +681,9 @@ instance MultiplicativeGroup a => MultiplicativeGroup (p -> a) where
instance AdditiveSemigroup a => AdditiveSemigroup (p -> a) where
p1 + p2 = \x -> p1 x + p2 x

instance Scale b a => Scale b (p -> a) where
scale = (.) . scale

instance AdditiveMonoid a => AdditiveMonoid (p -> a) where
zero = const zero

Expand Down
10 changes: 9 additions & 1 deletion src/ZkFold/Base/Algebra/Basic/Field.hs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ instance KnownNat p => Random (Zp p) where
--
-- Note that left distributivity is satisfied, meaning
-- @a ^ (m + n) = (a ^ m) * (a ^ n)@.
instance (KnownNat p, MultiplicativeGroup a, Order a ~ p) => Exponent a (Zp p) where
instance (MultiplicativeGroup a, Order a ~ p) => Exponent a (Zp p) where
a ^ n = a ^ fromZp n

----------------------------- Field Extensions --------------------------------
Expand All @@ -211,6 +211,8 @@ instance Ord f => Ord (Ext2 f e) where
instance (KnownNat (Order (Ext2 f e)), KnownNat (NumberOfBits (Ext2 f e))) => Finite (Ext2 f e) where
type Order (Ext2 f e) = Order f ^ 2

instance {-# OVERLAPPING #-} FromConstant (Ext2 f e) (Ext2 f e)

instance Field f => AdditiveSemigroup (Ext2 f e) where
Ext2 a b + Ext2 c d = Ext2 (a + c) (b + d)

Expand All @@ -224,6 +226,8 @@ instance Field f => AdditiveGroup (Ext2 f e) where
negate (Ext2 a b) = Ext2 (negate a) (negate b)
Ext2 a b - Ext2 c d = Ext2 (a - c) (b - d)

instance {-# OVERLAPPING #-} (Field f, Eq f, IrreduciblePoly f e) => Scale (Ext2 f e) (Ext2 f e)

instance (Field f, Eq f, IrreduciblePoly f e) => MultiplicativeSemigroup (Ext2 f e) where
Ext2 a b * Ext2 c d = fromConstant (toPoly [a, b] * toPoly [c, d])

Expand Down Expand Up @@ -275,6 +279,8 @@ instance Ord f => Ord (Ext3 f e) where
instance (KnownNat (Order (Ext3 f e)), KnownNat (NumberOfBits (Ext3 f e))) => Finite (Ext3 f e) where
type Order (Ext3 f e) = Order f ^ 3

instance {-# OVERLAPPING #-} FromConstant (Ext3 f e) (Ext3 f e)

instance Field f => AdditiveSemigroup (Ext3 f e) where
Ext3 a b c + Ext3 d e f = Ext3 (a + d) (b + e) (c + f)

Expand All @@ -288,6 +294,8 @@ instance Field f => AdditiveGroup (Ext3 f e) where
negate (Ext3 a b c) = Ext3 (negate a) (negate b) (negate c)
Ext3 a b c - Ext3 d e f = Ext3 (a - d) (b - e) (c - f)

instance {-# OVERLAPPING #-} (Field f, Eq f, IrreduciblePoly f e) => Scale (Ext3 f e) (Ext3 f e)

instance (Field f, Eq f, IrreduciblePoly f e) => MultiplicativeSemigroup (Ext3 f e) where
Ext3 a b c * Ext3 d e f = fromConstant (toPoly [a, b, c] * toPoly [d, e, f])

Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Base/Algebra/Basic/Sources.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ instance {-# OVERLAPPING #-} Ord i => Scale (Sources a i) (Sources a i) where
instance {-# OVERLAPPABLE #-} FromConstant c (Sources a i) where
fromConstant = const empty

instance {-# OVERLAPPABLE #-} MultiplicativeMonoid c => Scale c (Sources a i) where
instance {-# OVERLAPPABLE #-} Scale c (Sources a i) where
scale = const id

instance Ord i => AdditiveSemigroup (Sources a i) where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ evalPolynomial
-> b
evalPolynomial e f (P p) = foldr (\(c, m) x -> x + scale c (e f m)) zero p

variables :: forall c v .
(Ord v, MultiplicativeMonoid c) =>
Poly c v Natural -> Set v
variables :: forall c v . Ord v => Poly c v Natural -> Set v
variables = runSources . evalPolynomial evalMonomial (Sources @c . singleton)

mapVarPolynomial :: Variable i => Map i i-> Poly c i j -> Poly c i j
Expand Down Expand Up @@ -83,6 +81,8 @@ instance Polynomial c i j => Ord (Poly c i j) where
instance (Arbitrary c, Arbitrary (Mono i j)) => Arbitrary (Poly c i j) where
arbitrary = P <$> arbitrary

instance {-# OVERLAPPING #-} FromConstant (Poly c i j) (Poly c i j)

instance Polynomial c i j => AdditiveSemigroup (Poly c i j) where
P l + P r = P $ go l r
where
Expand All @@ -106,6 +106,8 @@ instance Polynomial c i j => AdditiveMonoid (Poly c i j) where
instance Polynomial c i j => AdditiveGroup (Poly c i j) where
negate (P p) = P $ map (first negate) p

instance {-# OVERLAPPING #-} Polynomial c i j => Scale (Poly c i j) (Poly c i j)

instance Polynomial c i j => MultiplicativeSemigroup (Poly c i j) where
P l * r = foldl' (+) (P []) $ map (`scaleM` r) l

Expand Down
11 changes: 10 additions & 1 deletion src/ZkFold/Base/Algebra/Polynomials/Univariate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ toPoly = removeZeros . P
fromPoly :: Poly c -> V.Vector c
fromPoly (P cs) = cs

instance {-# OVERLAPPING #-} FromConstant (Poly c) (Poly c)

instance FromConstant c c' => FromConstant c (Poly c') where
fromConstant = P . V.singleton . fromConstant

Expand All @@ -73,6 +75,11 @@ instance (Ring c, Eq c) => AdditiveSemigroup (Poly c) where
lPadded = l V.++ V.replicate (len P.- V.length l) zero
rPadded = r V.++ V.replicate (len P.- V.length r) zero

instance {-# OVERLAPPING #-} (Field c, Eq c) => Scale (Poly c) (Poly c)

instance Scale k c => Scale k (Poly c) where
scale = fmap . scale

instance (Ring c, Eq c) => AdditiveMonoid (Poly c) where
zero = P V.empty

Expand Down Expand Up @@ -269,7 +276,7 @@ vec2poly :: (Ring c, Eq c) => PolyVec c size -> Poly c
vec2poly (PV cs) = removeZeros $ P cs

instance Scale c' c => Scale c' (PolyVec c size) where
scale c (PV p) = PV (scale c p)
scale c (PV p) = PV (scale c <$> p)

instance Ring c => AdditiveSemigroup (PolyVec c size) where
PV l + PV r = PV $ V.zipWith (+) l r
Expand All @@ -283,6 +290,8 @@ instance (Ring c, KnownNat size) => AdditiveGroup (PolyVec c size) where
instance (Field c, KnownNat size, Eq c) => Exponent (PolyVec c size) Natural where
(^) = natPow

instance {-# OVERLAPPING #-} (Field c, KnownNat size, Eq c) => Scale (PolyVec c size) (PolyVec c size)

-- TODO (Issue #18): check for overflow
instance (Field c, KnownNat size, Eq c) => MultiplicativeSemigroup (PolyVec c size) where
l * r = poly2vec $ vec2poly l * vec2poly r
Expand Down
4 changes: 2 additions & 2 deletions src/ZkFold/Base/Protocol/Plonkup/PlonkConstraint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ instance (Arbitrary a, Finite a, ToConstant a, Const a ~ Natural, KnownNat i) =>
_ -> error "impossible"
return $ PlonkConstraint qm ql qr qo qc x y z

toPlonkConstraint :: forall a i . (Eq a, FiniteField a, Scale a a, KnownNat i) => Poly a (Var (Vector i)) Natural -> PlonkConstraint i a
toPlonkConstraint :: forall a i . (Eq a, FiniteField a, KnownNat i) => Poly a (Var (Vector i)) Natural -> PlonkConstraint i a
toPlonkConstraint p =
let xs = map Just $ toList (variables p)
perms = nubOrd $ map (take 3) $ permutations $ case length xs of
Expand Down Expand Up @@ -86,7 +86,7 @@ toPlonkConstraint p =

in head $ mapMaybe getCoefs perms

fromPlonkConstraint :: (Eq a, Scale a a, FromConstant a a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural
fromPlonkConstraint :: (Eq a, Field a, KnownNat i) => PlonkConstraint i a -> Poly a (Var (Vector i)) Natural
fromPlonkConstraint (PlonkConstraint qm ql qr qo qc a b c) =
let xvar = maybe zero var
xa = xvar a
Expand Down
1 change: 0 additions & 1 deletion src/ZkFold/Base/Protocol/Plonkup/Relation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ toPlonkRelation :: forall i n l a .
=> KnownNat (3 * n)
=> KnownNat l
=> Arithmetic a
=> Scale a a
=> Vector l (Var (Vector i))
-> ArithmeticCircuit a (Vector i) Par1
-> Maybe (PlonkRelation n i a)
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Base/Protocol/Protostar/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ZkFold.Base.Protocol.Protostar.Internal where
import Numeric.Natural (Natural)
import Prelude (Eq, Integer, Ord, Show)

import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup, AdditiveMonoid, AdditiveSemigroup, Scale)
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Field (Zp)
import ZkFold.Base.Algebra.Polynomials.Multivariate

Expand Down
2 changes: 0 additions & 2 deletions src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ acPrint ac = do
checkClosedCircuit
:: forall a n
. Arithmetic a
=> Scale a a
=> Show a
=> ArithmeticCircuit a U1 n
-> Property
Expand All @@ -149,7 +148,6 @@ checkClosedCircuit c = withMaxSuccess 1 $ conjoin [ testPoly p | p <- elems (acS
checkCircuit
:: Arbitrary (i a)
=> Arithmetic a
=> Scale a a
=> Show a
=> Representable i
=> ArithmeticCircuit a i n
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ instance

arbitrary' ::
forall a i .
(Arithmetic a, Arbitrary a, FromConstant a a) =>
(Arithmetic a, Arbitrary a) =>
(Haskell.Ord (Rep i), Representable i, Haskell.Foldable i) =>
(ToConstant (Rep i), Const (Rep i) ~ Natural) =>
FieldElement (ArithmeticCircuit a i) -> Natural ->
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ exec :: Functor o => ArithmeticCircuit a U1 o -> o a
exec ac = eval ac U1

-- | Applies the values of the first couple of inputs to the arithmetic circuit.
apply :: (Eq a, Field a, Ord (Rep j), Scale a a, FromConstant a a, Representable i) => i a -> ArithmeticCircuit a (i :*: j) U1 -> ArithmeticCircuit a j U1
apply :: (Eq a, Field a, Ord (Rep j), Representable i) => i a -> ArithmeticCircuit a (i :*: j) U1 -> ArithmeticCircuit a j U1
apply xs ac = ac
{ acSystem = fmap (evalPolynomial evalMonomial varF) (acSystem ac)
, acWitness = fmap witF (acWitness ac)
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Map.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ instance (Arithmetic a, Arbitrary (i a), Arbitrary (ArithmeticCircuit a i Par1),
, witnessInput = wi
}

mapVarArithmeticCircuit :: (Field a, Scale a a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o
mapVarArithmeticCircuit :: (Field a, Eq a, Functor o, Ord (Rep i), Representable i, Foldable i) => ArithmeticCircuitTest a i o -> ArithmeticCircuitTest a i o
mapVarArithmeticCircuit (ArithmeticCircuitTest ac wi) =
let vars = [v | NewVar v <- getAllVars ac]
forward = Map.fromAscList $ zip vars [0..]
Expand Down
4 changes: 4 additions & 0 deletions src/ZkFold/Symbolic/Data/FFA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ instance (FromConstant a (Zp p), Symbolic c) => FromConstant a (FFA p c) where
impl :: Natural -> Vector Size (BaseField c)
impl x = fromConstant . (x `mod`) <$> coprimes @(BaseField c)

instance {-# OVERLAPPING #-} FromConstant (FFA p c) (FFA p c)

instance {-# OVERLAPPING #-} (KnownNat p, Symbolic c) => Scale (FFA p c) (FFA p c)

instance (KnownNat p, Symbolic c) => MultiplicativeSemigroup (FFA p c) where
FFA x * FFA y =
FFA $ symbolic2F x y (\u v -> fromZp (toZp u * toZp v :: Zp p)) (mul @p)
Expand Down
7 changes: 5 additions & 2 deletions src/ZkFold/Symbolic/Data/FieldElement.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ instance Symbolic c => Exponent (FieldElement c) Natural where
instance Symbolic c => Exponent (FieldElement c) Integer where
(^) = intPowF

instance (Symbolic c, MultiplicativeMonoid k, Scale k (BaseField c)) =>
Scale k (FieldElement c) where
instance (Symbolic c, Scale k (BaseField c)) => Scale k (FieldElement c) where
scale k (FieldElement c) = FieldElement $ fromCircuitF c $ \(Par1 i) ->
Par1 <$> newAssigned (\x -> fromConstant (scale k one :: BaseField c) * x i)

instance {-# OVERLAPPING #-} FromConstant (FieldElement c) (FieldElement c)

instance {-# OVERLAPPING #-} Symbolic c => Scale (FieldElement c) (FieldElement c)

instance Symbolic c => MultiplicativeSemigroup (FieldElement c) where
FieldElement x * FieldElement y = FieldElement $ fromCircuit2F x y
$ \(Par1 i) (Par1 j) -> Par1 <$> newAssigned (\w -> w i * w j)
Expand Down
Loading

0 comments on commit 1ee3266

Please sign in to comment.