diff --git a/Cargo.lock b/Cargo.lock index 8da73d59..bcb9af83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,7 +145,7 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cairo-vm" version = "2.0.0-rc3" -source = "git+https://github.com/kkrt-labs/cairo-vm?rev=dd3d3f6e76248fc02395b31cbefe5fe8183222f1#dd3d3f6e76248fc02395b31cbefe5fe8183222f1" +source = "git+https://github.com/kkrt-labs/cairo-vm?rev=11ecc932cfabf0653a008a7f408d698ec780cc24#11ecc932cfabf0653a008a7f408d698ec780cc24" dependencies = [ "anyhow", "arbitrary", @@ -183,6 +183,7 @@ dependencies = [ "num-traits", "pyo3", "pyo3-build-config", + "starknet-crypto", ] [[package]] @@ -1505,9 +1506,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "starknet-crypto" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ded22ccf4cb9e572ce3f77de6066af53560cd2520d508876c83bb1e6b29d5cbc" +checksum = "039a3bad70806b494c9e6b21c5238a6c8a373d66a26071859deb0ccca6f93634" dependencies = [ "crypto-bigint", "hex", diff --git a/Cargo.toml b/Cargo.toml index 7d9ac5d6..e1fff7d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,8 +47,6 @@ module_name_repetitions = "allow" [workspace.dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1" -proptest = "1.0" -arbitrary = "1.3" thiserror = "1.0" url = "2.5" reqwest = { version = "0.12", features = ["json", "multipart"] } @@ -59,4 +57,4 @@ cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm.git", tag = "v2.0.0- ] } [patch."https://github.com/lambdaclass/cairo-vm.git"] -cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "dd3d3f6e76248fc02395b31cbefe5fe8183222f1" } +cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "11ecc932cfabf0653a008a7f408d698ec780cc24" } diff --git a/cairo/ethereum/cancun/state.cairo b/cairo/ethereum/cancun/state.cairo index 82be5050..3983df73 100644 --- a/cairo/ethereum/cancun/state.cairo +++ b/cairo/ethereum/cancun/state.cairo @@ -32,7 +32,7 @@ from ethereum.cancun.trie import ( from ethereum_types.bytes import Bytes, Bytes32 from ethereum_types.numeric import U256, U256Struct, Bool, bool -from src.utils.dict import hashdict_read, hashdict_write, hashdict_get +from src.utils.dict import hashdict_read, hashdict_write, hashdict_get, dict_new_empty struct AddressTrieBytes32U256DictAccess { key: Address, @@ -262,9 +262,7 @@ func set_storage{poseidon_ptr: PoseidonBuiltin*, state: State}( }(1, &address.value); if (storage_trie_pointer == 0) { - // dict_new expects an initial_dict hint argument. - %{ initial_dict = {} %} - let (new_mapping_dict_ptr) = dict_new(); + let (new_mapping_dict_ptr) = dict_new_empty(); tempvar new_storage_trie = new TrieBytes32U256Struct( secured=bool(1), default=U256(new U256Struct(0, 0)), @@ -540,8 +538,7 @@ func set_transient_storage{poseidon_ptr: PoseidonBuiltin*, transient_storage: Tr let (trie_ptr) = hashdict_get{dict_ptr=transient_storage_tries_dict_ptr}(1, &address.value); if (trie_ptr == 0) { - %{ initial_dict = {} %} - let (empty_dict) = dict_new(); + let (empty_dict) = dict_new_empty(); tempvar new_trie = new TrieBytes32U256Struct( secured=Bool(1), default=U256(new U256Struct(0, 0)), diff --git a/cairo/ethereum/cancun/trie.cairo b/cairo/ethereum/cancun/trie.cairo index 395f952c..99ab22ca 100644 --- a/cairo/ethereum/cancun/trie.cairo +++ b/cairo/ethereum/cancun/trie.cairo @@ -10,7 +10,7 @@ from starkware.cairo.common.cairo_builtins import KeccakBuiltin from starkware.cairo.common.memcpy import memcpy from src.utils.bytes import uint256_to_bytes32_little -from src.utils.dict import hashdict_read, hashdict_write +from src.utils.dict import hashdict_read, hashdict_write, dict_new_empty from ethereum.crypto.hash import keccak256 from ethereum.utils.numeric import min, is_zero from ethereum.rlp import encode, _encode_bytes, _encode @@ -351,11 +351,7 @@ func copy_TrieAddressOptionalAccount{range_check_ptr, trie: TrieAddressOptionalA local new_dict_ptr: AddressAccountDictAccess*; tempvar original_mapping = trie.value._data.value; - %{ - dict_tracker = __dict_manager.get_tracker(ids.original_mapping.dict_ptr) - copied_data = dict_tracker.data - ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data) - %} + %{ copy_dict_segment %} tempvar res = TrieAddressOptionalAccount( new TrieAddressOptionalAccountStruct( @@ -375,13 +371,7 @@ func copy_trieBytes32U256{range_check_ptr, trie: TrieBytes32U256}() -> TrieBytes local new_dict_ptr: Bytes32U256DictAccess*; tempvar original_mapping = trie.value._data.value; - %{ - from starkware.cairo.lang.vm.crypto import poseidon_hash_many - - dict_tracker = __dict_manager.get_tracker(ids.original_mapping.dict_ptr) - copied_data = dict_tracker.data - ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data) - %} + %{ copy_dict_segment %} tempvar res = TrieBytes32U256( new TrieBytes32U256Struct( @@ -725,7 +715,7 @@ func _search_common_prefix_length{ return current_length; } - let preimage = _get_preimage_for_key(obj, dict_ptr_stop); + let preimage = _get_preimage_for_key(obj.key.value, dict_ptr_stop); tempvar sliced_key = Bytes( new BytesStruct(preimage.value.data + level.value, preimage.value.len - level.value) ); @@ -751,7 +741,7 @@ func _get_branch_for_nibble_at_level_inner{poseidon_ptr: PoseidonBuiltin*}( return (branch_ptr, value); } - let preimage = _get_preimage_for_key(dict_ptr, dict_ptr_stop); + let preimage = _get_preimage_for_key(dict_ptr.key.value, dict_ptr_stop); // Check cases let is_value_case = is_zero(preimage.value.len - level); @@ -774,14 +764,11 @@ func _get_branch_for_nibble_at_level_inner{poseidon_ptr: PoseidonBuiltin*}( assert [branch_ptr].prev_value = dict_ptr.prev_value; assert [branch_ptr].new_value = dict_ptr.new_value; - // Add an entry to the dict_tracker - %{ - obj_tracker = __dict_manager.get_tracker(ids.dict_ptr_stop.address_) - dict_tracker = __dict_manager.get_tracker(ids.branch_ptr.address_) - dict_tracker.current_ptr += ids.DictAccess.SIZE - preimage = next(key for key in obj_tracker.data.keys() if poseidon_hash_many(key) == ids.dict_ptr.key.value) - dict_tracker.data[preimage] = obj_tracker.data[preimage] - %} + // Copy the entry from the dict_ptr's tracker to the branch_ptr's tracker + let source_key = dict_ptr.key.value; + let source_ptr_stop = dict_ptr_stop; + let dest_ptr = branch_ptr; + %{ copy_hashdict_tracker_entry %} return _get_branch_for_nibble_at_level_inner( dict_ptr + BytesBytesDictAccess.SIZE, @@ -824,8 +811,7 @@ func _get_branch_for_nibble_at_level{poseidon_ptr: PoseidonBuiltin*}( alloc_locals; // Allocate a segment for the branch and register an associated tracker // dict_new expectes an initial_dict hint argument. - %{ initial_dict = {} %} - let (branch_start_: DictAccess*) = dict_new(); + let (branch_start_: DictAccess*) = dict_new_empty(); let branch_start = cast(branch_start_, BytesBytesDictAccess*); let dict_ptr_stop = obj.value.dict_ptr; @@ -965,7 +951,7 @@ func _get_branches{poseidon_ptr: PoseidonBuiltin*}(obj: MappingBytesBytes, level assert value = value_15; assert value_set = 1; } - %{ ids.value_set = memory.get(fp + 2) or 0 %} + %{ fp_plus_2_or_0 %} if (value_set != 1) { let (data: felt*) = alloc(); tempvar empty_bytes = Bytes(new BytesStruct(data, 0)); @@ -978,28 +964,22 @@ func _get_branches{poseidon_ptr: PoseidonBuiltin*}(obj: MappingBytesBytes, level // @notice Given a key (inside `dict_ptr`), returns the preimage of the key registered in the tracker. // The preimage is validated to be correctly provided by the prover by hashing it and comparing it to the key. +// @param key - The key to get the preimage for. Either a hashed or non-hashed key - but it must be a felt. +// @param dict_ptr_stop - The pointer to the end of the dict segment, the one registered in the tracker. func _get_preimage_for_key{poseidon_ptr: PoseidonBuiltin*}( - dict_ptr: BytesBytesDictAccess*, dict_ptr_stop: BytesBytesDictAccess* + key: felt, dict_ptr_stop: BytesBytesDictAccess* ) -> Bytes { alloc_locals; // Get preimage data let (local preimage_data: felt*) = alloc(); local preimage_len; - %{ - from starkware.cairo.lang.vm.crypto import poseidon_hash_many - hashed_value = ids.dict_ptr.key.value - dict_tracker = __dict_manager.get_tracker(ids.dict_ptr_stop) - # Get the key in the dict that matches the hashed value - preimage = bytes(next(key for key in dict_tracker.data.keys() if poseidon_hash_many(key) == hashed_value)) - segments.write_arg(ids.preimage_data, preimage) - ids.preimage_len = len(preimage) - %} + %{ get_preimage_for_key %} // Verify preimage let (preimage_hash) = poseidon_hash_many(preimage_len, preimage_data); with_attr error_message("preimage_hash != key") { - assert preimage_hash = dict_ptr.key.value; + assert preimage_hash = key; } tempvar res = Bytes(new BytesStruct(preimage_data, preimage_len)); @@ -1024,7 +1004,8 @@ func patricialize{ } let arbitrary_value = obj.value.dict_ptr_start.new_value; - let preimage = _get_preimage_for_key(obj.value.dict_ptr_start, obj.value.dict_ptr); + let current_key = obj.value.dict_ptr_start.key.value; + let preimage = _get_preimage_for_key(current_key, obj.value.dict_ptr); // if leaf node if (len == 1) { diff --git a/cairo/ethereum/utils/bytes.cairo b/cairo/ethereum/utils/bytes.cairo index 43598978..b7461b4b 100644 --- a/cairo/ethereum/utils/bytes.cairo +++ b/cairo/ethereum/utils/bytes.cairo @@ -13,18 +13,7 @@ func Bytes__eq__(_self: Bytes, other: Bytes) -> bool { // return the first different byte index, and assert in cairo that the a[index] != b[index] tempvar is_diff; tempvar diff_index; - %{ - self_bytes = b''.join([memory[ids._self.value.data + i].to_bytes(1, "little") for i in range(ids._self.value.len)]) - other_bytes = b''.join([memory[ids.other.value.data + i].to_bytes(1, "little") for i in range(ids.other.value.len)]) - diff_index = next((i for i, (b_self, b_other) in enumerate(zip(self_bytes, other_bytes)) if b_self != b_other), None) - if diff_index is not None: - ids.is_diff = 1 - ids.diff_index = diff_index - else: - # No differences found in common prefix. Lengths were checked before - ids.is_diff = 0 - ids.diff_index = 0 - %} + %{ Bytes__eq__ %} if (is_diff == 1) { // Assert that the bytes are different at the first different index diff --git a/cairo/ethereum/utils/numeric.cairo b/cairo/ethereum/utils/numeric.cairo index 953ba4d3..b1a7d5b7 100644 --- a/cairo/ethereum/utils/numeric.cairo +++ b/cairo/ethereum/utils/numeric.cairo @@ -8,7 +8,7 @@ func min{range_check_ptr}(a: felt, b: felt) -> felt { alloc_locals; tempvar is_min_b; - %{ memory[ap - 1] = 1 if ids.b <= ids.a else 0 %} + %{ b_le_a %} jmp min_is_b if is_min_b != 0; min_is_a: diff --git a/cairo/src/utils/dict.cairo b/cairo/src/utils/dict.cairo index e295ee8c..7d411d9c 100644 --- a/cairo/src/utils/dict.cairo +++ b/cairo/src/utils/dict.cairo @@ -5,7 +5,6 @@ from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.squash_dict import squash_dict from starkware.cairo.common.uint256 import Uint256 - from ethereum_types.numeric import U256, U256Struct from ethereum_types.bytes import Bytes32 from ethereum.utils.numeric import U256__eq__ @@ -13,6 +12,13 @@ from ethereum.cancun.fork_types import Address, Account, AccountStruct, Account_ from src.utils.maths import unsigned_div_rem +// @ notice: Creates a new, empty dict, does not require an `initial_dict` argument. +func dict_new_empty() -> (res: DictAccess*) { + %{ dict_new_empty %} + ap += 1; + return (res=cast([ap - 1], DictAccess*)); +} + func dict_copy{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*) -> ( DictAccess*, DictAccess* ) { @@ -69,13 +75,7 @@ func hashdict_read{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}( } local value; - %{ - dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) - dict_tracker.current_ptr += ids.DictAccess.SIZE - preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)]) - # Not using [] here because it will register the value for that key in the tracker. - ids.value = dict_tracker.data.get(preimage, dict_tracker.data.default_factory()) - %} + %{ hashdict_read %} dict_ptr.key = felt_key; dict_ptr.prev_value = value; dict_ptr.new_value = value; @@ -104,16 +104,7 @@ func hashdict_get{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}( } local value; - %{ - from collections import defaultdict - dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) - dict_tracker.current_ptr += ids.DictAccess.SIZE - preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)]) - if isinstance(dict_tracker.data, defaultdict): - ids.value = dict_tracker.data[preimage] - else: - ids.value = dict_tracker.data.get(preimage, 0) - %} + %{ hashdict_get %} dict_ptr.key = felt_key; dict_ptr.prev_value = value; dict_ptr.new_value = value; @@ -139,17 +130,7 @@ func hashdict_write{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}( assert felt_key = felt_key_; tempvar poseidon_ptr = poseidon_ptr; } - %{ - from collections import defaultdict - dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) - dict_tracker.current_ptr += ids.DictAccess.SIZE - preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)]) - if isinstance(dict_tracker.data, defaultdict): - ids.dict_ptr.prev_value = dict_tracker.data[preimage] - else: - ids.dict_ptr.prev_value = 0 - dict_tracker.data[preimage] = ids.new_value - %} + %{ hashdict_write %} dict_ptr.key = felt_key; dict_ptr.new_value = new_value; let dict_ptr = dict_ptr + DictAccess.SIZE; diff --git a/cairo/tests/ethereum/cancun/test_fork_types.py b/cairo/tests/ethereum/cancun/test_fork_types.py index 8b66d928..e273514d 100644 --- a/cairo/tests/ethereum/cancun/test_fork_types.py +++ b/cairo/tests/ethereum/cancun/test_fork_types.py @@ -1,10 +1,7 @@ -import pytest from hypothesis import given from ethereum.cancun.fork_types import EMPTY_ACCOUNT, Account -pytestmark = pytest.mark.python_vm - class TestForkTypes: def test_account_default(self, cairo_run): diff --git a/cairo/tests/ethereum/cancun/test_state.py b/cairo/tests/ethereum/cancun/test_state.py index 057b36ac..ed851d0f 100644 --- a/cairo/tests/ethereum/cancun/test_state.py +++ b/cairo/tests/ethereum/cancun/test_state.py @@ -29,8 +29,6 @@ from tests.utils.args_gen import TransientStorage from tests.utils.strategies import address, bytes32, state, transient_storage -pytestmark = pytest.mark.python_vm - @composite def state_and_address_and_optional_key( diff --git a/cairo/tests/ethereum/cancun/test_trie.py b/cairo/tests/ethereum/cancun/test_trie.py index caca4aa8..51e47951 100644 --- a/cairo/tests/ethereum/cancun/test_trie.py +++ b/cairo/tests/ethereum/cancun/test_trie.py @@ -26,8 +26,6 @@ from tests.utils.hints import patch_hint from tests.utils.strategies import bytes32, nibble, uint4 -pytestmark = pytest.mark.python_vm - class TestTrie: @given(node=...) @@ -36,8 +34,6 @@ def test_encode_internal_node(self, cairo_run, node: Optional[InternalNode]): encode_internal_node(node), cairo_run("encode_internal_node", node) ) - @pytest.mark.slow - @settings(max_examples=20) # for max_examples=2, it takes 129.91s in local @given(node=..., storage_root=...) def test_encode_node(self, cairo_run, node: Node, storage_root: Optional[Bytes]): assume(node is not None) @@ -49,27 +45,16 @@ def test_encode_node(self, cairo_run, node: Node, storage_root: Optional[Bytes]) @given(node=...) def test_encode_account_should_fail_without_storage_root( - self, cairo_run, node: Account + self, cairo_run_py, node: Account ): with pytest.raises(AssertionError): encode_node(node, None) with cairo_error(message="encode_node"): - cairo_run("encode_node", node, None) - - # def test_copy_trie(self, cairo_run, trie): - # assert copy_trie(trie) == cairo_run("copy_trie", trie) - - # @given(key=..., value=...) - # def test_trie_set(self, cairo_run, key: K, value: V): - # assert trie_set(trie, key, value) == cairo_run("trie_set", trie, key, value) - - # @given(key=...) - # def test_trie_get(self, cairo_run, key: K): - # assert trie_get(trie, key) == cairo_run("trie_get", trie, key) + cairo_run_py("encode_node", node, None) @given(a=..., b=...) - def test_common_prefix_length(self, cairo_run, a: Bytes, b: Bytes): - assert common_prefix_length(a, b) == cairo_run("common_prefix_length", a, b) + def test_common_prefix_length(self, cairo_run_py, a: Bytes, b: Bytes): + assert common_prefix_length(a, b) == cairo_run_py("common_prefix_length", a, b) @given(a=..., b=...) def test_common_prefix_length_should_fail( @@ -93,7 +78,7 @@ def test_nibble_list_to_compact(self, cairo_run, x, is_leaf: bool): @given(x=nibble.filter(lambda x: len(x) != 0), is_leaf=...) def test_nibble_list_to_compact_should_raise_when_wrong_remainder( - self, cairo_program, cairo_run, x, is_leaf: bool + self, cairo_program, cairo_run_py, x, is_leaf: bool ): with ( patch_hint( @@ -103,11 +88,14 @@ def test_nibble_list_to_compact_should_raise_when_wrong_remainder( ), cairo_error(message="nibble_list_to_compact: invalid remainder"), ): - cairo_run("nibble_list_to_compact", x, is_leaf) + # Always run patch_hint tests with the python VM + cairo_run_py("nibble_list_to_compact", x, is_leaf) @given(bytes_=...) - def test_bytes_to_nibble_list(self, cairo_run, bytes_: Bytes): - assert bytes_to_nibble_list(bytes_) == cairo_run("bytes_to_nibble_list", bytes_) + def test_bytes_to_nibble_list(self, cairo_run_py, bytes_: Bytes): + assert bytes_to_nibble_list(bytes_) == cairo_run_py( + "bytes_to_nibble_list", bytes_ + ) # def test_root(self, cairo_run, trie, get_storage_root): # assert root(trie, get_storage_root) == cairo_run("root", trie, get_storage_root) @@ -147,66 +135,72 @@ def test_get_branches(self, cairo_run, obj, level): assert value == obj.get(level, b"") @pytest.mark.slow - @settings(max_examples=5) # for max_examples=2, it takes 239.03s in local - @given(obj=st.dictionaries(nibble, bytes32)) - def test_patricialize(self, cairo_run, obj: Mapping[Bytes, Bytes]): - assert patricialize(obj, Uint(0)) == cairo_run("patricialize", obj, Uint(0)) + @settings(max_examples=20) + @given(obj=st.dictionaries(nibble, bytes32, max_size=100)) + def test_patricialize(self, cairo_run_py, obj: Mapping[Bytes, Bytes]): + assert patricialize(obj, Uint(0)) == cairo_run_py("patricialize", obj, Uint(0)) class TestTrieOperations: - @given(trie=..., key=...) - def test_trie_get_TrieAddressOptionalAccount( - self, cairo_run, trie: Trie[Address, Optional[Account]], key: Address - ): - [trie_cairo, result_cairo] = cairo_run( - "trie_get_TrieAddressOptionalAccount", trie, key - ) - result_py = trie_get(trie, key) - assert result_cairo == result_py - assert trie_cairo == trie - - @given(trie=..., key=...) - def test_trie_get_TrieBytes32U256( - self, cairo_run, trie: Trie[Bytes32, U256], key: Bytes32 - ): - [trie_cairo, result_cairo] = cairo_run("trie_get_TrieBytes32U256", trie, key) - result_py = trie_get(trie, key) - assert result_cairo == result_py - assert trie_cairo == trie - - @given(trie=..., key=..., value=...) - def test_trie_set_TrieAddressOptionalAccount( - self, - cairo_run, - trie: Trie[Address, Optional[Account]], - key: Address, - value: Account, - ): - cairo_trie = cairo_run("trie_set_TrieAddressOptionalAccount", trie, key, value) - trie_set(trie, key, value) - assert cairo_trie == trie - - @given(trie=..., key=..., value=...) - def test_trie_set_TrieBytes32U256( - self, cairo_run, trie: Trie[Bytes32, U256], key: Bytes32, value: U256 - ): - cairo_trie = cairo_run("trie_set_TrieBytes32U256", trie, key, value) - trie_set(trie, key, value) - assert cairo_trie == trie - - @given(trie=...) - def test_copy_trie_AddressAccount( - self, cairo_run, trie: Trie[Address, Optional[Account]] - ): - [original_trie, copied_trie] = cairo_run( - "copy_TrieAddressOptionalAccount", trie - ) - trie_copy_py = copy_trie(trie) - assert original_trie == trie - assert copied_trie == trie_copy_py - - @given(trie=...) - def test_copy_trie_Bytes32U256(self, cairo_run, trie: Trie[Bytes32, U256]): - [original_trie, copied_trie] = cairo_run("copy_trieBytes32U256", trie) - copy_trie(trie) - assert original_trie == trie + class TestGet: + @given(trie=..., key=...) + def test_trie_get_TrieAddressOptionalAccount( + self, cairo_run, trie: Trie[Address, Optional[Account]], key: Address + ): + trie_cairo, result_cairo = cairo_run( + "trie_get_TrieAddressOptionalAccount", trie, key + ) + result_py = trie_get(trie, key) + assert result_cairo == result_py + assert trie_cairo == trie + + @given(trie=..., key=...) + def test_trie_get_TrieBytes32U256( + self, cairo_run, trie: Trie[Bytes32, U256], key: Bytes32 + ): + trie_cairo, result_cairo = cairo_run("trie_get_TrieBytes32U256", trie, key) + result_py = trie_get(trie, key) + assert result_cairo == result_py + assert trie_cairo == trie + + class TestSet: + @given(trie=..., key=..., value=...) + def test_trie_set_TrieAddressOptionalAccount( + self, + cairo_run, + trie: Trie[Address, Optional[Account]], + key: Address, + value: Account, + ): + cairo_trie = cairo_run( + "trie_set_TrieAddressOptionalAccount", trie, key, value + ) + trie_set(trie, key, value) + assert cairo_trie == trie + + @given(trie=..., key=..., value=...) + def test_trie_set_TrieBytes32U256( + self, cairo_run, trie: Trie[Bytes32, U256], key: Bytes32, value: U256 + ): + cairo_trie = cairo_run("trie_set_TrieBytes32U256", trie, key, value) + trie_set(trie, key, value) + assert cairo_trie == trie + + class TestCopy: + @given(trie=...) + def test_copy_trie_AddressAccount( + self, cairo_run, trie: Trie[Address, Optional[Account]] + ): + original_trie, copied_trie_cairo = cairo_run( + "copy_TrieAddressOptionalAccount", trie + ) + copied_trie_py = copy_trie(trie) + assert original_trie == trie + assert copied_trie_cairo == copied_trie_py + + @given(trie=...) + def test_copy_trie_Bytes32U256(self, cairo_run, trie: Trie[Bytes32, U256]): + original_trie, copied_trie_cairo = cairo_run("copy_trieBytes32U256", trie) + copied_trie_py = copy_trie(trie) + assert original_trie == trie + assert copied_trie_cairo == copied_trie_py diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py b/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py index 5d81e1fd..720dd1c8 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py @@ -15,8 +15,6 @@ from tests.utils.args_gen import Evm from tests.utils.evm_builder import EvmBuilder -pytestmark = pytest.mark.python_vm - class TestControlFlow: @given(evm=EvmBuilder().with_running().build()) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py b/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py index 2289be65..8479ac63 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py @@ -6,8 +6,6 @@ from tests.utils.args_gen import Evm from tests.utils.evm_builder import EvmBuilder -pytestmark = pytest.mark.python_vm - class TestPushN: @pytest.mark.parametrize("num_bytes", range(33)) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py b/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py index dc212422..ad9a5a81 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py @@ -12,8 +12,6 @@ from tests.utils.evm_builder import EvmBuilder from tests.utils.strategies import MAX_STORAGE_KEY_SET_SIZE -pytestmark = pytest.mark.python_vm - @composite def evm_with_accessed_storage_keys(draw): diff --git a/cairo/tests/ethereum/cancun/vm/test_gas.py b/cairo/tests/ethereum/cancun/vm/test_gas.py index 27e7dda3..8964e4e7 100644 --- a/cairo/tests/ethereum/cancun/vm/test_gas.py +++ b/cairo/tests/ethereum/cancun/vm/test_gas.py @@ -25,8 +25,6 @@ from tests.utils.args_gen import Evm, Memory from tests.utils.evm_builder import EvmBuilder -pytestmark = pytest.mark.python_vm - @composite def extensions_strategy(draw): diff --git a/cairo/tests/ethereum/test_rlp.py b/cairo/tests/ethereum/test_rlp.py index 72e2e58b..1a266f09 100644 --- a/cairo/tests/ethereum/test_rlp.py +++ b/cairo/tests/ethereum/test_rlp.py @@ -3,7 +3,7 @@ import pytest from ethereum_types.bytes import Bytes, Bytes0, Bytes32 from ethereum_types.numeric import U256, Uint -from hypothesis import assume, given, settings +from hypothesis import assume, given from ethereum.cancun.blocks import Log, Receipt, Withdrawal from ethereum.cancun.fork_types import Account, Address, Bloom, encode_account @@ -49,7 +49,6 @@ def test_encode_bytes(self, cairo_run, raw_bytes: Bytes): assert encode_bytes(raw_bytes) == cairo_run("encode_bytes", raw_bytes) @pytest.mark.slow - @settings(max_examples=300) @given(raw_sequence=...) def test_encode_sequence(self, cairo_run, raw_sequence: Sequence[Extended]): assert encode_sequence(raw_sequence) == cairo_run( @@ -57,7 +56,6 @@ def test_encode_sequence(self, cairo_run, raw_sequence: Sequence[Extended]): ) @pytest.mark.slow - @settings(max_examples=300) @given(raw_sequence=...) def test_get_joined_encodings( self, cairo_run, raw_sequence: Sequence[Extended] @@ -103,13 +101,11 @@ def test_encode_legacy_transaction(self, cairo_run, tx: LegacyTransaction): assert encode(tx) == cairo_run("encode_legacy_transaction", tx) @pytest.mark.slow - @settings(max_examples=300) @given(log=...) def test_encode_log(self, cairo_run, log: Log): assert encode(log) == cairo_run("encode_log", log) @pytest.mark.slow - @settings(max_examples=200) @given(tuple_log=...) def test_encode_tuple_log(self, cairo_run, tuple_log: Tuple[Log, ...]): assert encode(tuple_log) == cairo_run("encode_tuple_log", tuple_log) @@ -119,7 +115,6 @@ def test_encode_bloom(self, cairo_run, bloom: Bloom): assert encode(bloom) == cairo_run("encode_bloom", bloom) @pytest.mark.slow - @settings(max_examples=200) @given(receipt=...) def test_encode_receipt(self, cairo_run, receipt: Receipt): assert encode(receipt) == cairo_run("encode_receipt", receipt) @@ -156,7 +151,6 @@ def test_decode_to_bytes_should_raise(self, cairo_run, encoded_bytes: Bytes): assert decoded_bytes == decode_to_bytes(encoded_bytes) @pytest.mark.slow - @settings(max_examples=300) @given(raw_data=...) def test_decode_to_sequence(self, cairo_run, raw_data: Sequence[Extended]): assume(isinstance(raw_data, list)) diff --git a/cairo/tests/ethereum/utils/test_bytes.py b/cairo/tests/ethereum/utils/test_bytes.py index d4cdcce4..69dcd448 100644 --- a/cairo/tests/ethereum/utils/test_bytes.py +++ b/cairo/tests/ethereum/utils/test_bytes.py @@ -1,9 +1,6 @@ -import pytest from ethereum_types.bytes import Bytes from hypothesis import given -pytestmark = pytest.mark.python_vm - class TestBytes: @given(a=..., b=...) diff --git a/cairo/tests/fixtures/runner.py b/cairo/tests/fixtures/runner.py index f3d3bb4e..6aa5457e 100644 --- a/cairo/tests/fixtures/runner.py +++ b/cairo/tests/fixtures/runner.py @@ -97,27 +97,10 @@ def main_path(request): return request.session.main_paths[request.node.fspath] -@pytest.fixture(scope="module") -def cairo_run( - request, cairo_program: Program, rust_program: RustProgram, cairo_file, main_path -): - """ - Run the cairo program corresponding to the python test file at a given entrypoint with given program inputs as kwargs. - Returns the output of the cairo program put in the output memory segment. - - When --profile-cairo is passed, the cairo program is run with the tracer enabled and the resulting trace is dumped. - - Logic is mainly taken from starkware.cairo.lang.vm.cairo_run with minor updates, mainly builtins discovery from implicit args. - - Type conversion between Python and Cairo is handled by: - - gen_arg: Converts Python arguments to Cairo memory layout when preparing runner inputs - - serde: Converts Cairo memory data to Python types by reading into the segments, used to return python types. +def _run_python_vm(cairo_program: Program, cairo_file, main_path, request): + """Helper function containing Python VM implementation""" - Returns: - The function's return value, converted back to Python types - """ - - def _factory_py(entrypoint, *args, **kwargs): + def _run(entrypoint, *args, **kwargs): logger.debug(f"Running the CairoVM Python VM for {entrypoint}") implicit_args = cairo_program.identifiers.get_by_full_name( ScopedName(path=("__main__", entrypoint, "ImplicitArgs")) @@ -254,6 +237,8 @@ def _factory_py(entrypoint, *args, **kwargs): try: runner.run_until_pc(end, run_resources) except Exception as e: + if "An ASSERT_EQ instruction failed" in str(e): + raise AssertionError(e) from e raise Exception(str(e)) from e runner.end_run(disable_trace_padding=False) @@ -380,7 +365,15 @@ def _factory_py(entrypoint, *args, **kwargs): return final_output[0] if len(final_output) == 1 else final_output - def _factory_rs(entrypoint, *args, **kwargs): + return _run + + +def _run_rust_vm( + cairo_program: Program, rust_program: RustProgram, cairo_file, main_path, request +): + """Helper function containing Rust VM implementation""" + + def _run(entrypoint, *args, **kwargs): logger.debug(f"Running the CairoVM Rust VM for {entrypoint}") implicit_args = cairo_program.identifiers.get_by_full_name( ScopedName(path=("__main__", entrypoint, "ImplicitArgs")) @@ -461,7 +454,13 @@ def _factory_rs(entrypoint, *args, **kwargs): entrypoint=cairo_program.get_label(entrypoint), stack=stack ) - runner.run_until_pc(end, RustRunResources()) + # Bind Cairo's ASSERT_EQ instruction to a Python exception + try: + runner.run_until_pc(end, RustRunResources()) + except Exception as e: + if "An ASSERT_EQ instruction failed" in str(e): + raise AssertionError(e) from e + cumulative_retdata_offsets = serde.get_offsets(return_data_types) first_return_data_offset = ( cumulative_retdata_offsets[0] if cumulative_retdata_offsets else 0 @@ -526,7 +525,37 @@ def _factory_rs(entrypoint, *args, **kwargs): return final_output[0] if len(final_output) == 1 else final_output + return _run + + +@pytest.fixture(scope="module") +def cairo_run_py(request, cairo_program: Program, cairo_file, main_path): + """Run the cairo program using Python VM.""" + return _run_python_vm(cairo_program, cairo_file, main_path, request) + + +@pytest.fixture(scope="module") +def cairo_run( + request, cairo_program: Program, rust_program: RustProgram, cairo_file, main_path +): + """ + Run the cairo program corresponding to the python test file at a given entrypoint with given program inputs as kwargs. + Returns the output of the cairo program put in the output memory segment. + + When --profile-cairo is passed, the cairo program is run with the tracer enabled and the resulting trace is dumped. + + Logic is mainly taken from starkware.cairo.lang.vm.cairo_run with minor updates, mainly builtins discovery from implicit args. + + Type conversion between Python and Cairo is handled by: + - gen_arg: Converts Python arguments to Cairo memory layout when preparing runner inputs + - serde: Converts Cairo memory data to Python types by reading into the segments, used to return python types. + + The VM used for the run depends on the presence of a "python_vm" marker in the test. + + Returns: + The function's return value, converted back to Python types + """ if request.node.get_closest_marker("python_vm"): - return _factory_py - else: - return _factory_rs + return _run_python_vm(cairo_program, cairo_file, main_path, request) + + return _run_rust_vm(cairo_program, rust_program, cairo_file, main_path, request) diff --git a/cairo/tests/utils/args_gen.py b/cairo/tests/utils/args_gen.py index e76fabe3..da64ac62 100644 --- a/cairo/tests/utils/args_gen.py +++ b/cairo/tests/utils/args_gen.py @@ -627,6 +627,7 @@ def _gen_arg( if arg_type_origin is Trie: # In case of a Trie, we need the dict to be a defaultdict with the trie.default as the default value. dict_ptr = segments.memory.get(data[2]) + current_ptr = segments.memory.get(data[2] + 1) if isinstance(dict_manager, DictManager): dict_manager.trackers[dict_ptr.segment_index].data = defaultdict( lambda: data[1], dict_manager.trackers[dict_ptr.segment_index].data @@ -634,7 +635,7 @@ def _gen_arg( else: dict_manager.trackers[dict_ptr.segment_index] = RustDictTracker( data=dict_manager.trackers[dict_ptr.segment_index].data, - current_ptr=dict_ptr, + current_ptr=current_ptr, default_value=data[1], ) return struct_ptr diff --git a/crates/cairo-addons/Cargo.toml b/crates/cairo-addons/Cargo.toml index f0afaedf..13408e03 100644 --- a/crates/cairo-addons/Cargo.toml +++ b/crates/cairo-addons/Cargo.toml @@ -20,6 +20,7 @@ pyo3 = { version = "0.22.4", features = [ cairo-vm = { workspace = true } num-traits = "0.2.18" num-bigint = "0.4.6" +starknet-crypto = "0.7.4" [build-dependencies] pyo3-build-config = "0.22.4" # Should match pyo3 version diff --git a/crates/cairo-addons/src/vm/dict_manager.rs b/crates/cairo-addons/src/vm/dict_manager.rs index 736bda18..ae4167c1 100644 --- a/crates/cairo-addons/src/vm/dict_manager.rs +++ b/crates/cairo-addons/src/vm/dict_manager.rs @@ -61,20 +61,23 @@ pub struct PyTrackerMapping { #[pymethods] impl PyTrackerMapping { - fn __getitem__(&self, key: isize) -> PyResult { + fn __getitem__(&self, segment_index: isize) -> PyResult { self.inner .borrow() .trackers - .get(&key) + .get(&segment_index) .cloned() .map(|tracker| PyDictTracker { inner: tracker }) .ok_or_else(|| { - PyErr::new::(format!("Key {} not found", key)) + PyErr::new::(format!( + "segment_index {} not found", + segment_index + )) }) } - fn __setitem__(&mut self, key: isize, value: PyDictTracker) -> PyResult<()> { - self.inner.borrow_mut().trackers.insert(key, value.inner); + fn __setitem__(&mut self, segment_index: isize, value: PyDictTracker) -> PyResult<()> { + self.inner.borrow_mut().trackers.insert(segment_index, value.inner); Ok(()) } } @@ -123,7 +126,7 @@ impl PyDictManager { } #[pyclass(name = "DictTracker")] -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PyDictTracker { inner: DictTracker, } diff --git a/crates/cairo-addons/src/vm/felt.rs b/crates/cairo-addons/src/vm/felt.rs index a86f0b59..1e2e0021 100644 --- a/crates/cairo-addons/src/vm/felt.rs +++ b/crates/cairo-addons/src/vm/felt.rs @@ -16,7 +16,7 @@ impl Felt252Input { } #[pyclass(name = "Felt")] -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct PyFelt { pub(crate) inner: Felt252, } diff --git a/crates/cairo-addons/src/vm/hint_definitions.rs b/crates/cairo-addons/src/vm/hint_definitions.rs new file mode 100644 index 00000000..2468a4ac --- /dev/null +++ b/crates/cairo-addons/src/vm/hint_definitions.rs @@ -0,0 +1,7 @@ +mod dict; +mod hashdict; +mod utils; + +pub use dict::HINTS as DICT_HINTS; +pub use hashdict::HINTS as HASHDICT_HINTS; +pub use utils::HINTS as UTILS_HINTS; diff --git a/crates/cairo-addons/src/vm/hint_definitions/dict.rs b/crates/cairo-addons/src/vm/hint_definitions/dict.rs new file mode 100644 index 00000000..f58c1e8a --- /dev/null +++ b/crates/cairo-addons/src/vm/hint_definitions/dict.rs @@ -0,0 +1,61 @@ +use std::collections::HashMap; + +use cairo_vm::{ + hint_processor::{ + builtin_hint_processor::hint_utils::{ + get_ptr_from_var_name, insert_value_from_var_name, insert_value_into_ap, + }, + hint_processor_definition::HintReference, + }, + serde::deserialize_program::ApTracking, + types::exec_scope::ExecutionScopes, + vm::{errors::hint_errors::HintError, vm_core::VirtualMachine}, + Felt252, +}; + +use crate::vm::hints::Hint; + +pub const HINTS: &[fn() -> Hint] = &[dict_new_empty, copy_dict_segment]; + +pub fn dict_new_empty() -> Hint { + Hint::new( + String::from("dict_new_empty"), + |vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + _ids_data: &HashMap, + _ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + let base = + exec_scopes.get_dict_manager()?.borrow_mut().new_dict(vm, Default::default())?; + insert_value_into_ap(vm, base) + }, + ) +} + +pub fn copy_dict_segment() -> Hint { + Hint::new( + String::from("copy_dict_segment"), + |vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + // Get original dict pointer + let original_mapping_ptr = + get_ptr_from_var_name("original_mapping", vm, ids_data, ap_tracking)?; + let original_dict_ptr = vm.get_relocatable((original_mapping_ptr + 1)?)?; + + // Get tracker and copy its data + let dict_manager_ref = exec_scopes.get_dict_manager()?; + let mut dict_manager = dict_manager_ref.borrow_mut(); + let tracker = dict_manager.get_tracker(original_dict_ptr)?; + let copied_data = tracker.get_dictionary_copy(); + + // Create new dict with copied data and insert its pointer + let new_dict_ptr = dict_manager.new_dict(vm, copied_data)?; + insert_value_from_var_name("new_dict_ptr", new_dict_ptr, vm, ids_data, ap_tracking) + }, + ) +} diff --git a/crates/cairo-addons/src/vm/hint_definitions/hashdict.rs b/crates/cairo-addons/src/vm/hint_definitions/hashdict.rs new file mode 100644 index 00000000..981852c3 --- /dev/null +++ b/crates/cairo-addons/src/vm/hint_definitions/hashdict.rs @@ -0,0 +1,258 @@ +use std::collections::HashMap; + +use cairo_vm::{ + hint_processor::{ + builtin_hint_processor::{ + dict_hint_utils::DICT_ACCESS_SIZE, + dict_manager::DictKey, + hint_utils::{ + get_integer_from_var_name, get_maybe_relocatable_from_var_name, + get_ptr_from_var_name, insert_value_from_var_name, + }, + }, + hint_processor_definition::HintReference, + }, + serde::deserialize_program::ApTracking, + types::{ + errors::math_errors::MathError, exec_scope::ExecutionScopes, relocatable::MaybeRelocatable, + }, + vm::{ + errors::{hint_errors::HintError, memory_errors::MemoryError}, + vm_core::VirtualMachine, + }, + Felt252, +}; +use starknet_crypto::poseidon_hash_many; + +use crate::vm::hints::Hint; + +pub const HINTS: &[fn() -> Hint] = &[ + hashdict_read, + hashdict_get, + hashdict_write, + get_preimage_for_key, + copy_hashdict_tracker_entry, +]; + +pub fn hashdict_read() -> Hint { + Hint::new( + String::from("hashdict_read"), + |vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + // Get dictionary pointer and setup tracker + let dict_ptr = get_ptr_from_var_name("dict_ptr", vm, ids_data, ap_tracking)?; + let dict_manager_ref = exec_scopes.get_dict_manager()?; + let mut dict = dict_manager_ref.borrow_mut(); + let tracker = dict.get_tracker_mut(dict_ptr)?; + tracker.current_ptr.offset += DICT_ACCESS_SIZE; + + let key = get_ptr_from_var_name("key", vm, ids_data, ap_tracking)?; + let key_len_felt: Felt252 = + get_integer_from_var_name("key_len", vm, ids_data, ap_tracking)?; + let key_len: usize = key_len_felt + .try_into() + .map_err(|_| MathError::Felt252ToUsizeConversion(Box::new(key_len_felt)))?; + + // Build and process compound key + let dict_key = build_compound_key(vm, &key, key_len)?; + + tracker.get_value(&dict_key).and_then(|value| { + insert_value_from_var_name("value", value.clone(), vm, ids_data, ap_tracking) + }) + }, + ) +} + +/// Same as above, but returns 0 if the key is not found and the dict is NOT a defaultdict. +pub fn hashdict_get() -> Hint { + Hint::new( + String::from("hashdict_get"), + |vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + // Get dictionary pointer and setup tracker + let dict_ptr = get_ptr_from_var_name("dict_ptr", vm, ids_data, ap_tracking)?; + let dict_manager_ref = exec_scopes.get_dict_manager()?; + let mut dict = dict_manager_ref.borrow_mut(); + let tracker = dict.get_tracker_mut(dict_ptr)?; + tracker.current_ptr.offset += DICT_ACCESS_SIZE; + + let key = get_ptr_from_var_name("key", vm, ids_data, ap_tracking)?; + let key_len_felt: Felt252 = + get_integer_from_var_name("key_len", vm, ids_data, ap_tracking)?; + let key_len: usize = key_len_felt + .try_into() + .map_err(|_| MathError::Felt252ToUsizeConversion(Box::new(key_len_felt)))?; + + // Build and process compound key + let dict_key = build_compound_key(vm, &key, key_len)?; + let default_value = MaybeRelocatable::Int(0.into()); + let prev_value = tracker.get_value(&dict_key).unwrap_or(&default_value); + insert_value_from_var_name("value", prev_value.clone(), vm, ids_data, ap_tracking) + }, + ) +} + +pub fn hashdict_write() -> Hint { + Hint::new( + String::from("hashdict_write"), + |vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + // Get dictionary pointer and setup tracker + let dict_ptr = get_ptr_from_var_name("dict_ptr", vm, ids_data, ap_tracking)?; + let dict_manager_ref = exec_scopes.get_dict_manager()?; + let mut dict = dict_manager_ref.borrow_mut(); + let tracker = dict.get_tracker_mut(dict_ptr)?; + tracker.current_ptr.offset += DICT_ACCESS_SIZE; + + let key = get_ptr_from_var_name("key", vm, ids_data, ap_tracking)?; + let key_len_felt: Felt252 = + get_integer_from_var_name("key_len", vm, ids_data, ap_tracking)?; + let key_len: usize = key_len_felt + .try_into() + .map_err(|_| MathError::Felt252ToUsizeConversion(Box::new(key_len_felt)))?; + + // Build compound key and get new value + let dict_key = build_compound_key(vm, &key, key_len)?; + let new_value = + get_maybe_relocatable_from_var_name("new_value", vm, ids_data, ap_tracking)?; + let dict_ptr_prev_value = (dict_ptr + 1_i32)?; + + // Update tracker and memory + let tracker_dict = tracker.get_dictionary_ref(); + let prev_value = + tracker_dict.get(&dict_key).cloned().unwrap_or(MaybeRelocatable::Int(0.into())); + tracker.insert_value(&dict_key, &new_value); + vm.insert_value(dict_ptr_prev_value, prev_value)?; + + Ok(()) + }, + ) +} + +pub fn get_preimage_for_key() -> Hint { + Hint::new( + String::from("get_preimage_for_key"), + |vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + // Get the hashed key value + let hashed_key = get_integer_from_var_name("key", vm, ids_data, ap_tracking)?; + + // Get dictionary tracker + let dict_ptr = get_ptr_from_var_name("dict_ptr_stop", vm, ids_data, ap_tracking)?; + let dict_manager_ref = exec_scopes.get_dict_manager()?; + let dict = dict_manager_ref.borrow(); + let tracker = dict.get_tracker(dict_ptr)?; + + // Find matching preimage from tracker data + let preimage = tracker + .get_dictionary_ref() + .keys() + .find(|key| match key { + DictKey::Compound(values) => { + let felt_values: Vec = + values.iter().filter_map(|v| v.get_int()).collect(); + poseidon_hash_many(felt_values.iter()) == hashed_key + } + _ => false, + }) + .ok_or_else(|| HintError::CustomHint("No matching preimage found".into()))?; + + // Write preimage data to memory + let preimage_data_ptr = + get_ptr_from_var_name("preimage_data", vm, ids_data, ap_tracking)?; + if let DictKey::Compound(values) = preimage { + for (i, value) in values.iter().enumerate() { + vm.insert_value((preimage_data_ptr + i)?, value.clone())?; + } + + // Set preimage length + insert_value_from_var_name( + "preimage_len", + Felt252::from(values.len()), + vm, + ids_data, + ap_tracking, + )?; + } + + Ok(()) + }, + ) +} + +pub fn copy_hashdict_tracker_entry() -> Hint { + Hint::new( + String::from("copy_hashdict_tracker_entry"), + |vm: &mut VirtualMachine, + exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + let source_ptr_stop = + get_ptr_from_var_name("source_ptr_stop", vm, ids_data, ap_tracking)?; + let dest_ptr = get_ptr_from_var_name("dest_ptr", vm, ids_data, ap_tracking)?; + let dict_manager_ref = exec_scopes.get_dict_manager()?; + let mut dict = dict_manager_ref.borrow_mut(); + + let source_tracker = dict.get_tracker(source_ptr_stop)?; + let source_dict = source_tracker.get_dictionary_copy(); + + // Find matching preimage from source tracker data + let key_hash = get_integer_from_var_name("source_key", vm, ids_data, ap_tracking)?; + let preimage = source_dict + .keys() + .find(|key| match key { + DictKey::Compound(values) => { + let felt_values: Vec = + values.iter().filter_map(|v| v.get_int()).collect(); + poseidon_hash_many(felt_values.iter()) == key_hash + } + _ => false, + }) + .ok_or_else(|| HintError::CustomHint("No matching preimage found".into()))?; + let value = source_dict + .get(preimage) + .ok_or_else(|| HintError::CustomHint("No matching preimage found".into()))?; + + // Update destination tracker + let dest_tracker = dict.get_tracker_mut(dest_ptr)?; + dest_tracker.current_ptr.offset += DICT_ACCESS_SIZE; + dest_tracker.insert_value(preimage, &value.clone()); + + Ok(()) + }, + ) +} + +fn build_compound_key( + vm: &VirtualMachine, + key: &cairo_vm::types::relocatable::Relocatable, + key_len: usize, +) -> Result { + (0..key_len) + .map(|i| { + let mem_addr = (*key + i)?; + vm.get_maybe(&mem_addr).ok_or_else(|| { + HintError::Memory(MemoryError::UnknownMemoryCell(Box::from(mem_addr))) + }) + }) + .collect::, _>>() + .map(DictKey::Compound) +} diff --git a/crates/cairo-addons/src/vm/hint_definitions/utils.rs b/crates/cairo-addons/src/vm/hint_definitions/utils.rs new file mode 100644 index 00000000..6dc34566 --- /dev/null +++ b/crates/cairo-addons/src/vm/hint_definitions/utils.rs @@ -0,0 +1,160 @@ +use std::collections::HashMap; + +use cairo_vm::{ + hint_processor::{ + builtin_hint_processor::hint_utils::{ + get_integer_from_var_name, get_ptr_from_var_name, insert_value_from_var_name, + }, + hint_processor_definition::HintReference, + }, + serde::deserialize_program::ApTracking, + types::{ + errors::math_errors::MathError, exec_scope::ExecutionScopes, relocatable::MaybeRelocatable, + }, + vm::{errors::hint_errors::HintError, vm_core::VirtualMachine}, + Felt252, +}; + +use crate::vm::hints::Hint; + +pub const HINTS: &[fn() -> Hint] = &[bytes__eq__, b_le_a, fp_plus_2_or_0, nibble_remainder]; + +#[allow(non_snake_case)] +pub fn bytes__eq__() -> Hint { + Hint::new( + String::from("Bytes__eq__"), + |vm: &mut VirtualMachine, + _exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + let get_bytes_params = |name: &str| -> Result< + (usize, cairo_vm::types::relocatable::Relocatable), + HintError, + > { + let ptr = get_ptr_from_var_name(name, vm, ids_data, ap_tracking)?; + let len_addr = (ptr + 1)?; + + let len_felt = vm.get_integer(len_addr)?.into_owned(); + let len = len_felt + .try_into() + .map_err(|_| MathError::Felt252ToUsizeConversion(Box::new(len_felt)))?; + + let data = vm.get_relocatable(ptr)?; + + Ok((len, data)) + }; + + let (self_len, self_data) = get_bytes_params("_self")?; + let (other_len, other_data) = get_bytes_params("other")?; + + // Compare bytes until we find a difference + for i in 0..std::cmp::min(self_len, other_len) { + let self_byte = vm.get_integer((self_data + i)?)?.into_owned(); + + let other_byte = vm.get_integer((other_data + i)?)?.into_owned(); + + if self_byte != other_byte { + // Found difference - set is_diff=1 and diff_index=i + insert_value_from_var_name( + "is_diff", + MaybeRelocatable::from(1), + vm, + ids_data, + ap_tracking, + )?; + insert_value_from_var_name( + "diff_index", + MaybeRelocatable::from(i), + vm, + ids_data, + ap_tracking, + )?; + return Ok(()); + } + } + + // No differences found in common prefix + // Lengths were checked before this hint + insert_value_from_var_name( + "is_diff", + MaybeRelocatable::from(0), + vm, + ids_data, + ap_tracking, + )?; + insert_value_from_var_name( + "diff_index", + MaybeRelocatable::from(0), + vm, + ids_data, + ap_tracking, + )?; + Ok(()) + }, + ) +} + +pub fn b_le_a() -> Hint { + Hint::new( + String::from("b_le_a"), + |vm: &mut VirtualMachine, + _exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + let a = get_integer_from_var_name("a", vm, ids_data, ap_tracking)?; + let b = get_integer_from_var_name("b", vm, ids_data, ap_tracking)?; + let result = usize::from(b <= a); + insert_value_from_var_name( + "is_min_b", + MaybeRelocatable::from(result), + vm, + ids_data, + ap_tracking, + )?; + Ok(()) + }, + ) +} + +pub fn fp_plus_2_or_0() -> Hint { + Hint::new( + String::from("fp_plus_2_or_0"), + |vm: &mut VirtualMachine, + _exec_scopes: &mut ExecutionScopes, + _ids_data: &HashMap, + _ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + let fp_offset = (vm.get_fp() + 2)?; + let value_set = vm.get_maybe(&fp_offset); + if value_set.is_none() { + vm.insert_value(fp_offset, MaybeRelocatable::from(0))?; + } + Ok(()) + }, + ) +} + +pub fn nibble_remainder() -> Hint { + Hint::new( + String::from("memory[fp + 2] = to_felt_or_relocatable(ids.x.value.len % 2)"), + |vm: &mut VirtualMachine, + _exec_scopes: &mut ExecutionScopes, + ids_data: &HashMap, + ap_tracking: &ApTracking, + _constants: &HashMap| + -> Result<(), HintError> { + let bytes_ptr = get_ptr_from_var_name("x", vm, ids_data, ap_tracking)?; + let len = vm.get_integer((bytes_ptr + 1)?)?.into_owned(); + let len: usize = + len.try_into().map_err(|_| MathError::Felt252ToUsizeConversion(Box::new(len)))?; + let remainder = len % 2; + vm.insert_value((vm.get_fp() + 2)?, MaybeRelocatable::from(remainder))?; + Ok(()) + }, + ) +} diff --git a/crates/cairo-addons/src/vm/hint_loader.rs b/crates/cairo-addons/src/vm/hint_loader.rs new file mode 100644 index 00000000..03839e97 --- /dev/null +++ b/crates/cairo-addons/src/vm/hint_loader.rs @@ -0,0 +1,17 @@ +use pyo3::{prelude::*, types::PyDict}; +use std::collections::HashMap; + +pub fn load_python_hints() -> PyResult> { + Python::with_gil(|py| { + let hints_module = py.import_bound("cairo_addons.hints")?; + let impl_attr = hints_module.getattr("implementations")?; + let implementations = impl_attr.downcast::()?; + + let mut hints = HashMap::new(); + for (key, value) in implementations.iter() { + hints.insert(key.extract::()?, value.extract::()?); + } + + Ok(hints) + }) +} diff --git a/crates/cairo-addons/src/vm/hints.rs b/crates/cairo-addons/src/vm/hints.rs index 1775db98..81569f93 100644 --- a/crates/cairo-addons/src/vm/hints.rs +++ b/crates/cairo-addons/src/vm/hints.rs @@ -16,6 +16,11 @@ use cairo_vm::{ }; use std::{collections::HashMap, fmt, rc::Rc}; +use super::{ + hint_definitions::{DICT_HINTS, HASHDICT_HINTS, UTILS_HINTS}, + hint_loader::load_python_hints, +}; + /// A struct representing a hint. pub struct Hint { /// The hint id, ie the raw string written in the Cairo code in between `%{` and `%}`. @@ -44,22 +49,33 @@ impl Hint { /// A wrapper around [`BuiltinHintProcessor`] to manage hint registration. pub struct HintProcessor { inner: BuiltinHintProcessor, + python_hints: HashMap, } impl HintProcessor { pub fn new(run_resources: RunResources) -> Self { - Self { inner: BuiltinHintProcessor::new(HashMap::new(), run_resources) } + let python_hints = load_python_hints().unwrap(); + Self { inner: BuiltinHintProcessor::new(HashMap::new(), run_resources), python_hints } } #[must_use] - pub fn with_hint(mut self, hint: &Hint) -> Self { - self.inner.add_hint(hint.id.clone(), hint.func.clone()); + pub fn with_hints(mut self, hints: Vec Hint>) -> Self { + for fn_hint in hints { + let hint = fn_hint(); + self.inner.add_hint( + self.python_hints.get(&hint.id).unwrap_or(&hint.id).to_string(), + hint.func.clone(), + ); + } self } #[must_use] pub fn with_run_resources(self, run_resources: RunResources) -> Self { - Self { inner: BuiltinHintProcessor::new(self.inner.extra_hints, run_resources) } + Self { + inner: BuiltinHintProcessor::new(self.inner.extra_hints, run_resources), + python_hints: self.python_hints, + } } pub fn build(self) -> BuiltinHintProcessor { @@ -69,7 +85,12 @@ impl HintProcessor { impl Default for HintProcessor { fn default() -> Self { - Self::new(RunResources::default()).with_hint(&add_segment_hint()) + let mut hints: Vec Hint> = vec![add_segment_hint]; + hints.extend_from_slice(DICT_HINTS); + hints.extend_from_slice(HASHDICT_HINTS); + hints.extend_from_slice(UTILS_HINTS); + + Self::new(RunResources::default()).with_hints(hints) } } diff --git a/crates/cairo-addons/src/vm/maybe_relocatable.rs b/crates/cairo-addons/src/vm/maybe_relocatable.rs index bc14304c..92d90375 100644 --- a/crates/cairo-addons/src/vm/maybe_relocatable.rs +++ b/crates/cairo-addons/src/vm/maybe_relocatable.rs @@ -3,7 +3,7 @@ use cairo_vm::types::relocatable::MaybeRelocatable as RustMaybeRelocatable; use num_bigint::BigUint; use pyo3::{FromPyObject, IntoPy, PyObject, Python}; -#[derive(FromPyObject, Eq, PartialEq, Hash)] +#[derive(FromPyObject, Eq, PartialEq, Hash, Debug)] pub enum PyMaybeRelocatable { #[pyo3(transparent)] Felt(PyFelt), diff --git a/crates/cairo-addons/src/vm/mod.rs b/crates/cairo-addons/src/vm/mod.rs index bf9e2ccb..36241450 100644 --- a/crates/cairo-addons/src/vm/mod.rs +++ b/crates/cairo-addons/src/vm/mod.rs @@ -3,6 +3,8 @@ use pyo3::prelude::*; mod builtins; mod dict_manager; mod felt; +mod hint_definitions; +mod hint_loader; mod hints; mod layout; mod maybe_relocatable; diff --git a/crates/cairo-addons/src/vm/relocatable.rs b/crates/cairo-addons/src/vm/relocatable.rs index 04e6e04b..4039c86f 100644 --- a/crates/cairo-addons/src/vm/relocatable.rs +++ b/crates/cairo-addons/src/vm/relocatable.rs @@ -6,7 +6,7 @@ use pyo3::prelude::*; use super::maybe_relocatable::PyMaybeRelocatable; #[pyclass(name = "Relocatable")] -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct PyRelocatable { pub(crate) inner: RustRelocatable, } diff --git a/python/cairo-addons/src/cairo_addons/hints/__init__.py b/python/cairo-addons/src/cairo_addons/hints/__init__.py index d8306fc9..4804fc1f 100644 --- a/python/cairo-addons/src/cairo_addons/hints/__init__.py +++ b/python/cairo-addons/src/cairo_addons/hints/__init__.py @@ -1,7 +1,9 @@ # ruff: noqa: F403 from cairo_addons.hints.decorator import implementations, register_hint from cairo_addons.hints.dict import * +from cairo_addons.hints.hashdict import * from cairo_addons.hints.os import * +from cairo_addons.hints.utils import * __all__ = [ "register_hint", diff --git a/python/cairo-addons/src/cairo_addons/hints/dict.py b/python/cairo-addons/src/cairo_addons/hints/dict.py index bced9093..04bd7f7d 100644 --- a/python/cairo-addons/src/cairo_addons/hints/dict.py +++ b/python/cairo-addons/src/cairo_addons/hints/dict.py @@ -6,6 +6,17 @@ from starkware.cairo.lang.vm.vm_consts import VmConsts +@register_hint +def dict_new_empty( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +): + memory[ap] = dict_manager.new_dict(segments, {}) + + @register_hint def dict_copy(dict_manager: DictManager, ids: VmConsts): from starkware.cairo.common.dict import DictTracker @@ -37,3 +48,16 @@ def dict_squash( assert base.segment_index not in dict_manager.trackers dict_manager.trackers[base.segment_index] = DictTracker(data=data, current_ptr=base) memory[ap] = base + + +@register_hint +def copy_dict_segment( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +): + dict_tracker = dict_manager.get_tracker(ids.original_mapping.dict_ptr) + copied_data = dict_tracker.data + ids.new_dict_ptr = dict_manager.new_dict(segments, copied_data) diff --git a/python/cairo-addons/src/cairo_addons/hints/hashdict.py b/python/cairo-addons/src/cairo_addons/hints/hashdict.py new file mode 100644 index 00000000..be2ff886 --- /dev/null +++ b/python/cairo-addons/src/cairo_addons/hints/hashdict.py @@ -0,0 +1,105 @@ +from cairo_addons.hints.decorator import register_hint +from starkware.cairo.common.dict import DictManager +from starkware.cairo.lang.vm.memory_dict import MemoryDict +from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.vm_consts import VmConsts + + +@register_hint +def hashdict_read( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +) -> int: + dict_tracker = dict_manager.get_tracker(ids.dict_ptr) + dict_tracker.current_ptr += ids.DictAccess.SIZE + preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)]) + # Not using [] here because it will register the value for that key in the tracker. + ids.value = dict_tracker.data.get(preimage, dict_tracker.data.default_factory()) + + +@register_hint +def hashdict_get( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +) -> int: + from collections import defaultdict + + dict_tracker = dict_manager.get_tracker(ids.dict_ptr) + dict_tracker.current_ptr += ids.DictAccess.SIZE + preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)]) + if isinstance(dict_tracker.data, defaultdict): + ids.value = dict_tracker.data[preimage] + else: + ids.value = dict_tracker.data.get(preimage, 0) + + +@register_hint +def hashdict_write( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +) -> int: + from collections import defaultdict + + dict_tracker = dict_manager.get_tracker(ids.dict_ptr) + dict_tracker.current_ptr += ids.DictAccess.SIZE + preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)]) + if isinstance(dict_tracker.data, defaultdict): + ids.dict_ptr.prev_value = dict_tracker.data[preimage] + else: + ids.dict_ptr.prev_value = 0 + dict_tracker.data[preimage] = ids.new_value + + +@register_hint +def get_preimage_for_key( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +) -> int: + from starkware.cairo.lang.vm.crypto import poseidon_hash_many + + hashed_value = ids.key + dict_tracker = dict_manager.get_tracker(ids.dict_ptr_stop) + # Get the key in the dict that matches the hashed value + preimage = bytes( + next( + key + for key in dict_tracker.data.keys() + if poseidon_hash_many(key) == hashed_value + ) + ) + segments.write_arg(ids.preimage_data, preimage) + ids.preimage_len = len(preimage) + + +@register_hint +def copy_hashdict_tracker_entry( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +) -> int: + from starkware.cairo.lang.vm.crypto import poseidon_hash_many + + obj_tracker = dict_manager.get_tracker(ids.dict_ptr_stop.address_) + dict_tracker = dict_manager.get_tracker(ids.branch_ptr.address_) + dict_tracker.current_ptr += ids.DictAccess.SIZE + preimage = next( + key + for key in obj_tracker.data.keys() + if poseidon_hash_many(key) == ids.dict_ptr.key.value + ) + dict_tracker.data[preimage] = obj_tracker.data[preimage] diff --git a/python/cairo-addons/src/cairo_addons/hints/utils.py b/python/cairo-addons/src/cairo_addons/hints/utils.py new file mode 100644 index 00000000..ae7096b9 --- /dev/null +++ b/python/cairo-addons/src/cairo_addons/hints/utils.py @@ -0,0 +1,66 @@ +from cairo_addons.hints.decorator import register_hint +from starkware.cairo.common.dict import DictManager +from starkware.cairo.lang.vm.memory_dict import MemoryDict +from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager +from starkware.cairo.lang.vm.relocatable import RelocatableValue +from starkware.cairo.lang.vm.vm_consts import VmConsts + + +@register_hint +def Bytes__eq__( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +): + self_bytes = b"".join( + [ + memory[ids._self.value.data + i].to_bytes(1, "little") + for i in range(ids._self.value.len) + ] + ) + other_bytes = b"".join( + [ + memory[ids.other.value.data + i].to_bytes(1, "little") + for i in range(ids.other.value.len) + ] + ) + diff_index = next( + ( + i + for i, (b_self, b_other) in enumerate(zip(self_bytes, other_bytes)) + if b_self != b_other + ), + None, + ) + if diff_index is not None: + ids.is_diff = 1 + ids.diff_index = diff_index + else: + # No differences found in common prefix. Lengths were checked before + ids.is_diff = 0 + ids.diff_index = 0 + + +@register_hint +def b_le_a( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, +): + ids.is_min_b = 1 if ids.b <= ids.a else 0 + + +@register_hint +def fp_plus_2_or_0( + dict_manager: DictManager, + ids: VmConsts, + segments: MemorySegmentManager, + memory: MemoryDict, + ap: RelocatableValue, + fp: RelocatableValue, +): + ids.value_set = memory.get(fp + 2) or 0