Skip to content

Commit

Permalink
initial refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Okm165 committed Oct 16, 2024
1 parent 23da65c commit fdde4de
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 106 deletions.
2 changes: 1 addition & 1 deletion lib/block_header.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func extract_parent_hash_little{range_check_ptr}(rlp: felt*) -> (res: Uint256) {
//
// Reference: https://ethereum.org/en/developers/docs/data-structures-and-encoding/rlp/#definition
func get_bigint_byte_size{range_check_ptr}(byte: felt) -> felt {
%{ memory[ap]=1 if ids.byte<=127 else 0 %}
%{ memory[ap] = 1 if ids.byte <= 127 else 0 %}
ap += 1;
let is_single_byte = [ap - 1];

Expand Down
15 changes: 6 additions & 9 deletions lib/mmr.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,8 @@ func is_valid_mmr_size_inner{range_check_ptr, pow2_array: felt*}(n: felt, prev_p
func compute_height_pre_alloc_pow2{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
alloc_locals;
local bit_length;
%{
x = ids.x
ids.bit_length = x.bit_length()
%}
%{ ids.bit_length = ids.x.bit_length() %}

// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits

Expand Down Expand Up @@ -143,10 +141,8 @@ func compute_height_pre_alloc_pow2{range_check_ptr, pow2_array: felt*}(x: felt)
func compute_first_peak_pos{range_check_ptr, pow2_array: felt*}(mmr_len: felt) -> felt {
alloc_locals;
local bit_length;
%{
mmr_len = ids.mmr_len
ids.bit_length = mmr_len.bit_length()
%}
%{ ids.bit_length = ids.mmr_len.bit_length() %}

// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits

Expand Down Expand Up @@ -238,8 +234,8 @@ func left_child_jump_until_inside_mmr{range_check_ptr, pow2_array: felt*, mmr_le
) -> felt {
alloc_locals;
local in_mmr;

%{ ids.in_mmr = 1 if ids.left_child<=ids.mmr_len else 0 %}

if (in_mmr != 0) {
// Ensure left_child <= mmr_len
assert [range_check_ptr] = mmr_len - left_child;
Expand Down Expand Up @@ -284,6 +280,7 @@ func get_full_mmr_peak_values{
// %{ print(f"Asked position : {ids.position}, mmr_offset : {ids.mmr_offset}") %}
local is_position_in_mmr_array: felt;
%{ ids.is_position_in_mmr_array= 1 if ids.position > ids.mmr_offset else 0 %}

if (is_position_in_mmr_array != 0) {
// %{ print(f'getting from mmr_array at index {ids.position-ids.mmr_offset -1}') %}
// ensure position > mmr_offset
Expand Down
48 changes: 24 additions & 24 deletions lib/mpt.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
let list_prefix = extract_byte_at_pos(rlp[0], 0, pow2_array);
local long_short_list: felt; // 0 for short, !=0 for long.
%{
if 0xc0 <= ids.list_prefix <= 0xf7:
from tools.py.hints import is_short_list, is_long_list
if is_short_list(ids.list_prefix):
ids.long_short_list = 0
#print("List type : short")
elif 0xf8 <= ids.list_prefix <= 0xff:
elif is_long_list(ids.list_prefix):
ids.long_short_list = 1
#print("List type: long")
else:
print("Not a list.")
raise ValueError(f"Invalid list prefix: {hex(ids.list_prefix)}. Not a recognized list type.")
%}

local first_item_start_offset: felt;
local list_len: felt; // Bytes length of the list. (not including the prefix)

Expand Down Expand Up @@ -211,16 +211,19 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// (range [0x80, 0xb7] (dec. [128, 183])).

local first_item_type;
local first_item_len;
local second_item_starts_at_byte;
%{
if 0 <= ids.first_item_prefix <= 0x7f:
ids.first_item_type = 0 # Single byte
elif 0x80 <= ids.first_item_prefix <= 0xb7:
ids.first_item_type = 1 # Short string
from tools.py.hints import is_single_byte, is_short_string
if is_single_byte(ids.first_item_prefix):
ids.first_item_type = 0
elif is_short_string(ids.first_item_prefix):
ids.first_item_type = 1
else:
print(f"Unsupported first item type for prefix {ids.first_item_prefix=}")
raise ValueError(f"Unsupported first item prefix: {hex(ids.first_item_prefix)}.")
%}

local first_item_len;
local second_item_starts_at_byte;

if (first_item_type != 0) {
// Short string
assert [range_check_ptr + 3] = first_item_prefix - 0x80;
Expand Down Expand Up @@ -251,17 +254,15 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// %{ print("second_item_prefix", hex(ids.second_item_prefix)) %}
local second_item_type: felt;
%{
if 0x00 <= ids.second_item_prefix <= 0x7f:
from tools.py.hints import is_single_byte, is_short_string, is_long_string
if is_single_byte(ids.second_item_prefix):
ids.second_item_type = 0
#print(f"2nd item : single byte")
elif 0x80 <= ids.second_item_prefix <= 0xb7:
elif is_short_string(ids.second_item_prefix):
ids.second_item_type = 1
#print(f"2nd item : short string {ids.second_item_prefix - 0x80} bytes")
elif 0xb8 <= ids.second_item_prefix <= 0xbf:
elif is_long_string(ids.second_item_prefix):
ids.second_item_type = 2
#print(f"2nd item : long string (len_len {ids.second_item_prefix - 0xb7} bytes)")
else:
print(f"2nd item : unknown type {ids.second_item_prefix}")
raise ValueError(f"Unsupported second item prefix: {hex(ids.second_item_prefix)}.")
%}

local second_item_bytes_len;
Expand Down Expand Up @@ -670,14 +671,13 @@ func jump_branch_node_till_element_at_index{range_check_ptr, bitwise_ptr: Bitwis
let item_prefix = extract_byte_at_pos(rlp[prefix_start_word], prefix_start_offset, pow2_array);
local item_type: felt;
%{
if 0x00 <= ids.item_prefix <= 0x7f:
from tools.py.hints import is_single_byte, is_short_string
if is_single_byte(ids.item_prefix):
ids.item_type = 0
#print(f"item : single byte")
elif 0x80 <= ids.item_prefix <= 0xb7:
elif is_short_string(ids.item_prefix):
ids.item_type = 1
#print(f"item : short string at item {ids.item_start_index} {ids.item_prefix - 0x80} bytes")
else:
print(f"item : unknown type {ids.item_prefix} for a branch node. Should be single byte or short string only.")
raise ValueError(f"Unsupported item prefix: {hex(ids.item_prefix)} for a branch node. Should be single byte or short string only.")
%}

if (item_type == 0) {
Expand Down
42 changes: 15 additions & 27 deletions lib/rlp_little.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,8 @@ func count_leading_zeroes_from_uint256_le_before_reversion{bitwise_ptr: BitwiseB
alloc_locals;
%{
from tools.py.utils import parse_int_to_bytes, count_leading_zero_nibbles_from_hex
input_ = ids.x.low + 2**128*ids.x.high
input_bytes = parse_int_to_bytes(input_)
#print(f"input hex {input_bytes.hex()}")
reversed_bytes = input_bytes[::-1]
#print("reversed bytes", reversed_bytes)
reversed_hex = reversed_bytes.hex()
#print("reversed hex", reversed_hex)
if ids.cut_nibble == 1:
reversed_hex = reversed_hex[1:]
#print(f"Reversed hex final : {reversed_hex}")
expected_leading_zeroes = count_leading_zero_nibbles_from_hex(reversed_hex)
#print(f"Expected leading zeroes {expected_leading_zeroes}")
reversed_hex = parse_int_to_bytes(ids.x.low + (2 ** 128) * ids.x.high)[::-1].hex()
expected_leading_zeroes = count_leading_zero_nibbles_from_hex(reversed_hex[1:] if ids.cut_nibble == 1 else reversed_hex)
%}
local x_f: Uint256;
local first_nibble_is_zero;
Expand Down Expand Up @@ -424,18 +414,16 @@ func extract_nibble_from_key_be{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// Consenquently, we get the nibble from high part of the key only if :
// - nibble_index is in [0, 31] and key_nibbles > 32
%{ ids.get_nibble_from_low = 1 if (0 <= ids.nibble_index <= 31 and ids.key_nibbles <= 32) or (32 <= ids.nibble_index <= 63 and ids.key_nibbles > 32) else 0 %}
if (key.high != 0) {
// key_nibbles > 32
}
// %{
// print(f"Key low: {hex(ids.key.low)}")
// print(f"Key high: {hex(ids.key.high)}")
// print(f"nibble_index: {ids.nibble_index}")
// print(f"key_nibbles: {ids.key_nibbles}")
// print(f"key_leading_zeroes_nibbles: {ids.key_leading_zeroes_nibbles}")
// %}
%{
#print(f"Key low: {hex(ids.key.low)}")
#print(f"Key high: {hex(ids.key.high)}")
#print(f"nibble_index: {ids.nibble_index}")
#print(f"key_nibbles: {ids.key_nibbles}")
#print(f"key_leading_zeroes_nibbles: {ids.key_leading_zeroes_nibbles}")
key_hex = ids.key_leading_zeroes_nibbles*'0'+hex(ids.key.low + 2**128*ids.key.high)[2:]
#print(f"Key hex: {key_hex}")
expected_nibble = int(key_hex[ids.nibble_index+ids.key_leading_zeroes_nibbles], 16)
key_hex = ids.key_leading_zeroes_nibbles * '0' + hex(ids.key.low + (2 ** 128) * ids.key.high)[2:]
expected_nibble = int(key_hex[ids.nibble_index + ids.key_leading_zeroes_nibbles], 16)
%}
if (get_nibble_from_low != 0) {
local offset;
Expand Down Expand Up @@ -619,10 +607,10 @@ func extract_n_bytes_from_le_64_chunks_array{range_check_ptr}(
// Inlined felt_divmod (unsigned_div_rem).
let q = [ap];
let r = [ap + 1];
%{
ids.q, ids.r = divmod(memory[ids.array + ids.start_word + ids.i], ids.pow_cut)
#print(f"val={memory[ids.array + ids.start_word + ids.i]} q={ids.q} r={ids.r}")
%}
%{ ids.q, ids.r = divmod(memory[ids.array + ids.start_word + ids.i], ids.pow_cut) %}
// %{
// print(f"val={memory[ids.array + ids.start_word + ids.i]} q={ids.q} r={ids.r}")
// %}
ap += 2;
tempvar offset = 3 * n_words_handled;
assert [range_check_ptr + offset] = q;
Expand Down
65 changes: 20 additions & 45 deletions lib/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func count_trailing_zeroes_128{bitwise_ptr: BitwiseBuiltin*}(x: felt, pow2_array
%{
from tools.py.utils import count_trailing_zero_bytes_from_int
ids.trailing_zeroes_bytes = count_trailing_zero_bytes_from_int(ids.x)
#print(f"Input: {hex(ids.x)}_{ids.trailing_zeroes_bytes}Tr_Zerobytes")
%}
// Verify.
if (trailing_zeroes_bytes == 0) {
Expand Down Expand Up @@ -303,10 +302,7 @@ func get_felt_bitlength{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
}
alloc_locals;
local bit_length;
%{
x = ids.x
ids.bit_length = x.bit_length()
%}
%{ ids.bit_length = ids.x.bit_length() %}
// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits
tempvar N = pow2_array[bit_length];
Expand Down Expand Up @@ -334,11 +330,8 @@ func get_felt_bitlength_128{range_check_ptr, pow2_array: felt*}(x: felt) -> felt
}
alloc_locals;
local bit_length;
%{ ids.bit_length = ids.x.bit_length() %}

%{
x = ids.x
ids.bit_length = x.bit_length()
%}
if (bit_length == 128) {
assert [range_check_ptr] = x - 2 ** 127;
tempvar range_check_ptr = range_check_ptr + 1;
Expand Down Expand Up @@ -393,10 +386,10 @@ func felt_divmod_2pow32{range_check_ptr}(value: felt) -> (q: felt, r: felt) {
%{
from starkware.cairo.common.math_utils import assert_integer
assert_integer(ids.DIV_32)
assert 0 < ids.DIV_32 <= PRIME // range_check_builtin.bound, \
f'div={hex(ids.DIV_32)} is out of the valid range.'
ids.q, ids.r = divmod(ids.value, ids.DIV_32)
if not (0 < ids.DIV_32 <= PRIME):
raise ValueError(f'div={hex(ids.DIV_32)} is out of the valid range.')
%}
%{ ids.q, ids.r = divmod(ids.value, ids.DIV_32) %}
assert [range_check_ptr + 2] = DIV_32_MINUS_1 - r;
let range_check_ptr = range_check_ptr + 3;

Expand Down Expand Up @@ -435,10 +428,10 @@ func felt_divmod{range_check_ptr}(value, div) -> (q: felt, r: felt) {
%{
from starkware.cairo.common.math_utils import assert_integer
assert_integer(ids.div)
assert 0 < ids.div <= PRIME // range_check_builtin.bound, \
f'div={hex(ids.div)} is out of the valid range.'
ids.q, ids.r = divmod(ids.value, ids.div)
if not (0 < ids.div <= PRIME):
raise ValueError(f'div={hex(ids.div)} is out of the valid range.')
%}
%{ ids.q, ids.r = divmod(ids.value, ids.div) %}
assert [range_check_ptr + 2] = div - 1 - r;
let range_check_ptr = range_check_ptr + 3;

Expand Down Expand Up @@ -480,11 +473,8 @@ func word_reverse_endian_64{bitwise_ptr: BitwiseBuiltin*}(word: felt) -> (res: f
// res: the byte-reversed integer.
func word_reverse_endian_16_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**16
word_bytes=word.to_bytes(2, byteorder='big')
for i in range(2):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 2, memory, ap)
%}
ap += 2;

Expand All @@ -510,11 +500,8 @@ func word_reverse_endian_16_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_24_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**24
word_bytes=word.to_bytes(3, byteorder='big')
for i in range(3):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 3, memory, ap)
%}
ap += 3;

Expand Down Expand Up @@ -543,11 +530,8 @@ func word_reverse_endian_24_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_32_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**32
word_bytes=word.to_bytes(4, byteorder='big')
for i in range(4):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 4, memory, ap)
%}
ap += 4;

Expand Down Expand Up @@ -579,11 +563,8 @@ func word_reverse_endian_32_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_40_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**40
word_bytes=word.to_bytes(5, byteorder='big')
for i in range(5):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 5, memory, ap)
%}
ap += 5;

Expand Down Expand Up @@ -618,11 +599,8 @@ func word_reverse_endian_40_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_48_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**48
word_bytes=word.to_bytes(6, byteorder='big')
for i in range(6):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 6, memory, ap)
%}
ap += 6;

Expand Down Expand Up @@ -660,11 +638,8 @@ func word_reverse_endian_48_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_56_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**56
word_bytes=word.to_bytes(7, byteorder='big')
for i in range(7):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 7, memory, ap)
%}
ap += 7;

Expand Down
30 changes: 30 additions & 0 deletions tools/py/hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
def is_single_byte(prefix):
"""Check if the prefix indicates a single byte (0x00 to 0x7f)."""
return 0x00 <= prefix <= 0x7F


def is_short_string(prefix):
"""Check if the prefix indicates a short string (0x80 to 0xb7)."""
return 0x80 <= prefix <= 0xB7


def is_long_string(prefix):
"""Check if the prefix indicates a long string (0xb8 to 0xbf)."""
return 0xB8 <= prefix <= 0xBF


def is_short_list(prefix):
"""Check if the prefix indicates a short list (0xc0 to 0xf7)."""
return 0xC0 <= prefix <= 0xF7


def is_long_list(prefix):
"""Check if the prefix indicates a long list (0xf8 to 0xff)."""
return 0xF8 <= prefix <= 0xFF


def write_word_to_memory(word: int, n: int, memory, ap) -> None:
assert word < 2 ** (8 * n), f"Word value {word} exceeds {8 * n} bits."
word_bytes = word.to_bytes(n, byteorder="big")
for i in range(n):
memory[ap + i] = word_bytes[i]

0 comments on commit fdde4de

Please sign in to comment.