Skip to content

Commit

Permalink
feat: EELS stack
Browse files Browse the repository at this point in the history
adress simple comments

use generic List strategy

inline errors

use patch ethereum specs
  • Loading branch information
enitrat committed Dec 20, 2024
1 parent 25f5a05 commit dc351fa
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 15 deletions.
9 changes: 9 additions & 0 deletions cairo/ethereum/cancun/vm/exceptions.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ethereum_types.bytes import BytesStruct

struct StackUnderflowError {
value: BytesStruct*,
}

struct StackOverflowError {
value: BytesStruct*,
}
64 changes: 64 additions & 0 deletions cairo/ethereum/cancun/vm/stack.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from ethereum_types.numeric import U256, U256Struct
from ethereum_types.bytes import BytesStruct
from starkware.cairo.common.dict import DictAccess, dict_read, dict_write
from ethereum.cancun.vm.exceptions import StackOverflowError, StackUnderflowError

struct Stack {
value: StackStruct*,
}

struct StackStruct {
dict_ptr_start: StackDictAccess*,
dict_ptr: StackDictAccess*,
len: felt,
}

struct StackDictAccess {
key: felt,
prev_value: U256,
new_value: U256,
}

const STACK_MAX_SIZE = 1024;

func pop{stack: Stack}() -> (U256, StackUnderflowError) {
alloc_locals;
let len = stack.value.len;
if (len == 0) {
tempvar err = StackUnderflowError(new BytesStruct(cast(0, felt*), 0));
let val = U256(cast(0, U256Struct*));
return (val, err);
}

let dict_ptr = cast(stack.value.dict_ptr, DictAccess*);
with dict_ptr {
let (pointer) = dict_read(len - 1);
}
let new_dict_ptr = cast(dict_ptr, StackDictAccess*);

tempvar stack = Stack(new StackStruct(stack.value.dict_ptr_start, new_dict_ptr, len - 1));
tempvar value = U256(cast(pointer, U256Struct*));

let ok_ = StackUnderflowError(cast(0, BytesStruct*));
return (value, ok_);
}

func push{stack: Stack}(value: U256) -> StackOverflowError {
alloc_locals;
let len = stack.value.len;
if (len == STACK_MAX_SIZE) {
tempvar err = StackOverflowError(new BytesStruct(cast(0, felt*), 0));
return err;
}

let dict_ptr = cast(stack.value.dict_ptr, DictAccess*);
with dict_ptr {
dict_write(len, cast(value.value, felt));
}
let new_dict_ptr = cast(dict_ptr, StackDictAccess*);

tempvar stack = Stack(new StackStruct(stack.value.dict_ptr_start, new_dict_ptr, len + 1));
let ok_ = StackOverflowError(cast(0, BytesStruct*));

return ok_;
}
2 changes: 1 addition & 1 deletion cairo/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ profile = "black"
src_paths = ["src", "tests"]

[tool.uv.sources]
ethereum = { git = "https://github.com/ethereum/execution-specs.git", rev = "1adcc1bfe774798bcacc685aebc17bd9935078c3" }
ethereum = { git = "https://github.com/kkrt-labs/execution-specs.git", branch = "dev/change-type-branch-nodes" }

[build-system]
requires = ["hatchling"]
Expand Down
40 changes: 40 additions & 0 deletions cairo/tests/ethereum/cancun/vm/test_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List

import pytest
from ethereum_types.numeric import U256
from hypothesis import assume, given

from ethereum.cancun.vm.exceptions import StackOverflowError, StackUnderflowError
from ethereum.cancun.vm.stack import pop, push


class TestStack:
def test_pop_underflow(self, cairo_run):
stack = []
with pytest.raises(StackUnderflowError):
cairo_run("pop", stack)
with pytest.raises(StackUnderflowError):
pop(stack)

@given(stack=...)
def test_pop_success(self, cairo_run, stack: List[U256]):
assume(len(stack) > 0)

(new_stack_cairo, popped_value_cairo) = cairo_run("pop", stack)
popped_value_py = pop(stack)
assert new_stack_cairo == stack
assert popped_value_cairo == popped_value_py

@given(value=...)
def test_push_overflow(self, cairo_run, value: U256):
stack = [U256(0)] * 1024
with pytest.raises(StackOverflowError):
cairo_run("push", stack, value)
with pytest.raises(StackOverflowError):
push(stack, value)

@given(stack=..., value=...)
def test_push_success(self, cairo_run, stack: List[U256], value: U256):
new_stack_cairo = cairo_run("push", stack, value)
push(stack, value)
assert new_stack_cairo == stack
5 changes: 3 additions & 2 deletions cairo/tests/fixtures/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from tests.utils.args_gen import to_cairo_type, to_python_type
from tests.utils.hints import debug_info, get_op, oracle
from tests.utils.reporting import profile_from_tracer_data
from tests.utils.serde import Serde
from tests.utils.serde import NO_ERROR_FLAG, Serde

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
Expand Down Expand Up @@ -332,12 +332,13 @@ def _factory(entrypoint, *args, **kwargs):
final_output = serde.serialize_list(output_ptr)

cumulative_retdata_offsets = serde.get_offsets(return_data_types)
function_output = [
unfiltered_output = [
serde.serialize(return_data_type, runner.vm.run_context.ap, offset)
for offset, return_data_type in zip(
cumulative_retdata_offsets, return_data_types
)
]
function_output = [x for x in unfiltered_output if x is not NO_ERROR_FLAG]

if final_output is not None:
if len(function_output) > 0:
Expand Down
30 changes: 28 additions & 2 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
Any,
Dict,
ForwardRef,
List,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -97,6 +98,7 @@
Transaction,
)
from ethereum.cancun.trie import BranchNode, ExtensionNode, InternalNode, LeafNode, Node
from ethereum.cancun.vm.exceptions import StackOverflowError, StackUnderflowError
from ethereum.cancun.vm.gas import MessageCallGas
from ethereum.crypto.hash import Hash32
from ethereum.exceptions import EthereumException
Expand Down Expand Up @@ -176,6 +178,21 @@
Address, Account
],
("ethereum", "exceptions", "EthereumException"): EthereumException,
("ethereum", "cancun", "vm", "stack", "Stack"): List[U256],
(
"ethereum",
"cancun",
"vm",
"exceptions",
"StackUnderflowError",
): StackUnderflowError,
(
"ethereum",
"cancun",
"vm",
"exceptions",
"StackOverflowError",
): StackOverflowError,
}

# In the EELS, some functions are annotated with Sequence while it's actually just Bytes.
Expand Down Expand Up @@ -264,10 +281,20 @@ def _gen_arg(
segments.load_data(struct_ptr, data)
return struct_ptr

if arg_type_origin in (tuple, list, Sequence, abc.Sequence):
if arg_type_origin is list:
# A `list` is represented as a Dict[felt, V] along with a length field.
value_type = get_args(arg_type)[0] # Get the concrete type parameter
data = defaultdict(int, {k: v for k, v in enumerate(arg)})
base = _gen_arg(dict_manager, segments, Dict[Uint, value_type], data)
segments.load_data(base + 2, [len(arg)])
return base

if arg_type_origin in (tuple, Sequence, abc.Sequence):
if arg_type_origin is tuple and (
Ellipsis not in get_args(arg_type) or annotations
):
# Case a tuple with a fixed number of elements, all of different types.
# These are represented as a pointer to a struct with a pointer to each element.
element_types = get_args(arg_type)

# Handle fixed-size tuples with size annotation (e.g. Annotated[Tuple[T], N])
Expand All @@ -277,7 +304,6 @@ def _gen_arg(
raise ValueError(
f"Invalid tuple size annotation for {arg_type} with annotations {annotations}"
)

struct_ptr = segments.add()
data = [
_gen_arg(dict_manager, segments, element_type, value)
Expand Down
43 changes: 37 additions & 6 deletions cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,36 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:

return self._serialize(variant.cairo_type, value_ptr + variant.offset)

if get_origin(python_cls) is list:
mapping_struct_ptr = self.serialize_pointers(path, ptr)["value"]
mapping_struct_path = (
get_struct_definition(self.program, path)
.members["value"]
.cairo_type.pointee.scope.path
)
dict_access_path = (
get_struct_definition(self.program, mapping_struct_path)
.members["dict_ptr"]
.cairo_type.pointee.scope.path
)
dict_access_types = get_struct_definition(
self.program, dict_access_path
).members
key_type = dict_access_types["key"].cairo_type
value_type = dict_access_types["new_value"].cairo_type
pointers = self.serialize_pointers(mapping_struct_path, mapping_struct_ptr)
segment_size = pointers["dict_ptr"] - pointers["dict_ptr_start"]
dict_ptr = pointers["dict_ptr_start"]
stack_len = pointers["len"]

dict_repr = {
self._serialize(key_type, dict_ptr + i): self._serialize(
value_type, dict_ptr + i + 2
)
for i in range(0, segment_size, 3)
}
return [dict_repr[i] for i in range(stack_len)]

if get_origin(python_cls) in (tuple, list, Sequence, abc.Sequence):
# Tuple and list are represented as structs with a pointer to the first element and the length.
# The value field is a list of Relocatable (pointers to each element) or Felt (tuple of felts).
Expand Down Expand Up @@ -267,14 +297,19 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
if python_cls is None:
return kwargs

value = kwargs.get("value")
if isinstance(members["value"].cairo_type, TypePointer) and value is None:
# A None pointer is valid for pointer types, meaning just that the struct is not present.
return None

if python_cls in (U256, Hash32, Bytes32):
value = kwargs["value"]["low"] + kwargs["value"]["high"] * 2**128
value = value["low"] + value["high"] * 2**128
if python_cls == U256:
return U256(value)
return python_cls(value.to_bytes(32, "little"))

if python_cls in (Bytes0, Bytes8, Bytes20):
return python_cls(kwargs["value"].to_bytes(python_cls.LENGTH, "little"))
return python_cls(value.to_bytes(python_cls.LENGTH, "little"))

# Because some types are wrapped in a value field, e.g. Account{ value: AccountStruct }
# this may not work, so that we catch the error and try to fallback.
Expand All @@ -284,10 +319,6 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
except TypeError:
pass

value = kwargs.get("value")
if isinstance(members["value"].cairo_type, TypePointer) and value is None:
# A None pointer is valid for pointer types, meaning just that the struct is not present.
return None
if isinstance(value, dict):
signature(python_cls.__init__).bind(None, **value)
return python_cls(**value)
Expand Down
6 changes: 2 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit dc351fa

Please sign in to comment.