Skip to content

Commit

Permalink
Merge pull request #245 from zkFold/TurtlePU/update-witnessfield
Browse files Browse the repository at this point in the history
Update `WitnessField`
  • Loading branch information
vlasin authored Sep 8, 2024
2 parents b48dbac + 26396af commit 420da6a
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 62 deletions.
25 changes: 15 additions & 10 deletions src/ZkFold/Base/Algebra/Basic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,24 @@ intScale n a | n < 0 = naturalFromInteger (-n) `scale` negate a
-}
class (AdditiveMonoid a, MultiplicativeMonoid a, FromConstant Natural a) => Semiring a

{- | A Euclidean domain @R@ is an integral domain which can be endowed
with at least one function @f : R\{0} -> R+@ s.t.
If @a@ and @b@ are in @R@ and @b@ is nonzero, then there exist @q@ and @r@ in @R@ such that
@a = bq + r@ and either @r = 0@ or @f(r) < f(b)@.
{- | A semi-Euclidean-domain @a@ is a semiring without zero divisors which can
be endowed with at least one function @f : a\{0} -> R+@ s.t. if @x@ and @y@ are
in @a@ and @y@ is nonzero, then there exist @q@ and @r@ in @a@ such that
@x = qy + r@ and either @r = 0@ or @f(r) < f(y)@.
@q@ and @r@ are called respectively a quotient and a remainder of the division (or Euclidean division) of @a@ by @b@.
@q@ and @r@ are called respectively a quotient and a remainder of the division
(or Euclidean division) of @x@ by @y@.
The function @divMod@ associated with this class produces @q@ and @r@ given @a@ and @b@.
The function @divMod@ associated with this class produces @q@ and @r@
given @a@ and @b@.
This is a generalization of a notion of Euclidean domains to semirings.
-}
class Semiring a => EuclideanDomain a where
{-# MINIMAL divMod #-}
class Semiring a => SemiEuclidean a where
{-# MINIMAL divMod | (div, mod) #-}

divMod :: a -> a -> (a, a)
divMod n d = (n `div` d, n `mod` d)

div :: a -> a -> a
div n d = Haskell.fst $ divMod n d
Expand Down Expand Up @@ -478,7 +483,7 @@ instance AdditiveMonoid Natural where

instance Semiring Natural

instance EuclideanDomain Natural where
instance SemiEuclidean Natural where
divMod = Haskell.divMod

instance BinaryExpansion Natural where
Expand Down Expand Up @@ -516,7 +521,7 @@ instance FromConstant Natural Integer where

instance Semiring Integer

instance EuclideanDomain Integer where
instance SemiEuclidean Integer where
divMod = Haskell.divMod

instance Ring Integer
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Base/Algebra/Basic/Field.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ instance KnownNat p => FromConstant Natural (Zp p) where

instance KnownNat p => Semiring (Zp p)

instance KnownNat p => EuclideanDomain (Zp p) where
instance KnownNat p => SemiEuclidean (Zp p) where
divMod a b = let (q, r) = Haskell.divMod (fromZp a) (fromZp b)
in (toZp . fromIntegral $ q, toZp . fromIntegral $ r)

Expand Down
48 changes: 38 additions & 10 deletions src/ZkFold/Base/Algebra/Basic/Sources.hs
Original file line number Diff line number Diff line change
@@ -1,20 +1,37 @@
{-# LANGUAGE DerivingStrategies #-}

module ZkFold.Base.Algebra.Basic.Sources where

module ZkFold.Base.Algebra.Basic.Sources (Sources (..)) where

import Data.Function (const, id)
import Data.Kind (Type)
import Data.Maybe (Maybe (..))
import Data.Monoid (Monoid (..))
import Data.Ord (Ord)
import Data.Semigroup (Semigroup (..))
import Data.Set (Set)
import Prelude hiding (replicate)
import qualified Data.Set as Set
import Numeric.Natural (Natural)
import Prelude (Integer)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Prelude

newtype Sources a i = Sources { runSources :: Set i }
deriving newtype (Semigroup, Monoid)

instance MultiplicativeSemigroup c => Exponent (Sources a i) c where
(^) = const
empty :: Sources a i
empty = Sources Set.empty

instance {-# OVERLAPPING #-} FromConstant (Sources a i) (Sources a i) where
fromConstant = id

instance MultiplicativeMonoid c => Scale c (Sources a i) where
instance {-# OVERLAPPING #-} Ord i => Scale (Sources a i) (Sources a i) where
scale = (<>)

instance {-# OVERLAPPABLE #-} FromConstant c (Sources a i) where
fromConstant = const empty

instance {-# OVERLAPPABLE #-} MultiplicativeMonoid c => Scale c (Sources a i) where
scale = const id

instance Ord i => AdditiveSemigroup (Sources a i) where
Expand All @@ -32,23 +49,34 @@ instance Finite a => Finite (Sources a i) where
instance Ord i => MultiplicativeSemigroup (Sources a i) where
(*) = (<>)

instance Exponent (Sources a i) Natural where
(^) = const

instance Ord i => MultiplicativeMonoid (Sources a i) where
one = mempty

instance Exponent (Sources a i) Integer where
(^) = const

instance Ord i => MultiplicativeGroup (Sources a i) where
invert = id

instance Ord i => FromConstant c (Sources a i) where
fromConstant _ = mempty

instance Ord i => Semiring (Sources a i)

instance Ord i => Ring (Sources a i)

instance Ord i => Field (Sources a i) where
finv = id
rootOfUnity _ = Just (Sources mempty)
rootOfUnity _ = Just mempty

instance ToConstant (Sources (a :: Type) i) where
type Const (Sources a i) = Sources a i
toConstant = id

instance (Finite a, Ord i) => BinaryExpansion (Sources a i) where
type Bits (Sources a i) = [Sources a i]
binaryExpansion = replicate (numberOfBits @a)

instance Ord i => SemiEuclidean (Sources a i) where
div = (<>)
mod = (<>)
4 changes: 2 additions & 2 deletions src/ZkFold/Symbolic/Algorithms/Hash/Blake2b.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import Prelude hiding (Num (
replicate, splitAt, truncate, (!!), (&&), (^))

import ZkFold.Base.Algebra.Basic.Class (AdditiveGroup (..), AdditiveSemigroup (..),
EuclideanDomain (..), Exponent (..),
FromConstant (..), MultiplicativeSemigroup (..),
Exponent (..), FromConstant (..),
MultiplicativeSemigroup (..), SemiEuclidean (..),
divMod, one, zero, (-!))
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Prelude (length, replicate, splitAt, (!!))
Expand Down
3 changes: 2 additions & 1 deletion src/ZkFold/Symbolic/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Data.Functor (Functor (fmap), (<$>))
import Data.Kind (Type)
import Data.Type.Equality (type (~))
import GHC.Generics (Par1 (Par1), type (:.:) (unComp1))
import Numeric.Natural (Natural)

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Control.HApplicative (HApplicative (hpair, hunit))
Expand Down Expand Up @@ -49,7 +50,7 @@ class (HApplicative c, Package c, Arithmetic (BaseField c)) => Symbolic c where
-- | A wrapper around @'symbolicF'@ which extracts the pure computation
-- from the circuit computation using the @'Witnesses'@ newtype.
fromCircuitF :: c f -> CircuitFun f g (BaseField c) -> c g
fromCircuitF x f = symbolicF x (runWitnesses @(BaseField c) . f) f
fromCircuitF x f = symbolicF x (runWitnesses @Natural @(BaseField c) . f) f

-- | Embeds the pure value(s) into generic context @c@.
embed :: (Symbolic c, Foldable f, Functor f) => f (BaseField c) -> c f
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ desugarRange :: (Arithmetic a, MonadCircuit i a m) => i -> a -> m ()
desugarRange i b
| b == negate one = return ()
| otherwise = do
let bs = binaryExpansion b
let bs = binaryExpansion (toConstant b)
is <- expansion (length bs) i
case dropWhile ((== one) . fst) (zip bs is) of
[] -> return ()
Expand Down
2 changes: 1 addition & 1 deletion src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ instance o ~ U1 => Monoid (ArithmeticCircuit a i o) where
type VarField = Zp BLS12_381_Scalar

toField :: Arithmetic a => a -> VarField
toField = toZp . fromConstant . fromBinary @Natural . castBits . binaryExpansion
toField = fromConstant . toConstant

-- TODO: Remove the hardcoded constant.
toVar ::
Expand Down
45 changes: 22 additions & 23 deletions src/ZkFold/Symbolic/Data/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
module ZkFold.Symbolic.Data.Combinators where

import Control.Monad (mapM)
import Data.Foldable (Foldable (..), foldlM)
import Data.Foldable (foldlM)
import Data.Kind (Type)
import Data.List (find, splitAt)
import Data.List.Split (chunksOf)
Expand All @@ -16,7 +16,6 @@ import Data.Proxy (Proxy (..))
import Data.Ratio ((%))
import Data.Traversable (Traversable, for)
import Data.Type.Bool (If)
import Data.Type.Equality (type (~))
import Data.Type.Ord
import qualified Data.Zip as Z
import GHC.Base (const, return)
Expand All @@ -28,7 +27,6 @@ import Type.Errors

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number (value)
import ZkFold.Prelude (drop, take, (!!))
import ZkFold.Symbolic.MonadCircuit

-- | A class for isomorphic types.
Expand All @@ -47,8 +45,6 @@ class Extend a b where
class Shrink a b where
shrink :: a -> b



-- | Convert an @ArithmeticCircuit@ to bits and return their corresponding variables.
--
toBits
Expand All @@ -61,13 +57,10 @@ toBits
toBits regs hiBits loBits = do
let lows = tail regs
high = head regs

bitsLow <- Haskell.concatMap Haskell.reverse <$> mapM (expansion loBits) lows
bitsHigh <- Haskell.reverse <$> expansion hiBits high

pure $ bitsHigh <> bitsLow


-- | The inverse of @toBits@.
--
fromBits
Expand All @@ -78,10 +71,8 @@ fromBits
fromBits hiBits loBits bits = do
let (bitsHighNew, bitsLowNew) = splitAt (Haskell.fromIntegral hiBits) bits
let lowVarsNew = chunksOf (Haskell.fromIntegral loBits) bitsLowNew

lowsNew <- mapM (horner . Haskell.reverse) lowVarsNew
highNew <- horner . Haskell.reverse $ bitsHighNew

pure $ highNew : lowsNew

data RegisterSize = Auto | Fixed Natural
Expand Down Expand Up @@ -135,9 +126,7 @@ type family MaxRegisterSize (a :: Type) (regCount :: Natural) :: Natural where

type family ListRange (from :: Natural) (to :: Natural) :: [Natural] where
ListRange from from = '[from]
ListRange from to = from ': (ListRange (from + 1) to)


ListRange from to = from ': ListRange (from + 1) to

numberOfRegisters :: forall a n r . ( Finite a, KnownNat n, KnownRegisterSize r) => Natural
numberOfRegisters = case regSize @r of
Expand Down Expand Up @@ -175,13 +164,13 @@ highRegisterBits :: forall p n. (Finite p, KnownNat n) => Natural
highRegisterBits = case getNatural @n `mod` maxBitsPerFieldElement @p of
0 -> maxBitsPerFieldElement @p
m -> m

-- | The lowest possible number of registers to encode @n@ bits using Field elements from @p@
-- assuming that each register storest the largest possible number of bits.
--
minNumberOfRegisters :: forall p n. (Finite p, KnownNat n) => Natural
minNumberOfRegisters = (getNatural @n + maxBitsPerRegister @p @n -! 1) `div` maxBitsPerRegister @p @n


---------------------------------------------------------------

expansion :: MonadCircuit i a m => Natural -> i -> m [i]
Expand All @@ -196,10 +185,14 @@ bitsOf :: MonadCircuit i a m => Natural -> i -> m [i]
-- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller
-- bits of @k@.
bitsOf n k = for [0 .. n -! 1] $ \j ->
newConstrained (\x i -> let xi = x i in xi * (xi - one)) ((!! j) . repr . ($ k))
newConstrained (\x i -> let xi = x i in xi * (xi - one)) (repr j . ($ k))
where
repr :: forall b . (BinaryExpansion b, Bits b ~ [b], Finite b) => b -> [b]
repr = padBits (numberOfBits @b) . binaryExpansion
repr :: WitnessField n x => Natural -> x -> x
repr j =
fromConstant
. (`mod` fromConstant @Natural 2)
. (`div` fromConstant @Natural (2 ^ j))
. toConstant

horner :: MonadCircuit i a m => [i] -> m i
-- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using
Expand All @@ -213,15 +206,21 @@ splitExpansion :: (MonadCircuit i a m, Arithmetic a) => Natural -> Natural -> i
-- @k = 2^n1 h + l@, @l@ fits in @n1@ bits and @h@ fits in n2 bits (if such
-- values exist).
splitExpansion n1 n2 k = do
let f x y = x + y + y
l <- newRanged (fromConstant $ (2 :: Natural) ^ n1 -! 1) $ foldr f zero . take n1 . repr . ($ k)
h <- newRanged (fromConstant $ (2 :: Natural) ^ n2 -! 1) $ foldr f zero . take n2 . drop n1 . repr . ($ k)
l <- newRanged (fromConstant @Natural $ 2 ^ n1 -! 1) $ lower . ($ k)
h <- newRanged (fromConstant @Natural $ 2 ^ n2 -! 1) $ upper . ($ k)
constraint (\x -> x k - x l - scale (2 ^ n1 :: Natural) (x h))
return (l, h)
where
repr :: forall b . (BinaryExpansion b, Bits b ~ [b]) => b -> [b]
repr = padBits (n1 + n2) . binaryExpansion

lower :: WitnessField n a => a -> a
lower =
fromConstant . (`mod` fromConstant @Natural (2 ^ n1)) . toConstant

upper :: WitnessField n a => a -> a
upper =
fromConstant
. (`mod` fromConstant @Natural (2 ^ n2))
. (`div` fromConstant @Natural (2 ^ n1))
. toConstant

runInvert :: (MonadCircuit i a m, Z.Zip f, Traversable f) => f i -> m (f i, f i)
runInvert is = do
Expand Down
5 changes: 3 additions & 2 deletions src/ZkFold/Symbolic/Data/FieldElement.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ instance
instance Symbolic c => BinaryExpansion (FieldElement c) where
type Bits (FieldElement c) = c (Vector (NumberOfBits (BaseField c)))
binaryExpansion (FieldElement c) = hmap unsafeToVector $ symbolicF c
(\(Par1 v) -> padBits (numberOfBits @(BaseField c)) $ binaryExpansion v)
(\(Par1 i) -> expansion (numberOfBits @(BaseField c)) i)
(padBits n . fmap fromConstant . binaryExpansion . toConstant . unPar1)
(expansion n . unPar1)
where n = numberOfBits @(BaseField c)
fromBinary bits =
FieldElement $ symbolicF bits (Par1 . foldr (\x y -> x + y + y) zero)
$ fmap Par1 . horner . fromVector
9 changes: 6 additions & 3 deletions src/ZkFold/Symbolic/Data/Ord.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Data.Data (Proxy (..))
import Data.Foldable (Foldable, toList)
import Data.Function ((.))
import Data.Functor ((<$>))
import Data.List (map)
import qualified Data.Zip as Z
import GHC.Generics (Par1 (..))
import Prelude (type (~), ($))
Expand Down Expand Up @@ -90,9 +91,11 @@ getBitsBE ::
-- ^ @getBitsBE x@ returns a list of circuits computing bits of @x@, eldest to
-- youngest.
getBitsBE x =
hmap unsafeToVector
$ symbolicF (pieces x Proxy) (binaryExpansion . V.item)
$ expansion (numberOfBits @(BaseField c)) . V.item
hmap (V.reverse . unsafeToVector)
$ symbolicF (pieces x Proxy)
(padBits n . map fromConstant . binaryExpansion . toConstant . V.item)
(expansion n . V.item)
where n = numberOfBits @(BaseField c)

bitwiseGE :: forall c f . (Symbolic c, Z.Zip f, Foldable f) => c f -> c f -> Bool c
-- ^ Given two lists of bits of equal length, compares them lexicographically.
Expand Down
4 changes: 2 additions & 2 deletions src/ZkFold/Symbolic/Data/UInt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ cast n =
eea
:: forall n c r
. Symbolic c
=> EuclideanDomain (UInt n r c)
=> SemiEuclidean (UInt n r c)
=> KnownNat n
=> KnownNat (NumberOfRegisters (BaseField c) n r)
=> AdditiveGroup (UInt n r c)
Expand Down Expand Up @@ -200,7 +200,7 @@ instance
, KnownRegisterSize rs
, r ~ NumberOfRegisters (BaseField c) n rs
, NFData (c (Vector r))
) => EuclideanDomain (UInt n rs c) where
) => SemiEuclidean (UInt n rs c) where
divMod numerator d = bool @(Bool c) (q, r) (zero, zero) (d == zero)
where
(q, r) = Haskell.foldl longDivisionStep (zero, zero) [value @n -! 1, value @n -! 2 .. 0]
Expand Down
Loading

0 comments on commit 420da6a

Please sign in to comment.