Skip to content

Commit

Permalink
feat: set_account (#437)
Browse files Browse the repository at this point in the history
tests are broken so far, will investigate

---------

Co-authored-by: enitrat <[email protected]>
  • Loading branch information
Eikix and enitrat authored Jan 16, 2025
1 parent 7503e09 commit 9605606
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 44 deletions.
6 changes: 5 additions & 1 deletion cairo/ethereum/cancun/fork_types.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ struct Account {
value: AccountStruct*,
}

struct OptionalAccount {
value: AccountStruct*,
}

struct AddressAccountDictAccess {
key: Address,
prev_value: Account,
Expand Down Expand Up @@ -130,7 +134,7 @@ func EMPTY_ACCOUNT() -> Account {
return account;
}

func Account__eq__(a: Account, b: Account) -> bool {
func Account__eq__(a: OptionalAccount, b: OptionalAccount) -> bool {
if (cast(a.value, felt) == 0) {
let b_is_none = is_zero(cast(b.value, felt));
let res = bool(b_is_none);
Expand Down
52 changes: 36 additions & 16 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from starkware.cairo.common.math import assert_not_zero
from ethereum.cancun.fork_types import (
Address,
Account,
OptionalAccount,
MappingAddressAccount,
SetAddress,
EMPTY_ACCOUNT,
Expand All @@ -16,13 +17,14 @@ from ethereum.cancun.fork_types import (
)
from ethereum.cancun.trie import (
TrieBytes32U256,
TrieAddressAccount,
trie_get_TrieAddressAccount,
TrieAddressOptionalAccount,
trie_get_TrieAddressOptionalAccount,
trie_set_TrieAddressOptionalAccount,
trie_get_TrieBytes32U256,
trie_set_TrieBytes32U256,
AccountStruct,
TrieBytes32U256Struct,
TrieAddressAccountStruct,
TrieAddressOptionalAccountStruct,
)
from ethereum_types.bytes import Bytes, Bytes32
from ethereum_types.numeric import U256, U256Struct, Bool, bool
Expand All @@ -47,22 +49,22 @@ struct MappingAddressTrieBytes32U256 {
value: MappingAddressTrieBytes32U256Struct*,
}

struct TupleTrieAddressAccountMappingAddressTrieBytes32U256Struct {
trie_address_account: TrieAddressAccount,
struct TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct {
trie_address_account: TrieAddressOptionalAccount,
mapping_address_trie: MappingAddressTrieBytes32U256,
}

struct TupleTrieAddressAccountMappingAddressTrieBytes32U256 {
value: TupleTrieAddressAccountMappingAddressTrieBytes32U256Struct*,
struct TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256 {
value: TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct*,
}

struct ListTupleTrieAddressAccountMappingAddressTrieBytes32U256Struct {
data: TupleTrieAddressAccountMappingAddressTrieBytes32U256*,
struct ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct {
data: TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256*,
len: felt,
}

struct ListTupleTrieAddressAccountMappingAddressTrieBytes32U256 {
value: ListTupleTrieAddressAccountMappingAddressTrieBytes32U256Struct*,
struct ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256 {
value: ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct*,
}

struct TransientStorageSnapshotsStruct {
Expand All @@ -84,23 +86,22 @@ struct TransientStorage {
}

struct StateStruct {
_main_trie: TrieAddressAccount,
_main_trie: TrieAddressOptionalAccount,
_storage_tries: MappingAddressTrieBytes32U256,
_snapshots: ListTupleTrieAddressAccountMappingAddressTrieBytes32U256,
_snapshots: ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256,
created_accounts: SetAddress,
}

struct State {
value: StateStruct*,
}

using OptionalAccount = Account;
func get_account_optional{poseidon_ptr: PoseidonBuiltin*, state: State}(
address: Address
) -> OptionalAccount {
let trie = state.value._main_trie;
with trie {
let account = trie_get_TrieAddressAccount(address);
let account = trie_get_TrieAddressOptionalAccount(address);
}

return account;
Expand All @@ -114,7 +115,26 @@ func get_account{poseidon_ptr: PoseidonBuiltin*, state: State}(address: Address)
return empty_account;
}

return account;
tempvar res = Account(account.value);
return res;
}

func set_account{poseidon_ptr: PoseidonBuiltin*, state: State}(
address: Address, account: OptionalAccount
) {
let trie = state.value._main_trie;
with trie {
trie_set_TrieAddressOptionalAccount(address, account);
}
tempvar state = State(
new StateStruct(
_main_trie=trie,
_storage_tries=state.value._storage_tries,
_snapshots=state.value._snapshots,
created_accounts=state.value.created_accounts,
),
);
return ();
}

func get_storage{poseidon_ptr: PoseidonBuiltin*, state: State}(
Expand Down
38 changes: 20 additions & 18 deletions cairo/ethereum/cancun/trie.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ from ethereum.cancun.fork_types import (
Account__eq__,
AccountStruct,
Address,
OptionalAccount,
Bytes32U256DictAccess,
MappingAddressAccount,
MappingAddressAccountStruct,
Expand Down Expand Up @@ -167,14 +168,14 @@ struct Node {
value: NodeEnum*,
}

struct TrieAddressAccountStruct {
struct TrieAddressOptionalAccountStruct {
secured: bool,
default: Account,
default: OptionalAccount,
_data: MappingAddressAccount,
}

struct TrieAddressAccount {
value: TrieAddressAccountStruct*,
struct TrieAddressOptionalAccount {
value: TrieAddressOptionalAccountStruct*,
}

struct TrieBytes32U256Struct {
Expand Down Expand Up @@ -334,7 +335,8 @@ func encode_node{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: Kecc
// @notice Copies the trie to a new segment.
// @dev This function simply creates a new segment for the new dict and associates it with the
// dict_tracker of the source dict.
func copy_trieAddressAccount{range_check_ptr, trie: TrieAddressAccount}() -> TrieAddressAccount {
func copy_TrieAddressOptionalAccount{range_check_ptr, trie: TrieAddressOptionalAccount}(
) -> TrieAddressOptionalAccount {
alloc_locals;
// TODO: soundness
// We need to ensure it is sound when finalizing that copy.
Expand All @@ -355,8 +357,8 @@ func copy_trieAddressAccount{range_check_ptr, trie: TrieAddressAccount}() -> Tri
ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data)
%}

tempvar res = TrieAddressAccount(
new TrieAddressAccountStruct(
tempvar res = TrieAddressOptionalAccount(
new TrieAddressOptionalAccountStruct(
trie.value.secured,
trie.value.default,
MappingAddressAccount(
Expand Down Expand Up @@ -393,9 +395,9 @@ func copy_trieBytes32U256{range_check_ptr, trie: TrieBytes32U256}() -> TrieBytes
return res;
}

func trie_get_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressAccount}(
key: Address
) -> Account {
func trie_get_TrieAddressOptionalAccount{
poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressOptionalAccount
}(key: Address) -> OptionalAccount {
alloc_locals;
let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*);

Expand All @@ -412,10 +414,10 @@ func trie_get_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddre
trie.value._data.value.dict_ptr_start, new_dict_ptr, original_mapping
),
);
tempvar trie = TrieAddressAccount(
new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping)
tempvar trie = TrieAddressOptionalAccount(
new TrieAddressOptionalAccountStruct(trie.value.secured, trie.value.default, mapping)
);
tempvar res = Account(cast(pointer, AccountStruct*));
tempvar res = OptionalAccount(cast(pointer, AccountStruct*));
return res;
}

Expand All @@ -441,9 +443,9 @@ func trie_get_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U
return res;
}

func trie_set_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressAccount}(
key: Address, value: Account
) {
func trie_set_TrieAddressOptionalAccount{
poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressOptionalAccount
}(key: Address, value: OptionalAccount) {
let dict_ptr_start = cast(trie.value._data.value.dict_ptr_start, DictAccess*);
let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*);

Expand Down Expand Up @@ -473,8 +475,8 @@ func trie_set_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddre
trie.value._data.value.original_mapping,
),
);
tempvar trie = TrieAddressAccount(
new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping)
tempvar trie = TrieAddressOptionalAccount(
new TrieAddressOptionalAccountStruct(trie.value.secured, trie.value.default, mapping)
);
return ();
}
Expand Down
12 changes: 11 additions & 1 deletion cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Optional

import pytest
from ethereum_types.bytes import Bytes32
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import composite

from ethereum.cancun.fork_types import Address
from ethereum.cancun.fork_types import Account, Address
from ethereum.cancun.state import (
account_exists,
account_has_code_or_nonce,
Expand All @@ -14,6 +16,7 @@
get_storage,
get_transient_storage,
is_account_empty,
set_account,
set_storage,
set_transient_storage,
)
Expand Down Expand Up @@ -67,6 +70,13 @@ def test_get_account_optional(self, cairo_run, data):
assert result_cairo == get_account_optional(state, address)
assert state_cairo == state

@given(data=state_and_address_and_optional_key(), account=...)
def test_set_account(self, cairo_run, data, account: Optional[Account]):
state, address = data
state_cairo = cairo_run("set_account", state, address, account)
set_account(state, address, account)
assert state_cairo == state

@given(data=state_and_address_and_optional_key())
def test_account_has_code_or_nonce(self, cairo_run, data):
state, address = data
Expand Down
14 changes: 9 additions & 5 deletions cairo/tests/ethereum/cancun/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,12 @@ def test_patricialize(self, cairo_run, obj: Mapping[Bytes, Bytes]):

class TestTrieOperations:
@given(trie=..., key=...)
def test_trie_get_TrieAddressAccount(
def test_trie_get_TrieAddressOptionalAccount(
self, cairo_run, trie: Trie[Address, Optional[Account]], key: Address
):
[trie_cairo, result_cairo] = cairo_run("trie_get_TrieAddressAccount", trie, key)
[trie_cairo, result_cairo] = cairo_run(
"trie_get_TrieAddressOptionalAccount", trie, key
)
result_py = trie_get(trie, key)
assert result_cairo == result_py
assert trie_cairo == trie
Expand All @@ -173,14 +175,14 @@ def test_trie_get_TrieBytes32U256(
assert trie_cairo == trie

@given(trie=..., key=..., value=...)
def test_trie_set_TrieAddressAccount(
def test_trie_set_TrieAddressOptionalAccount(
self,
cairo_run,
trie: Trie[Address, Optional[Account]],
key: Address,
value: Account,
):
cairo_trie = cairo_run("trie_set_TrieAddressAccount", trie, key, value)
cairo_trie = cairo_run("trie_set_TrieAddressOptionalAccount", trie, key, value)
trie_set(trie, key, value)
assert cairo_trie == trie

Expand All @@ -196,7 +198,9 @@ def test_trie_set_TrieBytes32U256(
def test_copy_trie_AddressAccount(
self, cairo_run, trie: Trie[Address, Optional[Account]]
):
[original_trie, copied_trie] = cairo_run("copy_trieAddressAccount", trie)
[original_trie, copied_trie] = cairo_run(
"copy_TrieAddressOptionalAccount", trie
)
trie_copy_py = copy_trie(trie)
assert original_trie == trie
assert copied_trie == trie_copy_py
Expand Down
1 change: 1 addition & 0 deletions cairo/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def test_type(
Address,
Root,
Account,
Optional[Account],
Bloom,
VersionedHash,
Tuple[VersionedHash, ...],
Expand Down
7 changes: 4 additions & 3 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __eq__(self, other):
("ethereum", "cancun", "fork_types", "SetAddress"): Set[Address],
("ethereum", "cancun", "fork_types", "Root"): Root,
("ethereum", "cancun", "fork_types", "Account"): Account,
("ethereum", "cancun", "fork_types", "OptionalAccount"): Optional[Account],
("ethereum", "cancun", "fork_types", "Bloom"): Bloom,
("ethereum", "cancun", "fork_types", "VersionedHash"): VersionedHash,
("ethereum", "cancun", "fork_types", "TupleVersionedHash"): Tuple[
Expand Down Expand Up @@ -311,7 +312,7 @@ def __eq__(self, other):
("ethereum", "cancun", "trie", "BranchNode"): BranchNode,
("ethereum", "cancun", "trie", "InternalNode"): InternalNode,
("ethereum", "cancun", "trie", "Node"): Node,
("ethereum", "cancun", "trie", "TrieAddressAccount"): Trie[
("ethereum", "cancun", "trie", "TrieAddressOptionalAccount"): Trie[
Address, Optional[Account]
],
("ethereum", "cancun", "trie", "TrieBytes32U256"): Trie[Bytes32, U256],
Expand All @@ -335,13 +336,13 @@ def __eq__(self, other):
"ethereum",
"cancun",
"state",
"TupleTrieAddressAccountMappingAddressTrieBytes32U256",
"TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256",
): Tuple[Trie[Address, Optional[Account]], Mapping[Address, Trie[Bytes32, U256]]],
(
"ethereum",
"cancun",
"state",
"ListTupleTrieAddressAccountMappingAddressTrieBytes32U256",
"ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256",
): List[
Tuple[Trie[Address, Optional[Account]], Mapping[Address, Trie[Bytes32, U256]]]
],
Expand Down
11 changes: 11 additions & 0 deletions cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
python_cls, *annotations = get_args(python_cls)
origin_cls = get_origin(python_cls)

# arg_type = Optional[T] <=> arg_type_origin = Union[T, None]
if origin_cls is Union and get_args(python_cls)[1] is type(None):
# Get the value pointer: if it's zero, return None.
# Otherwise, consider this the non-optional type:
value_ptr = self.serialize_pointers(path, ptr)["value"]
if value_ptr is None:
return None
python_cls = get_args(python_cls)[0]
origin_cls = get_origin(python_cls)

if origin_cls is Union:
value_ptr = self.serialize_pointers(path, ptr)["value"]
if value_ptr is None:
Expand All @@ -176,6 +186,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
if value != 0 and value is not None
}
if len(variant_keys) != 1:
breakpoint()
raise ValueError(
f"Expected 1 item only to be relocatable in enum, got {len(variant_keys)}"
)
Expand Down

0 comments on commit 9605606

Please sign in to comment.