Skip to content

Commit

Permalink
Merge pull request #279 from zkFold/eitan-vector-vector
Browse files Browse the repository at this point in the history
Use Data.Vector.Vector in Vector n
  • Loading branch information
echatav authored Sep 30, 2024
2 parents 78134a8 + 7a25de1 commit 84b572d
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 118 deletions.
23 changes: 13 additions & 10 deletions src/ZkFold/Base/Algebra/Basic/Permutations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module ZkFold.Base.Algebra.Basic.Permutations (
fromCycles
) where

import Data.Functor.Rep (Representable (index))
import Data.Map (Map, elems, empty, singleton, union)
import Data.Maybe (fromJust)
import qualified Data.Vector as V
Expand All @@ -19,8 +20,9 @@ import qualified Prelude as P
import Test.QuickCheck (Arbitrary (..))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Field
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.Vector (Vector (..), fromVector, toVector)
import ZkFold.Base.Data.Vector (Vector (..), fromVector, toVector, unsafeToVector)
import ZkFold.Prelude (chooseNatural, drop, length, (!!))

-- TODO (Issue #18): make the code safer
Expand All @@ -37,7 +39,7 @@ mkIndexPartition vs =

------------------------------------- Permutations -------------------------------------------

newtype Permutation n = Permutation (Vector n Natural)
newtype Permutation n = Permutation (Vector n (Zp n))
deriving (Show, Eq)

instance KnownNat n => Arbitrary (Permutation n) where
Expand All @@ -48,16 +50,17 @@ instance KnownNat n => Arbitrary (Permutation n) where
let as' = (bs !! i) : as
bs' = drop i bs
f as' bs'
in Permutation . Vector <$> f [] [1..value @n]
in Permutation . unsafeToVector <$>
f [] [fromConstant x | x <- [1..value @n]]

fromPermutation :: Permutation n -> [Natural]
fromPermutation :: Permutation n -> [Zp n]
fromPermutation (Permutation perm) = fromVector perm

applyPermutation :: Permutation n -> Vector n a -> Vector n a
applyPermutation (Permutation (Vector ps)) (Vector as) = Vector $ map (as !!) ps
applyPermutation :: KnownNat n => Permutation n -> Vector n a -> Vector n a
applyPermutation (Permutation ps) as = fmap (index as) ps

applyCycle :: IndexSet -> Permutation n -> Permutation n
applyCycle c (Permutation perm) = Permutation $ fmap f perm
applyCycle :: KnownNat n => V.Vector Natural -> Permutation n -> Permutation n
applyCycle c (Permutation perm) = Permutation $ fmap (fromConstant . f . toConstant) perm
where
f :: Natural -> Natural
f i = case i `V.elemIndex` c of
Expand All @@ -66,6 +69,6 @@ applyCycle c (Permutation perm) = Permutation $ fmap f perm

fromCycles :: KnownNat n => IndexPartition a -> Permutation n
fromCycles p =
let n = fromIntegral $ V.length $ V.concat $ elems p
in foldr applyCycle (Permutation $ fromJust $ toVector [1 .. n]) $ elems p
let n = toInteger $ V.length $ V.concat $ elems p
in foldr applyCycle (Permutation $ fromJust $ toVector [fromConstant x | x <- [1 .. n]]) $ elems p

26 changes: 6 additions & 20 deletions src/ZkFold/Base/Data/Matrix.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TypeApplications #-}

module ZkFold.Base.Data.Matrix where

import Data.Bifunctor (first)
import qualified Data.List as List
import Data.Maybe (fromJust)
import Data.These
Expand Down Expand Up @@ -40,7 +40,7 @@ outer f a b = Matrix $ fmap (\x -> fmap (f x) b) a
(.*) = zipWith (*)

sum1 :: (Semiring a) => Matrix m n a -> Vector n a
sum1 (Matrix as) = Vector (sum <$> fromVector as)
sum1 (Matrix as) = Vector (sum <$> toV as)

sum2 :: (KnownNat m, KnownNat n, Semiring a) => Matrix m n a -> Vector m a
sum2 (Matrix as) = sum1 $ transpose $ Matrix as
Expand Down Expand Up @@ -73,20 +73,6 @@ instance Zip (Matrix m n) where

zipWith f (Matrix as) (Matrix bs) = Matrix $ zipWith (zipWith f) as bs

instance (Arbitrary a, KnownNat m, KnownNat n) => Arbitrary (Matrix m n a) where
arbitrary = Matrix <$> arbitrary

instance (Random a, KnownNat m, KnownNat n) => Random (Matrix m n a) where
random g =
let as = foldl (\(as', g') _ ->
let (a, g'') = random g'
in (as' ++ [a], g''))
([], g) [1..value @m]
in first (Matrix . Vector) as

randomR (Matrix xs, Matrix ys) g =
let as = fst $ foldl (\((as', g'), (xs', ys')) _ ->
let (a, g'') = randomR (head xs', head ys') g'
in ((as' ++ [a], g''), (tail xs', tail ys')))
(([], g), (fromVector xs, fromVector ys)) [1..value @m]
in first (Matrix . Vector) as
deriving newtype instance (Arbitrary a, KnownNat m, KnownNat n) => Arbitrary (Matrix m n a)

deriving newtype instance (Random a, KnownNat m, KnownNat n) => Random (Matrix m n a)
137 changes: 58 additions & 79 deletions src/ZkFold/Base/Data/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,119 +5,109 @@
module ZkFold.Base.Data.Vector where

import Control.DeepSeq (NFData)
import qualified Control.Monad as M
import Control.Parallel.Strategies (parMap, rpar)
import Control.Monad.State.Strict (runState, state)
import Data.Aeson (ToJSON (..))
import Data.Bifunctor (first)
import Data.Distributive (Distributive (..))
import Data.Functor.Rep (Representable (..), collectRep, distributeRep)
import qualified Data.List as List
import Data.List.Split (chunksOf)
import Data.Foldable (fold)
import Data.Functor.Rep (Representable (..), collectRep, distributeRep, mzipRep, pureRep)
import Data.These (These (..))
import qualified Data.Vector as V
import Data.Vector.Binary ()
import qualified Data.Vector.Split as V
import Data.Zip (Semialign (..), Zip (..))
import GHC.Generics (Generic)
import GHC.IsList (IsList (..))
import Prelude hiding (drop, head, length, mod, replicate, sum, tail, take, zip,
zipWith, (*))
import qualified Prelude as P
import Prelude hiding (concat, drop, head, length, mod, replicate, sum, tail, take,
zip, zipWith, (*))
import System.Random (Random (..))
import Test.QuickCheck (Arbitrary (..))

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Field
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.ByteString (Binary (..))
import qualified ZkFold.Prelude as ZP
import ZkFold.Prelude (length, replicate)
import ZkFold.Prelude (length)

newtype Vector (size :: Natural) a = Vector [a]
newtype Vector (size :: Natural) a = Vector {toV :: V.Vector a}
deriving (Show, Eq, Functor, Foldable, Traversable, Generic, NFData, Ord)

-- helper
knownNat :: forall size n . (KnownNat size, Integral n) => n
knownNat = fromIntegral (value @size)

instance KnownNat size => Representable (Vector size) where
type Rep (Vector size) = Zp size
index (Vector v) ix = v Prelude.!! (fromIntegral (fromZp ix))
tabulate f = Vector [f (toZp ix) | ix <- [0 .. fromIntegral (value @size) Prelude.- 1]]
index (Vector v) ix = v V.! (fromIntegral (fromZp ix))
tabulate f = Vector (V.generate (knownNat @size) (f . fromIntegral))

instance KnownNat size => Distributive (Vector size) where
distribute = distributeRep
collect = collectRep


vtoVector :: forall size a . KnownNat size => V.Vector a -> Maybe (Vector size a)
vtoVector as
| V.length as == knownNat @size = Just $ Vector as
| otherwise = Nothing

instance IsList (Vector n a) where
type Item (Vector n a) = a
toList = fromVector
fromList = unsafeToVector

parFmap :: (a -> b) -> Vector size a -> Vector size b
parFmap f (Vector lst) = Vector $ parMap rpar f lst

toVector :: forall size a . KnownNat size => [a] -> Maybe (Vector size a)
toVector as
| length as == value @size = Just $ Vector as
| length as == value @size = Just $ Vector (V.fromList as)
| otherwise = Nothing

unsafeToVector :: forall size a . [a] -> Vector size a
unsafeToVector = Vector

generate :: forall size a . KnownNat size => (Natural -> a) -> Vector size a
generate f = Vector $
case value @size of
0 -> [] -- avoid arithmetic underflow
n -> f <$> [0 .. n -! 1]
unsafeToVector = Vector . V.fromList

unfold :: forall size a b. KnownNat size => (b -> (a, b)) -> b -> Vector size a
unfold f = Vector . ZP.take (value @size) . List.unfoldr (Just . f)
unfold f = Vector . V.take (knownNat @size) . V.unfoldr (Just . f)

fromVector :: Vector size a -> [a]
fromVector (Vector as) = as
fromVector (Vector as) = V.toList as

(!!) :: Vector size a -> Natural -> a
(Vector as) !! i = as List.!! fromIntegral i
(Vector as) !! i = as V.! fromIntegral i

uncons :: Vector size a -> (a, Vector (size - 1) a)
uncons (Vector lst) = (P.head lst, Vector $ P.tail lst)
uncons (Vector lst) = (V.head lst, Vector $ V.tail lst)

reverse :: Vector size a -> Vector size a
reverse (Vector lst) = Vector (P.reverse lst)
reverse (Vector lst) = Vector (V.reverse lst)

head :: Vector size a -> a
head (Vector as) = P.head as
head (Vector as) = V.head as

tail :: Vector size a -> Vector (size - 1) a
tail (Vector as) = Vector $ P.tail as
tail (Vector as) = Vector $ V.tail as

singleton :: a -> Vector 1 a
singleton = Vector . pure

item :: Vector 1 a -> a
item (Vector [a]) = a
item _ = error "Unreachable"
item = head

mapWithIx :: forall n a b . KnownNat n => (Natural -> a -> b) -> Vector n a -> Vector n b
mapWithIx f (Vector l) = Vector $ zipWith f [0 .. (value @n -! 1)] l
mapWithIx f (Vector l) = Vector $ V.zipWith f (V.enumFromTo 0 (value @n -! 1)) l

mapMWithIx :: forall n m a b . (KnownNat n, Monad m) => (Natural -> a -> m b) -> Vector n a -> m (Vector n b)
mapMWithIx f (Vector l) = Vector <$> M.zipWithM f [0 .. (value @n -! 1)] l
mapMWithIx f (Vector l) = Vector <$> V.zipWithM f (V.enumFromTo 0 (value @n -! 1)) l

zipWithM :: forall n m a b c . Applicative m => (a -> b -> m c) -> Vector n a -> Vector n b -> m (Vector n c)
zipWithM f (Vector l) (Vector r) = Vector <$> M.zipWithM f l r
zipWithM f (Vector l) (Vector r) = sequenceA . Vector $ V.zipWith f l r

-- TODO: Check that n <= size?
take :: forall n size a. KnownNat n => Vector size a -> Vector n a
take (Vector lst) = Vector (ZP.take (value @n) lst)
take (Vector lst) = Vector (V.take (knownNat @n) lst)

drop :: forall n m a. KnownNat n => Vector (n + m) a -> Vector m a
drop (Vector lst) = Vector (ZP.drop (value @n) lst)
drop (Vector lst) = Vector (V.drop (knownNat @n) lst)

splitAt :: forall n m a. KnownNat n => Vector (n + m) a -> (Vector n a, Vector m a)
splitAt (Vector lst) = (Vector (ZP.take (value @n) lst), Vector (ZP.drop (value @n) lst))

-- | The sole purpose of this function is to get rid of annoying constraints in ZkFols.Symbolic.Compiler.Arithmetizable
--
splitAt3 :: forall n m k a. (KnownNat n, KnownNat m) => Vector (n + m + k) a -> (Vector n a, Vector m a, Vector k a)
splitAt3 (Vector lst) = (Vector ln, Vector lm, Vector lk)
where
(ln, lmk) = (ZP.take (value @n) lst, ZP.drop (value @n) lst)
(lm, lk) = (ZP.take (value @m) lmk, ZP.drop (value @m) lmk)
splitAt (Vector lst) = (Vector (V.take (knownNat @n) lst), Vector (V.drop (knownNat @n) lst))

rotate :: forall size a. KnownNat size => Vector size a -> Integer -> Vector size a
rotate (Vector lst) n = Vector (r <> l)
Expand All @@ -128,71 +118,60 @@ rotate (Vector lst) n = Vector (r <> l)
lshift :: Int
lshift = fromIntegral $ n `mod` len

(l, r) = P.splitAt lshift lst
(l, r) = V.splitAt lshift lst

shift :: forall size a. KnownNat size => Vector size a -> Integer -> a -> Vector size a
shift (Vector lst) n pad
| n < 0 = Vector $ ZP.take (value @size) (padList <> lst)
| otherwise = Vector $ ZP.drop (fromIntegral n) (lst <> padList)
| n < 0 = Vector $ V.take (knownNat @size) (padList <> lst)
| otherwise = Vector $ V.drop (fromIntegral n) (lst <> padList)
where
padList = replicate (fromIntegral $ abs n) pad

padList = V.replicate (fromIntegral $ abs n) pad

vectorDotProduct :: forall size a . Semiring a => Vector size a -> Vector size a -> a
vectorDotProduct (Vector as) (Vector bs) = sum $ zipWith (*) as bs

empty :: Vector 0 a
empty = Vector []
empty = Vector V.empty

infixr 5 .:
(.:) :: a -> Vector n a -> Vector (n + 1) a
a .: (Vector lst) = Vector (a : lst)
a .: (Vector lst) = Vector (a `V.cons` lst)

append :: Vector m a -> Vector n a -> Vector (m + n) a
append (Vector l) (Vector r) = Vector (l <> r)

concat :: Vector m (Vector n a) -> Vector (m * n) a
concat = Vector . concatMap fromVector
concat = Vector . V.concatMap toV . toV

unsafeConcat :: forall m n a . [Vector n a] -> Vector (m * n) a
unsafeConcat = Vector . concatMap fromVector
unsafeConcat = concat . unsafeToVector @m

chunks :: forall m n a . KnownNat n => Vector (m * n) a -> Vector m (Vector n a)
chunks (Vector lists) = Vector (Vector <$> chunksOf (fromIntegral $ value @n) lists)
chunks (Vector vectors) = unsafeToVector (Vector <$> V.chunksOf (fromIntegral $ value @n) vectors)

instance Binary a => Binary (Vector n a) where
put = put . fromVector
get = Vector <$> get
instance (KnownNat n, Binary a) => Binary (Vector n a) where
put = fold . V.map put . toV
get = Vector <$> V.replicateM (knownNat @n) get

instance KnownNat size => Applicative (Vector size) where
pure a = Vector $ replicate (value @size) a
pure a = Vector $ V.replicate (knownNat @size) a

(Vector fs) <*> (Vector as) = Vector $ zipWith ($) fs as
(Vector fs) <*> (Vector as) = Vector $ V.zipWith ($) fs as

instance Semialign (Vector size) where
align (Vector as) (Vector bs) = Vector $ zipWith These as bs
align (Vector as) (Vector bs) = Vector $ V.zipWith These as bs

instance Zip (Vector size) where
zip (Vector as) (Vector bs) = Vector $ zip as bs
zip (Vector as) (Vector bs) = Vector $ V.zip as bs

zipWith f (Vector as) (Vector bs) = Vector $ zipWith f as bs
zipWith f (Vector as) (Vector bs) = Vector $ V.zipWith f as bs

instance (Arbitrary a, KnownNat size) => Arbitrary (Vector size a) where
arbitrary = Vector <$> mapM (const arbitrary) [1..value @size]
arbitrary = sequenceA (pureRep arbitrary)

instance (Random a, KnownNat size) => Random (Vector size a) where
random g =
let as = foldl (\(as', g') _ ->
let (a, g'') = random g'
in (as' ++ [a], g''))
([], g) [1..value @size]
in first Vector as

randomR (Vector xs, Vector ys) g =
let as = fst $ foldl (\((as', g'), (xs', ys')) _ ->
let (a, g'') = randomR (P.head xs', P.head ys') g'
in ((as' ++ [a], g''), (P.tail xs', P.tail ys'))) (([], g), (xs, ys)) [1..value @size]
in first Vector as
random = runState (sequenceA (pureRep (state random)))
randomR = runState . traverse (state . randomR) . uncurry mzipRep

instance ToJSON a => ToJSON (Vector n a) where
toJSON (Vector xs) = toJSON xs
2 changes: 1 addition & 1 deletion src/ZkFold/Base/Protocol/Plonkup/Setup.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ plonkupSetup Plonkup {..} =
1 -> k1 * (omega^i)
2 -> k2 * (omega^i)
_ -> error "setup: invalid index"
s = fromList $ map f $ fromPermutation @(PlonkupPermutationSize n) $ sigma
s = fromList $ map (f . toConstant) $ fromPermutation @(PlonkupPermutationSize n) $ sigma
sigma1s = toPolyVec $ V.take (fromIntegral $ value @n) s
sigma2s = toPolyVec $ V.take (fromIntegral $ value @n) $ V.drop (fromIntegral $ value @n) s
sigma3s = toPolyVec $ V.take (fromIntegral $ value @n) $ V.drop (fromIntegral $ 2 * value @n) s
Expand Down
3 changes: 2 additions & 1 deletion src/ZkFold/Base/Protocol/Protostar/ArithmeticCircuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module ZkFold.Base.Protocol.Protostar.ArithmeticCircuit where


import Data.ByteString (ByteString)
import Data.Functor.Rep (tabulate)
import Data.List (foldl')
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
Expand Down Expand Up @@ -108,7 +109,7 @@ padDecomposition pad = foldl' (P.zipWith (+)) (P.repeat zero) . V.mapWithIx (\j
-- | Decomposes an algebraic map into homogenous degree-j maps for j from 0 to @n@
--
degreeDecomposition :: forall n f v . KnownNat (n + 1) => [Poly f v Natural] -> V.Vector (n + 1) [Poly f v Natural]
degreeDecomposition lmap = V.generate degree_j
degreeDecomposition lmap = tabulate (degree_j . toConstant)
where
degree_j :: Natural -> [Poly f v Natural]
degree_j j = P.fmap (leaveDeg j) lmap
Expand Down
4 changes: 4 additions & 0 deletions src/ZkFold/Base/Protocol/Protostar/Oracle.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Data.Char (ord)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Proxy (Proxy (..))
import qualified Data.Vector as V
import GHC.Generics
import GHC.TypeLits
import Prelude (($), (.), (<$>))
Expand All @@ -35,6 +36,9 @@ instance Ring a => RandomOracle a a where
instance (AdditiveMonoid b, RandomOracle a b) => RandomOracle [a] b where
oracle as = sum $ oracle <$> as

instance (AdditiveMonoid b, RandomOracle a b) => RandomOracle (V.Vector a) b where
oracle as = sum $ oracle <$> as

instance {-# OVERLAPPABLE #-} (Generic a, RandomOracle' (Rep a) b) => RandomOracle a b where

class RandomOracle' f b where
Expand Down
Loading

0 comments on commit 84b572d

Please sign in to comment.