diff --git a/symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Relation.hs b/symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Relation.hs index c09694d34..2bd1fc65b 100644 --- a/symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Relation.hs +++ b/symbolic-base/src/ZkFold/Base/Protocol/Plonkup/Relation.hs @@ -10,6 +10,7 @@ import Data.Constraint (withDict) import Data.Constraint.Nat (timesNat) import Data.Map (elems, keys) import Data.Maybe (fromJust) +import qualified Data.Set as S import GHC.IsList (IsList (..)) import Prelude hiding (Num (..), drop, length, replicate, sum, take, (!!), (/), (^)) @@ -77,13 +78,13 @@ toPlonkupRelation ac = let xPub = acOutput ac pubInputConstraints = map var (fromVector xPub) plonkConstraints = map (evalPolynomial evalMonomial (var . SysVar)) (elems (acSystem ac)) - rs = map toConstant $ elems $ acRange ac + rs = map toConstant $ keys $ acRange ac -- TODO: We are expecting at most one range. t = toPolyVec $ fromList $ map fromConstant $ bool [] (replicate (value @n -! length rs + 1) 0 ++ [ 0 .. head rs ]) (not $ null rs) -- Number of elements in the set `t`. nLookup = bool 0 (head rs + 1) (not $ null rs) -- Lookup queries. - xLookup = keys (acRange ac) + xLookup = concatMap S.toList $ elems (acRange ac) -- The total number of constraints in the relation. n' = acSizeN ac + value @l + length xLookup diff --git a/symbolic-base/src/ZkFold/Symbolic/Compiler.hs b/symbolic-base/src/ZkFold/Symbolic/Compiler.hs index 4496dbb0f..cdde1caa0 100644 --- a/symbolic-base/src/ZkFold/Symbolic/Compiler.hs +++ b/symbolic-base/src/ZkFold/Symbolic/Compiler.hs @@ -10,7 +10,7 @@ module ZkFold.Symbolic.Compiler ( solder, ) where -import Data.Aeson (FromJSON, ToJSON) +import Data.Aeson (FromJSON, ToJSON, ToJSONKey) import Data.Binary (Binary) import Data.Function (const, (.)) import Data.Functor (($>)) @@ -109,6 +109,7 @@ compileIO :: ( c ~ ArithmeticCircuit a l , FromJSON a , ToJSON a + , ToJSONKey a , SymbolicData f , Context f ~ c , Support f ~ s diff --git a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs index 2c3a32293..228f15f32 100644 --- a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs +++ b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit.hs @@ -37,6 +37,7 @@ import Data.Binary (Binary) import Data.Functor.Rep (Representable (..)) import Data.Map hiding (drop, foldl, foldr, map, null, splitAt, take) +import qualified Data.Set as S import Data.Void (absurd) import GHC.Generics (U1 (..)) import Numeric.Natural (Natural) @@ -86,7 +87,7 @@ desugarRanges :: (Arithmetic a, Binary a, Binary (Rep i), Ord (Rep i), Representable i) => ArithmeticCircuit a i o -> ArithmeticCircuit a i o desugarRanges c = - let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(SysVar k, v) | (k,v) <- toList (acRange c)] + let r' = flip execState c {acOutput = U1} . traverse (uncurry desugarRange) $ [(SysVar v, fromConstant k) | (k, s) <- toList (acRange c), v <- S.toList s] in r' { acRange = mempty, acOutput = acOutput c } ----------------------------------- Information ----------------------------------- diff --git a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs index cbac4d31e..331fd09a6 100644 --- a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs +++ b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Instance.hs @@ -101,7 +101,7 @@ instance (FiniteField a, Haskell.Eq a, Show a, Show (o (Var a i)), Haskell.Ord ( ++ " }" -- TODO: add witness generation info to the JSON object -instance (ToJSON a, ToJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey (Var a i), ToJSON (Rep i)) => ToJSON (ArithmeticCircuit a i o) where +instance (ToJSON a, ToJSON (o (Var a i)), ToJSONKey a, FromJSONKey (Var a i), ToJSON (Rep i)) => ToJSON (ArithmeticCircuit a i o) where toJSON r = object [ "system" .= acSystem r, @@ -110,7 +110,7 @@ instance (ToJSON a, ToJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey (Var ] -- TODO: properly restore the witness generation function -instance (FromJSON a, FromJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey (Var a i), Haskell.Ord (Rep i), FromJSON (Rep i)) => FromJSON (ArithmeticCircuit a i o) where +instance (FromJSON a, FromJSON (o (Var a i)), ToJSONKey (Var a i), FromJSONKey a, Haskell.Ord a, Haskell.Ord (Rep i), FromJSON (Rep i)) => FromJSON (ArithmeticCircuit a i o) where parseJSON = withObject "ArithmeticCircuit" $ \v -> do acSystem <- v .: "system" diff --git a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs index 55386f500..b7a3e33a7 100644 --- a/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs +++ b/symbolic-base/src/ZkFold/Symbolic/Compiler/ArithmeticCircuit/Internal.hs @@ -37,9 +37,10 @@ import Data.Foldable (fold, to import Data.Functor.Rep import Data.Map.Strict hiding (drop, foldl, foldr, map, null, splitAt, take, toList) -import Data.Maybe (fromMaybe) +import Data.Maybe (catMaybes, fromMaybe) import Data.Semialign (unzipDefault) import Data.Semigroup.Generic (GenericSemigroupMonoid (..)) +import qualified Data.Set as S import GHC.Generics (Generic, Par1 (..), U1 (..), (:*:) (..)) import Optics import Prelude hiding (Num (..), drop, length, product, splitAt, @@ -65,7 +66,7 @@ data ArithmeticCircuit a i o = ArithmeticCircuit { acSystem :: Map ByteString (Constraint a i), -- ^ The system of polynomial constraints - acRange :: Map (SysVar i) a, + acRange :: Map a (S.Set (SysVar i)), -- ^ The range constraints [0, a] for the selected variables acWitness :: Map ByteString (i a -> Map ByteString a -> a), -- ^ The witness generation functions @@ -74,10 +75,10 @@ data ArithmeticCircuit a i o = ArithmeticCircuit } deriving (Generic) deriving via (GenericSemigroupMonoid (ArithmeticCircuit a i o)) - instance (Ord (Rep i), o ~ U1) => Semigroup (ArithmeticCircuit a i o) + instance (Ord a, o ~ U1) => Semigroup (ArithmeticCircuit a i o) deriving via (GenericSemigroupMonoid (ArithmeticCircuit a i o)) - instance (Ord (Rep i), o ~ U1) => Monoid (ArithmeticCircuit a i o) + instance (Ord a, o ~ U1) => Monoid (ArithmeticCircuit a i o) instance (NFData a, NFData (o (Var a i)), NFData (Rep i)) => NFData (ArithmeticCircuit a i o) @@ -151,7 +152,7 @@ hlmap :: (forall x . j x -> i x) -> ArithmeticCircuit a i o -> ArithmeticCircuit a j o hlmap f (ArithmeticCircuit s r w o) = ArithmeticCircuit { acSystem = mapVars (imapSysVar f) <$> s - , acRange = mapKeys (imapSysVar f) r + , acRange = S.map (imapSysVar f) <$> r , acWitness = (\g j p -> g (f j) p) <$> w , acOutput = imapVar f <$> o } @@ -167,11 +168,11 @@ behead = liftA2 (,) (set #acOutput U1) acOutput instance HFunctor (ArithmeticCircuit a i) where hmap = over #acOutput -instance Ord (Rep i) => HApplicative (ArithmeticCircuit a i) where +instance (Ord (Rep i), Ord a) => HApplicative (ArithmeticCircuit a i) where hpure = crown mempty hliftA2 f (behead -> (c, o)) (behead -> (d, p)) = crown (c <> d) (f o p) -instance Ord (Rep i) => Package (ArithmeticCircuit a i) where +instance (Ord (Rep i), Ord a) => Package (ArithmeticCircuit a i) where unpackWith f (behead -> (c, o)) = crown c <$> f o packWith f (unzipDefault . fmap behead -> (cs, os)) = crown (fold cs) (f os) @@ -205,7 +206,7 @@ instance zoom #acSystem . modify $ insert (toVar @a p) (p evalConstVar) rangeConstraint (SysVar v) upperBound = - zoom #acRange . modify $ insert v upperBound + zoom #acRange . modify $ insertWith S.union upperBound (S.singleton v) -- FIXME range-constrain other variable types rangeConstraint _ _ = error "Cannot range-constrain this variable" @@ -267,7 +268,7 @@ apply :: i a -> ArithmeticCircuit a (i :*: j) U1 -> ArithmeticCircuit a j U1 apply xs ac = ac { acSystem = fmap (evalPolynomial evalMonomial varF) (acSystem ac) - , acRange = mapKeys' (acRange ac) + , acRange = S.fromList . catMaybes . toList . filterSet <$> acRange ac , acWitness = fmap witF (acWitness ac) , acOutput = U1 } @@ -277,12 +278,12 @@ apply xs ac = ac varF (NewVar v) = var (NewVar v) witF f j = f (xs :*: j) - mapKeys' :: Ord (SysVar j) => Map (SysVar (i :*: j)) a -> Map (SysVar j) a - mapKeys' m = fromList $ - foldrWithKey (\k x ms -> case k of - NewVar v -> (NewVar v, x) : ms - InVar (Right v) -> (InVar v, x) : ms - _ -> ms) [] m + filterSet :: Ord (Rep j) => S.Set (SysVar (i :*: j)) -> S.Set (Maybe (SysVar j)) + filterSet = S.map (\case + NewVar v -> Just (NewVar v) + InVar (Right v) -> Just (InVar v) + _ -> Nothing) + -- TODO: Add proper symbolic application functions