Skip to content

Commit

Permalink
Merge pull request #325 from zkFold/vlasin-protostar-fix
Browse files Browse the repository at this point in the history
Protostar fix
  • Loading branch information
vlasin authored Nov 12, 2024
2 parents 3b71afc + 19c08f1 commit 34cba62
Show file tree
Hide file tree
Showing 23 changed files with 705 additions and 1,018 deletions.
14 changes: 12 additions & 2 deletions symbolic-base/src/ZkFold/Base/Algebra/Polynomials/Univariate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ instance (KnownNat size, Ring c) => IsList (PolyVec c size) where
instance Scale c' c => Scale c' (PolyVec c size) where
scale c (PV p) = PV (scale c <$> p)

instance (FromConstant Natural c, AdditiveMonoid c, KnownNat size) => FromConstant Natural (PolyVec c size) where
fromConstant n = PV $ V.singleton (fromConstant n) V.++ V.replicate (fromIntegral (value @size -! 1)) zero

instance (FromConstant Integer c, AdditiveMonoid c, KnownNat size) => FromConstant Integer (PolyVec c size) where
fromConstant n = PV $ V.singleton (fromConstant n) V.++ V.replicate (fromIntegral (value @size -! 1)) zero

instance Ring c => AdditiveSemigroup (PolyVec c size) where
PV l + PV r = PV $ V.zipWith (+) l r

Expand All @@ -330,6 +336,10 @@ instance (Field c, KnownNat size) => MultiplicativeSemigroup (PolyVec c size) wh
instance (Field c, KnownNat size) => MultiplicativeMonoid (PolyVec c size) where
one = PV $ V.singleton one V.++ V.replicate (fromIntegral (value @size -! 1)) zero

instance (Field c, KnownNat size) => Semiring (PolyVec c size)

instance (Field c, KnownNat size) => Ring (PolyVec c size)

instance (Ring c, Arbitrary c, KnownNat size) => Arbitrary (PolyVec c size) where
arbitrary = toPolyVec <$> V.replicateM (fromIntegral $ value @size) arbitrary

Expand Down Expand Up @@ -361,11 +371,11 @@ a +. (PV cs) = PV $ fmap (+ a) cs
polyVecConstant :: forall c size . (Ring c, KnownNat size) => c -> PolyVec c size
polyVecConstant a0 = PV $ V.singleton a0 V.++ V.replicate (fromIntegral $ value @size -! 1) zero

-- p(x) = a0 + a1 * x
-- p(x) = a1 * x + a0
polyVecLinear :: forall c size . (Ring c, KnownNat size) => c -> c -> PolyVec c size
polyVecLinear a1 a0 = PV $ V.fromList [a0, a1] V.++ V.replicate (fromIntegral $ value @size -! 2) zero

-- p(x) = a0 + a1 * x + a2 * x^2
-- p(x) = a2 * x^2 + a1 * x + a0
polyVecQuadratic :: forall c size . (Ring c, KnownNat size) => c -> c -> c -> PolyVec c size
polyVecQuadratic a2 a1 a0 = PV $ V.fromList [a0, a1, a2] V.++ V.replicate (fromIntegral $ value @size -! 3) zero

Expand Down
4 changes: 1 addition & 3 deletions symbolic-base/src/ZkFold/Base/Protocol/Protostar.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
module ZkFold.Base.Protocol.Protostar (module P) where

import ZkFold.Base.Protocol.Protostar.AccumulatorScheme as P
import ZkFold.Base.Protocol.Protostar.ArithmeticCircuit as P
import ZkFold.Base.Protocol.Protostar.Fold as P
import ZkFold.Base.Protocol.Protostar.RecursiveCircuit as P
35 changes: 4 additions & 31 deletions symbolic-base/src/ZkFold/Base/Protocol/Protostar/Accumulator.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}

module ZkFold.Base.Protocol.Protostar.Accumulator where

import Control.DeepSeq (NFData (..))
import Control.Lens.Combinators (makeLenses)
import GHC.Generics
import Prelude hiding (length)
import Prelude hiding (length, pi)

-- Page 19, Accumulator instance
data AccumulatorInstance pi f c
Expand All @@ -25,39 +26,11 @@ makeLenses ''AccumulatorInstance
-- Page 19, Accumulator
-- @acc.x@ (accumulator instance) from the paper corresponds to _x
-- @acc.w@ (accumulator witness) from the paper corresponds to _w
data Accumulator i f c m
data Accumulator pi f c m
= Accumulator
{ _x :: AccumulatorInstance i f c
{ _x :: AccumulatorInstance pi f c
, _w :: [m]
}
deriving (Show, Generic, NFData)

makeLenses ''Accumulator

-- Page 18, section 3.4, The accumulation predicate
--
data NARKProof c m
= NARKProof
{ narkCommits :: [c] -- Commits [C_i] ∈ C^k
, narkWitness :: [m] -- prover messages in the special-sound protocol [m_i]
}
deriving (Show, Generic, NFData)

data InstanceProofPair pi c m = InstanceProofPair pi (NARKProof c m)
deriving (Show, Generic, NFData)

{--
toAccumulatorInstance :: (FiniteField f, AdditiveGroup c) => (f -> c -> f) -> NARKInstance f c -> AccumulatorInstance f c
toAccumulatorInstance oracle (NARKInstance i cs) =
let r0 = oracle i zero
f acc@(r:_) c = oracle r c : acc
f [] _ = error "Invalid accumulator instance"
rs = init $ reverse $ foldl f [r0] cs
in AccumulatorInstance i cs rs zero one
toAccumulatorWitness :: NARKWitness m -> AccumulatorWitness m
toAccumulatorWitness (NARKWitness ms) = AccumulatorWitness ms
toAccumulator :: (FiniteField f, AdditiveGroup c) => (f -> c -> f) -> NARKPair pi f c m -> Accumulator f c m
toAccumulator oracle (NARKPair i w) = Accumulator (toAccumulatorInstance oracle i) (toAccumulatorWitness w)
--}
170 changes: 68 additions & 102 deletions symbolic-base/src/ZkFold/Base/Protocol/Protostar/AccumulatorScheme.hs
Original file line number Diff line number Diff line change
@@ -1,184 +1,150 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Redundant ^." #-}

module ZkFold.Base.Protocol.Protostar.AccumulatorScheme where

import Control.DeepSeq (NFData)
import Control.Lens ((^.))
import Data.List (transpose)
import qualified Data.Vector as DV
import GHC.Generics (Generic)
import Prelude (type (~), ($), (.), (<$>))
import GHC.IsList (IsList (..))
import Prelude (concatMap, ($), (.), (<$>))
import qualified Prelude as P

import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import qualified ZkFold.Base.Algebra.Polynomials.Univariate as PU
import ZkFold.Base.Protocol.Protostar.Accumulator
import ZkFold.Base.Protocol.Protostar.AlgebraicMap (AlgebraicMap (..))
import ZkFold.Base.Protocol.Protostar.Commit (HomomorphicCommit (..))
import ZkFold.Base.Protocol.Protostar.CommitOpen (CommitOpen (..), CommitOpenProverMessage (..))
import ZkFold.Base.Protocol.Protostar.CommitOpen (CommitOpen (..))
import ZkFold.Base.Protocol.Protostar.FiatShamir (FiatShamir (..))
import ZkFold.Base.Protocol.Protostar.NARK (InstanceProofPair (..), NARKProof (..))
import ZkFold.Base.Protocol.Protostar.Oracle (RandomOracle (..))
import ZkFold.Base.Protocol.Protostar.SpecialSound (AlgebraicMap (..), MapInput, SpecialSoundProtocol (..))
import ZkFold.Symbolic.Class
import ZkFold.Symbolic.Data.Bool
import ZkFold.Symbolic.Data.Eq


-- | Accumulator scheme for V_NARK as described in Chapter 3.4 of the Protostar paper
--
class AccumulatorScheme i f c m ctx a where
prover :: a
-> f -- Commitment key ck
-> Accumulator i f c m -- accumulator
-> InstanceProofPair i c m -- instance-proof pair (pi, π)
-> (Accumulator i f c m, [c]) -- updated accumulator and accumulation proof

verifier :: i -- Public input
-> [c] -- NARK proof π.x
-> AccumulatorInstance i f c -- accumulator instance acc.x
-> AccumulatorInstance i f c -- updated accumulator instance acc'.x
-> [c] -- accumulation proof E_j
-> Bool ctx

decider :: a
-> (f, KeyScale f) -- Commitment key ck and scaling factor
-> Accumulator i f c m -- final accumulator
-> Bool ctx

data KeyScale f = KeyScale f f
deriving (P.Show, Generic, NFData)

-- | Class describing types which can form a polynomial linear combination:
-- linearCombination a1 a2 -> a1 * X + a2
-- TODO: define the initial accumulator
--
class LinearCombination a b where
linearCombination :: a -> a -> b
class AccumulatorScheme pi f c m a where
prover :: a
-> Accumulator pi f c m -- accumulator
-> InstanceProofPair pi c m -- instance-proof pair (pi, π)
-> (Accumulator pi f c m, [c]) -- updated accumulator and accumulation proof

-- | Same as above, but with a coefficient known at runtime
-- linearCombination coeff b1 b2 -> b1 * coeff + b2
--
class LinearCombinationWith a b where
linearCombinationWith :: a -> b -> b -> b
verifier :: pi -- Public input
-> [c] -- NARK proof π.x
-> AccumulatorInstance pi f c -- accumulator instance acc.x
-> AccumulatorInstance pi f c -- updated accumulator instance acc'.x
-> [c] -- accumulation proof E_j
-> (f, pi, [f], [c], c) -- returns zeros if the accumulation proof is correct

instance (Scale f a, AdditiveSemigroup a) => LinearCombinationWith f [a] where
linearCombinationWith f = P.zipWith (\a b -> scale f a + b)
decider :: a
-> Accumulator pi f c m -- final accumulator
-> ([c], c) -- returns zeros if the final accumulator is valid

instance
( Symbolic ctx
, Eq (Bool ctx) c
, Eq (Bool ctx) i
, Eq (Bool ctx) f
, Eq (Bool ctx) [f]
, Eq (Bool ctx) [c]
, AdditiveGroup c
, AdditiveSemigroup m
, Ring f
, Scale f c
, Scale f m
, MapInput f a ~ i
, deg ~ Degree (CommitOpen m c a) + 1
, KnownNat deg
, LinearCombination (MapMessage f a) (MapMessage (PU.PolyVec f deg) a)
, LinearCombination (MapInput f a) (MapInput (PU.PolyVec f deg) a)
, LinearCombinationWith f (MapInput f a)
, MapMessage f a ~ m
, AlgebraicMap f (CommitOpen m c a)
, AlgebraicMap (PU.PolyVec f deg) a
, RandomOracle c f -- Random oracle ρ_NARK
, RandomOracle i f -- Random oracle for compressing public input
, HomomorphicCommit f [m] c
, HomomorphicCommit f [f] c
) => AccumulatorScheme i f c m ctx (FiatShamir f (CommitOpen m c a)) where
prover (FiatShamir (CommitOpen _ sps) _) ck acc (InstanceProofPair pubi (NARKProof pi_x pi_w)) =
(Accumulator (AccumulatorInstance pi'' ci'' ri'' eCapital' mu') mi'', eCapital_j)
( Scale f c
, RandomOracle pi f -- Random oracle for compressing public input
, RandomOracle c f -- Random oracle ρ_NARK
, HomomorphicCommit [f] c
, HomomorphicCommit m c
, AlgebraicMap f pi m a
, AlgebraicMap (PU.PolyVec f (Degree a + 1)) [PU.PolyVec f (Degree a + 1)] [PU.PolyVec f (Degree a + 1)] a
, KnownNat (Degree a + 1)
) => AccumulatorScheme pi f c m (FiatShamir f (CommitOpen m c a)) where
prover (FiatShamir (CommitOpen sps)) acc (InstanceProofPair pubi (NARKProof pi_x pi_w)) =
(Accumulator (AccumulatorInstance pi'' ci'' ri'' eCapital' mu') m_i'', pf)
where
-- Fig. 3, step 1
r_i :: [f]
r_i = P.tail $ P.scanl (P.curry oracle) (oracle pubi) (zero : pi_x)
r_i = P.tail $ P.scanl (P.curry oracle) (oracle pubi) pi_x

-- Fig. 3, step 2

-- X + mu as a univariate polynomial
polyMu :: PU.PolyVec f deg
polyMu = PU.polyVecLinear (acc^.x^.mu) one
polyMu :: PU.PolyVec f (Degree a + 1)
polyMu = PU.polyVecLinear one (acc^.x^.mu)

-- X * pi + pi' as a list of univariate polynomials
polyPi = linearCombination pubi (acc^.x^.pi)
polyPi :: [PU.PolyVec f (Degree a + 1)]
polyPi = P.zipWith (PU.polyVecLinear @f) (toList pubi) (toList (acc^.x^.pi))

-- X * mi + mi'
polyW = P.zipWith linearCombination pi_w (acc^.w)
polyW :: [PU.PolyVec f (Degree a + 1)]
polyW = P.zipWith (PU.polyVecLinear @f) (concatMap toList pi_w) (concatMap toList (acc^.w))

-- X * ri + ri'
polyR :: [PU.PolyVec f deg]
polyR :: [PU.PolyVec f (Degree a + 1)]
polyR = P.zipWith (P.flip PU.polyVecLinear) (acc^.x^.r) r_i

-- The @l x d+1@ matrix of coefficients as a vector of @l@ univariate degree-@d@ polynomials
--
e_uni :: [PU.PolyVec f deg]
e_uni = algebraicMap @(PU.PolyVec f deg) sps polyPi polyW polyR polyMu
e_uni :: [PU.PolyVec f (Degree a + 1)]
e_uni = algebraicMap sps polyPi [polyW] polyR polyMu

-- e_all are coefficients of degree-j homogenous polynomials where j is from the range [0, d]
e_all = transpose $ (DV.toList . PU.fromPolyVec) <$> e_uni
e_all = transpose $ DV.toList . PU.fromPolyVec <$> e_uni

-- e_j are coefficients of degree-j homogenous polynomials where j is from the range [1, d - 1]
e_j :: [[f]]
e_j = P.tail $ P.init $ e_all
e_j = P.tail . P.init $ e_all

-- Fig. 3, step 3
eCapital_j = hcommit ck <$> e_j
pf = hcommit <$> e_j

-- Fig. 3, step 4
alpha :: f
alpha = oracle (acc^.x, pubi, pi_x, eCapital_j)
alpha = oracle (acc^.x, pubi, pi_x, pf)

-- Fig. 3, steps 5, 6
mu' = alpha + acc^.x^.mu
pi'' = linearCombinationWith alpha pubi $ acc^.x^.pi
ri'' = linearCombinationWith alpha r_i $ acc^.x^.r
ci'' = linearCombinationWith alpha pi_x $ acc^.x^.c
mi'' = linearCombinationWith alpha pi_w $ acc^.w
mu' = alpha + acc^.x^.mu
pi'' = scale alpha pubi + acc^.x^.pi
ri'' = scale alpha r_i + acc^.x^.r
ci'' = scale alpha pi_x + acc^.x^.c
m_i'' = scale alpha pi_w + acc^.w

-- Fig. 3, step 7
eCapital' = acc^.x^.e + sum (P.zipWith (\e' p -> scale (alpha ^ p) e') eCapital_j [1::Natural ..])
eCapital' = acc^.x^.e + sum (P.zipWith (\e' p -> scale (alpha ^ p) e') pf [1::Natural ..])


verifier pubi c_i acc acc' pf = and [muEq, piEq, riEq, ciEq, eEq]
verifier pubi c_i acc acc' pf = (muDiff, piDiff, riDiff, ciDiff, eDiff)
where
-- Fig. 4, step 1
r_i :: [f]
r_i = P.tail $ P.scanl (P.curry oracle) (oracle pubi) (zero : c_i)
r_i = P.tail $ P.scanl (P.curry oracle) (oracle pubi) c_i

-- Fig. 4, step 2
alpha :: f
alpha = oracle (acc, pubi, c_i, pf)

-- Fig. 4, step 3
mu' = alpha + acc^.mu
pi'' = linearCombinationWith alpha pubi $ acc^.pi
ri'' = linearCombinationWith alpha r_i $ acc^.r
ci'' = linearCombinationWith alpha c_i $ acc^.c
pi'' = scale alpha pubi + acc^.pi
ri'' = scale alpha r_i + acc^.r
ci'' = scale alpha c_i + acc^.c

-- Fig 4, step 4
muEq = acc'^.mu == mu'
piEq = acc'^.pi == pi''
riEq = acc'^.r == ri''
ciEq = acc'^.c == ci''
muDiff = acc'^.mu - mu'
piDiff = acc'^.pi - pi''
riDiff = acc'^.r - ri''
ciDiff = acc'^.c - ci''

-- Fig 4, step 5
eEq = acc'^.e == acc^.e + sum (P.zipWith scale ((\p -> alpha^p) <$> [1 :: Natural ..]) pf)
eDiff = acc'^.e - (acc^.e + sum (P.zipWith scale ((alpha ^) <$> [1 :: Natural ..]) pf))

decider (FiatShamir sps _) (ck, KeyScale ef _) acc = commitsEq && eEq
decider (FiatShamir (CommitOpen sps)) acc = (commitsDiff, eDiff)
where
-- Fig. 5, step 1
commitsEq = and $ P.zipWith (\cm m_acc -> cm == hcommit (scale (acc^.x^.mu) ck) [m_acc]) (acc^.x^.c) (acc^.w)
commitsDiff = P.zipWith (\cm m_acc -> cm - hcommit m_acc) (acc^.x^.c) (acc^.w)

-- Fig. 5, step 2
err :: [f]
err = algebraicMap @f sps (acc^.x^.pi) [Open $ acc^.w] (acc^.x^.r) (acc^.x^.mu)
err = algebraicMap sps (acc^.x^.pi) (acc^.w) (acc^.x^.r) (acc^.x^.mu)


-- Fig. 5, step 3
eEq = (acc^.x^.e) == hcommit (scale ef ck) err
eDiff = (acc^.x^.e) - hcommit err
Loading

0 comments on commit 34cba62

Please sign in to comment.