Skip to content

Commit

Permalink
feat: rust vm hints (#412)
Browse files Browse the repository at this point in the history
Adds most common hints to run the tests in the Rust VM.
  • Loading branch information
enitrat authored Jan 16, 2025
1 parent 5ae70e7 commit 096c1b6
Show file tree
Hide file tree
Showing 34 changed files with 914 additions and 238 deletions.
7 changes: 4 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ module_name_repetitions = "allow"
[workspace.dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
proptest = "1.0"
arbitrary = "1.3"
thiserror = "1.0"
url = "2.5"
reqwest = { version = "0.12", features = ["json", "multipart"] }
Expand All @@ -59,4 +57,4 @@ cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm.git", tag = "v2.0.0-
] }

[patch."https://github.com/lambdaclass/cairo-vm.git"]
cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "dd3d3f6e76248fc02395b31cbefe5fe8183222f1" }
cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "11ecc932cfabf0653a008a7f408d698ec780cc24" }
9 changes: 3 additions & 6 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ from ethereum.cancun.trie import (
from ethereum_types.bytes import Bytes, Bytes32
from ethereum_types.numeric import U256, U256Struct, Bool, bool

from src.utils.dict import hashdict_read, hashdict_write, hashdict_get
from src.utils.dict import hashdict_read, hashdict_write, hashdict_get, dict_new_empty

struct AddressTrieBytes32U256DictAccess {
key: Address,
Expand Down Expand Up @@ -262,9 +262,7 @@ func set_storage{poseidon_ptr: PoseidonBuiltin*, state: State}(
}(1, &address.value);

if (storage_trie_pointer == 0) {
// dict_new expects an initial_dict hint argument.
%{ initial_dict = {} %}
let (new_mapping_dict_ptr) = dict_new();
let (new_mapping_dict_ptr) = dict_new_empty();
tempvar new_storage_trie = new TrieBytes32U256Struct(
secured=bool(1),
default=U256(new U256Struct(0, 0)),
Expand Down Expand Up @@ -540,8 +538,7 @@ func set_transient_storage{poseidon_ptr: PoseidonBuiltin*, transient_storage: Tr
let (trie_ptr) = hashdict_get{dict_ptr=transient_storage_tries_dict_ptr}(1, &address.value);

if (trie_ptr == 0) {
%{ initial_dict = {} %}
let (empty_dict) = dict_new();
let (empty_dict) = dict_new_empty();
tempvar new_trie = new TrieBytes32U256Struct(
secured=Bool(1),
default=U256(new U256Struct(0, 0)),
Expand Down
57 changes: 19 additions & 38 deletions cairo/ethereum/cancun/trie.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from starkware.cairo.common.cairo_builtins import KeccakBuiltin
from starkware.cairo.common.memcpy import memcpy

from src.utils.bytes import uint256_to_bytes32_little
from src.utils.dict import hashdict_read, hashdict_write
from src.utils.dict import hashdict_read, hashdict_write, dict_new_empty
from ethereum.crypto.hash import keccak256
from ethereum.utils.numeric import min, is_zero
from ethereum.rlp import encode, _encode_bytes, _encode
Expand Down Expand Up @@ -351,11 +351,7 @@ func copy_TrieAddressOptionalAccount{range_check_ptr, trie: TrieAddressOptionalA

local new_dict_ptr: AddressAccountDictAccess*;
tempvar original_mapping = trie.value._data.value;
%{
dict_tracker = __dict_manager.get_tracker(ids.original_mapping.dict_ptr)
copied_data = dict_tracker.data
ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data)
%}
%{ copy_dict_segment %}

tempvar res = TrieAddressOptionalAccount(
new TrieAddressOptionalAccountStruct(
Expand All @@ -375,13 +371,7 @@ func copy_trieBytes32U256{range_check_ptr, trie: TrieBytes32U256}() -> TrieBytes

local new_dict_ptr: Bytes32U256DictAccess*;
tempvar original_mapping = trie.value._data.value;
%{
from starkware.cairo.lang.vm.crypto import poseidon_hash_many
dict_tracker = __dict_manager.get_tracker(ids.original_mapping.dict_ptr)
copied_data = dict_tracker.data
ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data)
%}
%{ copy_dict_segment %}

tempvar res = TrieBytes32U256(
new TrieBytes32U256Struct(
Expand Down Expand Up @@ -725,7 +715,7 @@ func _search_common_prefix_length{
return current_length;
}

let preimage = _get_preimage_for_key(obj, dict_ptr_stop);
let preimage = _get_preimage_for_key(obj.key.value, dict_ptr_stop);
tempvar sliced_key = Bytes(
new BytesStruct(preimage.value.data + level.value, preimage.value.len - level.value)
);
Expand All @@ -751,7 +741,7 @@ func _get_branch_for_nibble_at_level_inner{poseidon_ptr: PoseidonBuiltin*}(
return (branch_ptr, value);
}

let preimage = _get_preimage_for_key(dict_ptr, dict_ptr_stop);
let preimage = _get_preimage_for_key(dict_ptr.key.value, dict_ptr_stop);

// Check cases
let is_value_case = is_zero(preimage.value.len - level);
Expand All @@ -774,14 +764,11 @@ func _get_branch_for_nibble_at_level_inner{poseidon_ptr: PoseidonBuiltin*}(
assert [branch_ptr].prev_value = dict_ptr.prev_value;
assert [branch_ptr].new_value = dict_ptr.new_value;
// Add an entry to the dict_tracker
%{
obj_tracker = __dict_manager.get_tracker(ids.dict_ptr_stop.address_)
dict_tracker = __dict_manager.get_tracker(ids.branch_ptr.address_)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = next(key for key in obj_tracker.data.keys() if poseidon_hash_many(key) == ids.dict_ptr.key.value)
dict_tracker.data[preimage] = obj_tracker.data[preimage]
%}
// Copy the entry from the dict_ptr's tracker to the branch_ptr's tracker
let source_key = dict_ptr.key.value;
let source_ptr_stop = dict_ptr_stop;
let dest_ptr = branch_ptr;
%{ copy_hashdict_tracker_entry %}

return _get_branch_for_nibble_at_level_inner(
dict_ptr + BytesBytesDictAccess.SIZE,
Expand Down Expand Up @@ -824,8 +811,7 @@ func _get_branch_for_nibble_at_level{poseidon_ptr: PoseidonBuiltin*}(
alloc_locals;
// Allocate a segment for the branch and register an associated tracker
// dict_new expectes an initial_dict hint argument.
%{ initial_dict = {} %}
let (branch_start_: DictAccess*) = dict_new();
let (branch_start_: DictAccess*) = dict_new_empty();
let branch_start = cast(branch_start_, BytesBytesDictAccess*);
let dict_ptr_stop = obj.value.dict_ptr;

Expand Down Expand Up @@ -965,7 +951,7 @@ func _get_branches{poseidon_ptr: PoseidonBuiltin*}(obj: MappingBytesBytes, level
assert value = value_15;
assert value_set = 1;
}
%{ ids.value_set = memory.get(fp + 2) or 0 %}
%{ fp_plus_2_or_0 %}
if (value_set != 1) {
let (data: felt*) = alloc();
tempvar empty_bytes = Bytes(new BytesStruct(data, 0));
Expand All @@ -978,28 +964,22 @@ func _get_branches{poseidon_ptr: PoseidonBuiltin*}(obj: MappingBytesBytes, level

// @notice Given a key (inside `dict_ptr`), returns the preimage of the key registered in the tracker.
// The preimage is validated to be correctly provided by the prover by hashing it and comparing it to the key.
// @param key - The key to get the preimage for. Either a hashed or non-hashed key - but it must be a felt.
// @param dict_ptr_stop - The pointer to the end of the dict segment, the one registered in the tracker.
func _get_preimage_for_key{poseidon_ptr: PoseidonBuiltin*}(
dict_ptr: BytesBytesDictAccess*, dict_ptr_stop: BytesBytesDictAccess*
key: felt, dict_ptr_stop: BytesBytesDictAccess*
) -> Bytes {
alloc_locals;

// Get preimage data
let (local preimage_data: felt*) = alloc();
local preimage_len;
%{
from starkware.cairo.lang.vm.crypto import poseidon_hash_many
hashed_value = ids.dict_ptr.key.value
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr_stop)
# Get the key in the dict that matches the hashed value
preimage = bytes(next(key for key in dict_tracker.data.keys() if poseidon_hash_many(key) == hashed_value))
segments.write_arg(ids.preimage_data, preimage)
ids.preimage_len = len(preimage)
%}
%{ get_preimage_for_key %}

// Verify preimage
let (preimage_hash) = poseidon_hash_many(preimage_len, preimage_data);
with_attr error_message("preimage_hash != key") {
assert preimage_hash = dict_ptr.key.value;
assert preimage_hash = key;
}

tempvar res = Bytes(new BytesStruct(preimage_data, preimage_len));
Expand All @@ -1024,7 +1004,8 @@ func patricialize{
}

let arbitrary_value = obj.value.dict_ptr_start.new_value;
let preimage = _get_preimage_for_key(obj.value.dict_ptr_start, obj.value.dict_ptr);
let current_key = obj.value.dict_ptr_start.key.value;
let preimage = _get_preimage_for_key(current_key, obj.value.dict_ptr);

// if leaf node
if (len == 1) {
Expand Down
13 changes: 1 addition & 12 deletions cairo/ethereum/utils/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,7 @@ func Bytes__eq__(_self: Bytes, other: Bytes) -> bool {
// return the first different byte index, and assert in cairo that the a[index] != b[index]
tempvar is_diff;
tempvar diff_index;
%{
self_bytes = b''.join([memory[ids._self.value.data + i].to_bytes(1, "little") for i in range(ids._self.value.len)])
other_bytes = b''.join([memory[ids.other.value.data + i].to_bytes(1, "little") for i in range(ids.other.value.len)])
diff_index = next((i for i, (b_self, b_other) in enumerate(zip(self_bytes, other_bytes)) if b_self != b_other), None)
if diff_index is not None:
ids.is_diff = 1
ids.diff_index = diff_index
else:
# No differences found in common prefix. Lengths were checked before
ids.is_diff = 0
ids.diff_index = 0
%}
%{ Bytes__eq__ %}

if (is_diff == 1) {
// Assert that the bytes are different at the first different index
Expand Down
2 changes: 1 addition & 1 deletion cairo/ethereum/utils/numeric.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ func min{range_check_ptr}(a: felt, b: felt) -> felt {
alloc_locals;

tempvar is_min_b;
%{ memory[ap - 1] = 1 if ids.b <= ids.a else 0 %}
%{ b_le_a %}
jmp min_is_b if is_min_b != 0;

min_is_a:
Expand Down
39 changes: 10 additions & 29 deletions cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@ from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.squash_dict import squash_dict
from starkware.cairo.common.uint256 import Uint256

from ethereum_types.numeric import U256, U256Struct
from ethereum_types.bytes import Bytes32
from ethereum.utils.numeric import U256__eq__
from ethereum.cancun.fork_types import Address, Account, AccountStruct, Account__eq__

from src.utils.maths import unsigned_div_rem

// @ notice: Creates a new, empty dict, does not require an `initial_dict` argument.
func dict_new_empty() -> (res: DictAccess*) {
%{ dict_new_empty %}
ap += 1;
return (res=cast([ap - 1], DictAccess*));
}

func dict_copy{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*) -> (
DictAccess*, DictAccess*
) {
Expand Down Expand Up @@ -69,13 +75,7 @@ func hashdict_read{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
}

local value;
%{
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
# Not using [] here because it will register the value for that key in the tracker.
ids.value = dict_tracker.data.get(preimage, dict_tracker.data.default_factory())
%}
%{ hashdict_read %}
dict_ptr.key = felt_key;
dict_ptr.prev_value = value;
dict_ptr.new_value = value;
Expand Down Expand Up @@ -104,16 +104,7 @@ func hashdict_get{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
}

local value;
%{
from collections import defaultdict
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
if isinstance(dict_tracker.data, defaultdict):
ids.value = dict_tracker.data[preimage]
else:
ids.value = dict_tracker.data.get(preimage, 0)
%}
%{ hashdict_get %}
dict_ptr.key = felt_key;
dict_ptr.prev_value = value;
dict_ptr.new_value = value;
Expand All @@ -139,17 +130,7 @@ func hashdict_write{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
assert felt_key = felt_key_;
tempvar poseidon_ptr = poseidon_ptr;
}
%{
from collections import defaultdict
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
if isinstance(dict_tracker.data, defaultdict):
ids.dict_ptr.prev_value = dict_tracker.data[preimage]
else:
ids.dict_ptr.prev_value = 0
dict_tracker.data[preimage] = ids.new_value
%}
%{ hashdict_write %}
dict_ptr.key = felt_key;
dict_ptr.new_value = new_value;
let dict_ptr = dict_ptr + DictAccess.SIZE;
Expand Down
3 changes: 0 additions & 3 deletions cairo/tests/ethereum/cancun/test_fork_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest
from hypothesis import given

from ethereum.cancun.fork_types import EMPTY_ACCOUNT, Account

pytestmark = pytest.mark.python_vm


class TestForkTypes:
def test_account_default(self, cairo_run):
Expand Down
2 changes: 0 additions & 2 deletions cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from tests.utils.args_gen import TransientStorage
from tests.utils.strategies import address, bytes32, state, transient_storage

pytestmark = pytest.mark.python_vm


@composite
def state_and_address_and_optional_key(
Expand Down
Loading

0 comments on commit 096c1b6

Please sign in to comment.