Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1 - python hints refactored #8

Merged
merged 27 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/block_header.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func extract_parent_hash_little{range_check_ptr}(rlp: felt*) -> (res: Uint256) {
//
// Reference: https://ethereum.org/en/developers/docs/data-structures-and-encoding/rlp/#definition
func get_bigint_byte_size{range_check_ptr}(byte: felt) -> felt {
%{ memory[ap]=1 if ids.byte<=127 else 0 %}
%{ memory[ap] = 1 if ids.byte <= 127 else 0 %}
ap += 1;
let is_single_byte = [ap - 1];

Expand Down
15 changes: 6 additions & 9 deletions lib/mmr.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,8 @@ func is_valid_mmr_size_inner{range_check_ptr, pow2_array: felt*}(n: felt, prev_p
func compute_height_pre_alloc_pow2{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
alloc_locals;
local bit_length;
%{
x = ids.x
ids.bit_length = x.bit_length()
%}
%{ ids.bit_length = ids.x.bit_length() %}

// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits

Expand Down Expand Up @@ -143,10 +141,8 @@ func compute_height_pre_alloc_pow2{range_check_ptr, pow2_array: felt*}(x: felt)
func compute_first_peak_pos{range_check_ptr, pow2_array: felt*}(mmr_len: felt) -> felt {
alloc_locals;
local bit_length;
%{
mmr_len = ids.mmr_len
ids.bit_length = mmr_len.bit_length()
%}
%{ ids.bit_length = ids.mmr_len.bit_length() %}

// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits

Expand Down Expand Up @@ -238,8 +234,8 @@ func left_child_jump_until_inside_mmr{range_check_ptr, pow2_array: felt*, mmr_le
) -> felt {
alloc_locals;
local in_mmr;

%{ ids.in_mmr = 1 if ids.left_child<=ids.mmr_len else 0 %}

if (in_mmr != 0) {
// Ensure left_child <= mmr_len
assert [range_check_ptr] = mmr_len - left_child;
Expand Down Expand Up @@ -284,6 +280,7 @@ func get_full_mmr_peak_values{
// %{ print(f"Asked position : {ids.position}, mmr_offset : {ids.mmr_offset}") %}
local is_position_in_mmr_array: felt;
%{ ids.is_position_in_mmr_array= 1 if ids.position > ids.mmr_offset else 0 %}

if (is_position_in_mmr_array != 0) {
// %{ print(f'getting from mmr_array at index {ids.position-ids.mmr_offset -1}') %}
// ensure position > mmr_offset
Expand Down
48 changes: 24 additions & 24 deletions lib/mpt.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
let list_prefix = extract_byte_at_pos(rlp[0], 0, pow2_array);
local long_short_list: felt; // 0 for short, !=0 for long.
%{
if 0xc0 <= ids.list_prefix <= 0xf7:
from tools.py.hints import is_short_list, is_long_list
if is_short_list(ids.list_prefix):
ids.long_short_list = 0
#print("List type : short")
elif 0xf8 <= ids.list_prefix <= 0xff:
elif is_long_list(ids.list_prefix):
ids.long_short_list = 1
#print("List type: long")
else:
print("Not a list.")
raise ValueError(f"Invalid list prefix: {hex(ids.list_prefix)}. Not a recognized list type.")
%}

local first_item_start_offset: felt;
local list_len: felt; // Bytes length of the list. (not including the prefix)

Expand Down Expand Up @@ -211,16 +211,19 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// (range [0x80, 0xb7] (dec. [128, 183])).

local first_item_type;
local first_item_len;
local second_item_starts_at_byte;
%{
if 0 <= ids.first_item_prefix <= 0x7f:
ids.first_item_type = 0 # Single byte
elif 0x80 <= ids.first_item_prefix <= 0xb7:
ids.first_item_type = 1 # Short string
from tools.py.hints import is_single_byte, is_short_string
if is_single_byte(ids.first_item_prefix):
ids.first_item_type = 0
elif is_short_string(ids.first_item_prefix):
ids.first_item_type = 1
else:
print(f"Unsupported first item type for prefix {ids.first_item_prefix=}")
raise ValueError(f"Unsupported first item prefix: {hex(ids.first_item_prefix)}.")
%}

local first_item_len;
local second_item_starts_at_byte;

if (first_item_type != 0) {
// Short string
assert [range_check_ptr + 3] = first_item_prefix - 0x80;
Expand Down Expand Up @@ -251,17 +254,15 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// %{ print("second_item_prefix", hex(ids.second_item_prefix)) %}
local second_item_type: felt;
%{
if 0x00 <= ids.second_item_prefix <= 0x7f:
from tools.py.hints import is_single_byte, is_short_string, is_long_string
if is_single_byte(ids.second_item_prefix):
ids.second_item_type = 0
#print(f"2nd item : single byte")
elif 0x80 <= ids.second_item_prefix <= 0xb7:
elif is_short_string(ids.second_item_prefix):
ids.second_item_type = 1
#print(f"2nd item : short string {ids.second_item_prefix - 0x80} bytes")
elif 0xb8 <= ids.second_item_prefix <= 0xbf:
elif is_long_string(ids.second_item_prefix):
ids.second_item_type = 2
#print(f"2nd item : long string (len_len {ids.second_item_prefix - 0xb7} bytes)")
else:
print(f"2nd item : unknown type {ids.second_item_prefix}")
raise ValueError(f"Unsupported second item prefix: {hex(ids.second_item_prefix)}.")
%}

local second_item_bytes_len;
Expand Down Expand Up @@ -670,14 +671,13 @@ func jump_branch_node_till_element_at_index{range_check_ptr, bitwise_ptr: Bitwis
let item_prefix = extract_byte_at_pos(rlp[prefix_start_word], prefix_start_offset, pow2_array);
local item_type: felt;
%{
if 0x00 <= ids.item_prefix <= 0x7f:
from tools.py.hints import is_single_byte, is_short_string
if is_single_byte(ids.item_prefix):
ids.item_type = 0
#print(f"item : single byte")
elif 0x80 <= ids.item_prefix <= 0xb7:
elif is_short_string(ids.item_prefix):
ids.item_type = 1
#print(f"item : short string at item {ids.item_start_index} {ids.item_prefix - 0x80} bytes")
else:
print(f"item : unknown type {ids.item_prefix} for a branch node. Should be single byte or short string only.")
raise ValueError(f"Unsupported item prefix: {hex(ids.item_prefix)} for a branch node. Should be single byte or short string only.")
%}

if (item_type == 0) {
Expand Down
42 changes: 15 additions & 27 deletions lib/rlp_little.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,8 @@ func count_leading_zeroes_from_uint256_le_before_reversion{bitwise_ptr: BitwiseB
alloc_locals;
%{
from tools.py.utils import parse_int_to_bytes, count_leading_zero_nibbles_from_hex
input_ = ids.x.low + 2**128*ids.x.high
input_bytes = parse_int_to_bytes(input_)
#print(f"input hex {input_bytes.hex()}")
reversed_bytes = input_bytes[::-1]
#print("reversed bytes", reversed_bytes)
reversed_hex = reversed_bytes.hex()
#print("reversed hex", reversed_hex)
if ids.cut_nibble == 1:
reversed_hex = reversed_hex[1:]
#print(f"Reversed hex final : {reversed_hex}")
expected_leading_zeroes = count_leading_zero_nibbles_from_hex(reversed_hex)
#print(f"Expected leading zeroes {expected_leading_zeroes}")
reversed_hex = parse_int_to_bytes(ids.x.low + (2 ** 128) * ids.x.high)[::-1].hex()
expected_leading_zeroes = count_leading_zero_nibbles_from_hex(reversed_hex[1:] if ids.cut_nibble == 1 else reversed_hex)
%}
local x_f: Uint256;
local first_nibble_is_zero;
Expand Down Expand Up @@ -424,18 +414,16 @@ func extract_nibble_from_key_be{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// Consenquently, we get the nibble from high part of the key only if :
// - nibble_index is in [0, 31] and key_nibbles > 32
%{ ids.get_nibble_from_low = 1 if (0 <= ids.nibble_index <= 31 and ids.key_nibbles <= 32) or (32 <= ids.nibble_index <= 63 and ids.key_nibbles > 32) else 0 %}
if (key.high != 0) {
// key_nibbles > 32
}
// %{
// print(f"Key low: {hex(ids.key.low)}")
// print(f"Key high: {hex(ids.key.high)}")
// print(f"nibble_index: {ids.nibble_index}")
// print(f"key_nibbles: {ids.key_nibbles}")
// print(f"key_leading_zeroes_nibbles: {ids.key_leading_zeroes_nibbles}")
// %}
%{
#print(f"Key low: {hex(ids.key.low)}")
#print(f"Key high: {hex(ids.key.high)}")
#print(f"nibble_index: {ids.nibble_index}")
#print(f"key_nibbles: {ids.key_nibbles}")
#print(f"key_leading_zeroes_nibbles: {ids.key_leading_zeroes_nibbles}")
key_hex = ids.key_leading_zeroes_nibbles*'0'+hex(ids.key.low + 2**128*ids.key.high)[2:]
#print(f"Key hex: {key_hex}")
expected_nibble = int(key_hex[ids.nibble_index+ids.key_leading_zeroes_nibbles], 16)
key_hex = ids.key_leading_zeroes_nibbles * '0' + hex(ids.key.low + (2 ** 128) * ids.key.high)[2:]
expected_nibble = int(key_hex[ids.nibble_index + ids.key_leading_zeroes_nibbles], 16)
%}
if (get_nibble_from_low != 0) {
local offset;
Expand Down Expand Up @@ -619,10 +607,10 @@ func extract_n_bytes_from_le_64_chunks_array{range_check_ptr}(
// Inlined felt_divmod (unsigned_div_rem).
let q = [ap];
let r = [ap + 1];
%{
ids.q, ids.r = divmod(memory[ids.array + ids.start_word + ids.i], ids.pow_cut)
#print(f"val={memory[ids.array + ids.start_word + ids.i]} q={ids.q} r={ids.r}")
%}
%{ ids.q, ids.r = divmod(memory[ids.array + ids.start_word + ids.i], ids.pow_cut) %}
// %{
// print(f"val={memory[ids.array + ids.start_word + ids.i]} q={ids.q} r={ids.r}")
// %}
ap += 2;
tempvar offset = 3 * n_words_handled;
assert [range_check_ptr + offset] = q;
Expand Down
65 changes: 20 additions & 45 deletions lib/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func count_trailing_zeroes_128{bitwise_ptr: BitwiseBuiltin*}(x: felt, pow2_array
%{
from tools.py.utils import count_trailing_zero_bytes_from_int
ids.trailing_zeroes_bytes = count_trailing_zero_bytes_from_int(ids.x)
#print(f"Input: {hex(ids.x)}_{ids.trailing_zeroes_bytes}Tr_Zerobytes")
%}
// Verify.
if (trailing_zeroes_bytes == 0) {
Expand Down Expand Up @@ -303,10 +302,7 @@ func get_felt_bitlength{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
}
alloc_locals;
local bit_length;
%{
x = ids.x
ids.bit_length = x.bit_length()
%}
%{ ids.bit_length = ids.x.bit_length() %}
// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits
tempvar N = pow2_array[bit_length];
Expand Down Expand Up @@ -334,11 +330,8 @@ func get_felt_bitlength_128{range_check_ptr, pow2_array: felt*}(x: felt) -> felt
}
alloc_locals;
local bit_length;
%{ ids.bit_length = ids.x.bit_length() %}

%{
x = ids.x
ids.bit_length = x.bit_length()
%}
if (bit_length == 128) {
assert [range_check_ptr] = x - 2 ** 127;
tempvar range_check_ptr = range_check_ptr + 1;
Expand Down Expand Up @@ -393,10 +386,10 @@ func felt_divmod_2pow32{range_check_ptr}(value: felt) -> (q: felt, r: felt) {
%{
from starkware.cairo.common.math_utils import assert_integer
assert_integer(ids.DIV_32)
assert 0 < ids.DIV_32 <= PRIME // range_check_builtin.bound, \
f'div={hex(ids.DIV_32)} is out of the valid range.'
ids.q, ids.r = divmod(ids.value, ids.DIV_32)
if not (0 < ids.DIV_32 <= PRIME):
raise ValueError(f'div={hex(ids.DIV_32)} is out of the valid range.')
%}
%{ ids.q, ids.r = divmod(ids.value, ids.DIV_32) %}
assert [range_check_ptr + 2] = DIV_32_MINUS_1 - r;
let range_check_ptr = range_check_ptr + 3;

Expand Down Expand Up @@ -435,10 +428,10 @@ func felt_divmod{range_check_ptr}(value, div) -> (q: felt, r: felt) {
%{
from starkware.cairo.common.math_utils import assert_integer
assert_integer(ids.div)
assert 0 < ids.div <= PRIME // range_check_builtin.bound, \
f'div={hex(ids.div)} is out of the valid range.'
ids.q, ids.r = divmod(ids.value, ids.div)
if not (0 < ids.div <= PRIME):
raise ValueError(f'div={hex(ids.div)} is out of the valid range.')
%}
%{ ids.q, ids.r = divmod(ids.value, ids.div) %}
assert [range_check_ptr + 2] = div - 1 - r;
let range_check_ptr = range_check_ptr + 3;

Expand Down Expand Up @@ -480,11 +473,8 @@ func word_reverse_endian_64{bitwise_ptr: BitwiseBuiltin*}(word: felt) -> (res: f
// res: the byte-reversed integer.
func word_reverse_endian_16_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**16
word_bytes=word.to_bytes(2, byteorder='big')
for i in range(2):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 2, memory, ap)
%}
ap += 2;

Expand All @@ -510,11 +500,8 @@ func word_reverse_endian_16_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_24_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**24
word_bytes=word.to_bytes(3, byteorder='big')
for i in range(3):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 3, memory, ap)
%}
ap += 3;

Expand Down Expand Up @@ -543,11 +530,8 @@ func word_reverse_endian_24_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_32_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**32
word_bytes=word.to_bytes(4, byteorder='big')
for i in range(4):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 4, memory, ap)
%}
ap += 4;

Expand Down Expand Up @@ -579,11 +563,8 @@ func word_reverse_endian_32_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_40_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**40
word_bytes=word.to_bytes(5, byteorder='big')
for i in range(5):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 5, memory, ap)
%}
ap += 5;

Expand Down Expand Up @@ -618,11 +599,8 @@ func word_reverse_endian_40_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_48_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**48
word_bytes=word.to_bytes(6, byteorder='big')
for i in range(6):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 6, memory, ap)
%}
ap += 6;

Expand Down Expand Up @@ -660,11 +638,8 @@ func word_reverse_endian_48_RC{range_check_ptr}(word: felt) -> felt {
// res: the byte-reversed integer.
func word_reverse_endian_56_RC{range_check_ptr}(word: felt) -> felt {
%{
word = ids.word
assert word < 2**56
word_bytes=word.to_bytes(7, byteorder='big')
for i in range(7):
memory[ap+i] = word_bytes[i]
from tools.py.hints import write_word_to_memory
write_word_to_memory(ids.word, 7, memory, ap)
%}
ap += 7;

Expand Down
30 changes: 30 additions & 0 deletions tools/py/hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
def is_single_byte(prefix):
"""Check if the prefix indicates a single byte (0x00 to 0x7f)."""
return 0x00 <= prefix <= 0x7F


def is_short_string(prefix):
"""Check if the prefix indicates a short string (0x80 to 0xb7)."""
return 0x80 <= prefix <= 0xB7


def is_long_string(prefix):
"""Check if the prefix indicates a long string (0xb8 to 0xbf)."""
return 0xB8 <= prefix <= 0xBF


def is_short_list(prefix):
"""Check if the prefix indicates a short list (0xc0 to 0xf7)."""
return 0xC0 <= prefix <= 0xF7


def is_long_list(prefix):
"""Check if the prefix indicates a long list (0xf8 to 0xff)."""
return 0xF8 <= prefix <= 0xFF


def write_word_to_memory(word: int, n: int, memory, ap) -> None:
assert word < 2 ** (8 * n), f"Word value {word} exceeds {8 * n} bits."
word_bytes = word.to_bytes(n, byteorder="big")
for i in range(n):
memory[ap + i] = word_bytes[i]
Loading