Skip to content

Commit

Permalink
rollTree added
Browse files Browse the repository at this point in the history
  • Loading branch information
hovanja2011 committed Jan 14, 2025
1 parent 4be95a4 commit 1cf312c
Showing 1 changed file with 69 additions and 9 deletions.
78 changes: 69 additions & 9 deletions symbolic-base/src/ZkFold/Symbolic/Data/MerkleTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ import GHC.Generics hiding (Rep)
import Data.Functor.Rep (Representable, pureRep)
import GHC.TypeNats
import Data.Type.Equality (type (~))
import Prelude (const, ($), undefined, Traversable, Integer, (==))
import Prelude (const, ($), undefined, Traversable, Integer)
import qualified Prelude as P
import ZkFold.Symbolic.Data.Bool (Bool (..), BoolType (..))
import ZkFold.Symbolic.Data.Bool (Bool (..))
import Data.Foldable (Foldable (..))
import ZkFold.Symbolic.Data.Input (SymbolicInput)
import ZkFold.Symbolic.Data.Conditional (bool, Conditional)
import ZkFold.Symbolic.Class
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Symbolic.Data.List (List, uncons, emptyList, (.:))
import ZkFold.Symbolic.Data.List (List (..), uncons, emptyList, (.:), ListItem (..))
import ZkFold.Base.Data.Vector (knownNat)
import ZkFold.Symbolic.Data.Maybe (Maybe, just, nothing, maybe)
import ZkFold.Base.Protocol.IVC.Accumulator (x)
import ZkFold.Symbolic.Data.Hash (Hashable (hasher))
import ZkFold.Symbolic.Data.Payloaded
import Data.List.Infinite (Infinite(..))
import ZkFold.Symbolic.MonadCircuit

data MerkleTree (d :: Natural) a = MerkleTree {
mHash :: (Context a) (Layout a)
Expand Down Expand Up @@ -89,7 +91,8 @@ findPath p (MerkleTree _ ml) = just @c $ MerkleTreePath (bool path (emptyList @c
lookup :: forall x c d.
( KnownNat d
, SymbolicOutput x
, Context x ~ c, Conditional (Bool c) Integer
, Context x ~ c
, Conditional (Bool c) Integer
) => MerkleTree d x -> MerkleTreePath d (Bool c) -> x
lookup (MerkleTree _ ml) (MerkleTreePath p) = val ml $ ind d 0 p
where
Expand All @@ -101,8 +104,30 @@ lookup (MerkleTree _ ml) (MerkleTreePath p) = val ml $ ind d 0 p
in bool (ind (iter-1) (2*i) ls) (ind (iter-1) (2*i+1) ls) l

val :: List c x -> Integer -> x
val mt 0 = let (l, _) = uncons @c @x mt in l
val mt i = let (_, ls) = uncons @c @x mt in val ls (i-1)
val mt i = let (l, ls) = uncons @c @x mt in
case i of
0 -> l
_ -> val ls (i-1)


rollTree :: forall d a c.
( c ~ Context a
, SymbolicOutput a
, KnownNat d
, Hashable a a
, AdditiveSemigroup a
) => MerkleTree d a -> MerkleTree (d - 1) a
rollTree (MerkleTree h l) = MerkleTree h (solve d l)
where
d = 2 P.^ (knownNat @d :: Integer)

solve :: Integer -> List c a -> List c a
solve 0 _ = emptyList @c
solve i lst =
let (x1, list1) = uncons lst
(x2, olist) = uncons list1
in hasher (x1 + x2) .: solve (i - 2) olist



-- | Inserts an element at a specified position in a tree
Expand All @@ -116,5 +141,40 @@ insert (MerkleTree h ls) p x = MerkleTree (embed $ pureRep zero) ls
-- replace :: (x -> Bool (Context x)) -> MerkleTree d x -> x -> MerkleTree d x

-- | Returns the next path in a tree
incrementPath :: MerkleTreePath d x -> MerkleTreePath d x
incrementPath = undefined
incrementPath :: forall c d.
( KnownNat d
, Symbolic c
, Conditional (Bool c) Integer
) => MerkleTreePath d (Bool c) -> MerkleTreePath d (Bool c)
incrementPath (MerkleTreePath p) = MerkleTreePath (path $ ind d 0 p + 1)
where
d = knownNat @d :: Integer

ind :: Integer -> Integer -> List c (Bool c) -> Integer
ind 0 i _ = i
ind iter i ps = let (l, ls) = uncons @c ps
in bool (ind (iter-1) (2*i) ls) (ind (iter-1) (2*i+1) ls) l

path :: Integer -> List c (Bool c)
path val = foldl (\nl ni -> Bool (embed (Par1 $ fromConstant ni)) .: nl) (emptyList @c)
$ P.map (\i -> mod (div val (2 P.^ i)) 2) [0 .. d]


-- | Returns the previous path in a tree
-- decrementPath :: forall c d.
-- ( KnownNat d
-- , Symbolic c
-- , Conditional (Bool c) Integer
-- ) => MerkleTreePath d (Bool c) -> MerkleTreePath d (Bool c)
-- decrementPath (MerkleTreePath p) = MerkleTreePath (path $ ind d 0 p - 1)
-- where
-- d = knownNat @d :: Integer

-- ind :: Integer -> Integer -> List c (Bool c) -> Integer
-- ind 0 i _ = i
-- ind iter i ps = let (l, ls) = uncons @c ps
-- in bool (ind (iter-1) (2*i) ls) (ind (iter-1) (2*i+1) ls) l

-- path :: Integer -> List c (Bool c)
-- path val = foldl (\nl ni -> Bool (embed (Par1 $ fromConstant ni)) .: nl) (emptyList @c)
-- $ P.map (\i -> mod (div val (2 P.^ i)) 2) [0 .. d]

0 comments on commit 1cf312c

Please sign in to comment.