-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmmr.py
258 lines (213 loc) · 7.19 KB
/
mmr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""
Merkle Mountain Range
Adapted from https://github.com/jjyr/mmr.py (MIT License)
Replicates the MMR behavior of the chunk processor, ie :
- the leafs are directly inserted in the tree without any hashing (they're supposed to be already hashes of block headers)
- merging is done by hashing the two hashes together, without prepending any indexes
- the root is computed by bagging the peaks and hashing the result with the size of the MMR
"""
from typing import List, Tuple, Union
import sha3
from starkware.cairo.common.poseidon_hash import (
poseidon_hash,
poseidon_hash_single,
poseidon_hash_many,
)
class PoseidonHasher:
def __init__(self):
self.items = []
def update(self, item: Union[int, bytes]):
if isinstance(item, int):
self.items.append(item)
elif isinstance(item, bytes):
self.items.append(int.from_bytes(item, "big"))
else:
raise TypeError(f"Unsupported type: {type(item)}, {item}")
def digest(self) -> int:
num_items = len(self.items)
if num_items == 1:
result = poseidon_hash_single(self.items[0])
elif num_items == 2:
result = poseidon_hash(self.items[0], self.items[1])
elif num_items > 2:
result = poseidon_hash_many(self.items)
else:
raise ValueError("No item to digest")
self.items.clear()
return result
class KeccakHasher:
def __init__(self):
self.keccak = sha3.keccak_256()
def update(self, item: Union[int, bytes]):
if isinstance(item, int):
self.keccak.update(item.to_bytes(32, "big"))
elif isinstance(item, bytes):
self.keccak.update(item)
else:
raise TypeError(f"Unsupported type: {type(item)}, {item}")
def digest(self) -> int:
result = int.from_bytes(self.keccak.digest(), "big")
self.keccak = sha3.keccak_256()
return result
class MockedHasher:
def __init__(self):
self.hash_count = 0
def update(self, _):
pass
def digest(self) -> int:
self.hash_count += 1
return 0
def is_valid_mmr_size(n):
prev_peak = 0
while n > 0:
i = n.bit_length()
peak = 2**i - 1
if peak > n:
i -= 1
peak = 2**i - 1
if peak == prev_peak:
return False
prev_peak = peak
n -= peak
return n == 0
def tree_pos_height(pos: int) -> int:
"""
calculate pos height in tree
Explains:
https://github.com/mimblewimble/grin/blob/0ff6763ee64e5a14e70ddd4642b99789a1648a32/core/src/core/pmmr.rs#L606
use binary expression to find tree height(all one position number)
return pos height
"""
# convert from 0-based to 1-based position, see document
pos += 1
bit_length = pos.bit_length()
while not (1 << bit_length) - 1 == pos:
most_significant_bits = 1 << bit_length - 1
pos -= most_significant_bits - 1
bit_length = pos.bit_length()
return bit_length - 1
# get left or right sibling offset by height
def sibling_offset(height) -> int:
return (2 << height) - 1
def get_peaks(mmr_size) -> List[int]:
"""
return peaks positions from left to right, 0-index based.
"""
def get_right_peak(height, pos, mmr_size):
"""
find next right peak
peak not exsits if height is -1
"""
# jump to right sibling
pos += sibling_offset(height)
# jump to left child
while pos > mmr_size - 1:
height -= 1
if height < 0:
# no right peak exists
return (height, None)
pos -= 2 << height
return (height, pos)
poss = []
height, pos = left_peak_height_pos(mmr_size)
poss.append(pos)
while height > 0:
height, pos = get_right_peak(height, pos, mmr_size)
if height >= 0:
poss.append(pos)
return poss
def left_peak_height_pos(mmr_size: int) -> Tuple[int, int]:
"""
find left peak
return (left peak height, pos)
"""
def get_left_pos(height):
"""
convert height to binary express, then minus 1 to get 0 based pos
explain:
https://github.com/mimblewimble/grin/blob/master/doc/mmr.md#structure
https://github.com/mimblewimble/grin/blob/0ff6763ee64e5a14e70ddd4642b99789a1648a32/core/src/core/pmmr.rs#L606
For example:
height = 2
# use one-based encoding, mean that left node is all one-bits
# 0b1 is 0 pos, 0b11 is 2 pos 0b111 is 6 pos
one_based_binary_encoding = 0b111
pos = 0b111 - 1 = 6
"""
return (1 << height + 1) - 2
height = 0
prev_pos = 0
pos = get_left_pos(height)
# increase height and get most left pos of tree
# once pos is out of mmr_size we consider previous pos is left peak
while pos < mmr_size:
height += 1
prev_pos = pos
pos = get_left_pos(height)
return (height - 1, prev_pos)
class MMR(object):
"""
MMR
"""
def __init__(
self,
hasher: Union[PoseidonHasher, KeccakHasher, MockedHasher] = PoseidonHasher(),
):
self.last_pos = -1
self.pos_hash = {}
self._hasher = hasher
def add(self, elem: Union[bytes, int]) -> int:
"""
Insert a new leaf, v is a binary value
"""
self.last_pos += 1
# store hash
self.pos_hash[self.last_pos] = elem
height = 0
pos = self.last_pos
# merge same sub trees
# if next pos height is higher implies we are in right children
# and sub trees can be merge
while tree_pos_height(self.last_pos + 1) > height:
# increase pos cursor
self.last_pos += 1
# calculate pos of left child and right child
left_pos = self.last_pos - (2 << height)
right_pos = left_pos + sibling_offset(height)
# calculate parent hash
self._hasher.update(self.pos_hash[left_pos])
self._hasher.update(self.pos_hash[right_pos])
self.pos_hash[self.last_pos] = self._hasher.digest()
height += 1
return pos
def get_root(self) -> int:
"""
MMR root
"""
peaks = get_peaks(self.last_pos + 1)
peaks_values = [self.pos_hash[p] for p in peaks]
bagged = self.bag_peaks(peaks_values)
self._hasher.update(self.last_pos + 1)
self._hasher.update(bagged)
root = self._hasher.digest()
return root
def get_peaks(self) -> list:
peaks = get_peaks(self.last_pos + 1)
peaks_values = [self.pos_hash[p] for p in peaks]
return peaks_values
def bag_peaks(self, peaks: List[int]) -> int:
bags = peaks[-1]
for peak in reversed(peaks[:-1]):
self._hasher.update(peak)
self._hasher.update(bags)
bags = self._hasher.digest()
return bags
if __name__ == "__main__":
poseidon_mmr = MMR(PoseidonHasher())
for i in range(3):
_ = poseidon_mmr.add(i)
print(poseidon_mmr.get_root())
keccak_mmr = MMR(KeccakHasher())
for i in range(3):
_ = keccak_mmr.add(i)
print(keccak_mmr.get_root())