Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: EELS memory #307

Merged
merged 5 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions cairo/ethereum/cancun/vm/memory.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// SPDX-License-Identifier: MIT

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import FALSE, TRUE
from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize
from starkware.cairo.common.dict import DictAccess, dict_read, dict_write
from starkware.cairo.common.memset import memset
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
from starkware.cairo.common.math import assert_le, assert_lt
from starkware.cairo.common.math_cmp import is_le, is_not_zero

from ethereum_types.bytes import Bytes, BytesStruct, Bytes1DictAccess
from ethereum_types.numeric import U256
from ethereum.utils.numeric import max

struct MemoryStruct {
dict_ptr_start: Bytes1DictAccess*,
dict_ptr: Bytes1DictAccess*,
len: felt,
}

struct Memory {
value: MemoryStruct*,
}

// @notice Write bytes to memory at a given position.
// @param memory The pointer to the bytearray.
// @param start_position Starting position to write at.
// @param value Bytes to write.
func memory_write{range_check_ptr, memory: Memory}(start_position: U256, value: Bytes) {
alloc_locals;
let bytes_len = value.value.len;
let start_position_felt = start_position.value.low;
ClementWalter marked this conversation as resolved.
Show resolved Hide resolved
with_attr error_message("memory_write: start_position > 2**128 || value.len > 2**128") {
assert start_position.value.high = 0;
}

let bytes_data = value.value.data;
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
with dict_ptr {
_write_bytes(start_position_felt, bytes_data, bytes_len);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

let len = max(memory.value.len, start_position.value.low + value.value.len);
tempvar memory = Memory(new MemoryStruct(memory.value.dict_ptr_start, new_dict_ptr, len));
return ();
}

// @notice Read bytes from memory.
// @param memory The pointer to the bytearray.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from memory.
func memory_read_bytes{memory: Memory}(start_position: U256, size: U256) -> Bytes {
alloc_locals;

with_attr error_message("memory_read_bytes: start_position > 2**128 || size > 2**128") {
assert start_position.value.high = 0;
assert size.value.high = 0;
}

let (local output: felt*) = alloc();
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;

with dict_ptr {
_read_bytes(start_position_felt, size_felt, output);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

tempvar memory = Memory(
new MemoryStruct(memory.value.dict_ptr_start, new_dict_ptr, memory.value.len)
);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

// @notice Read bytes from a buffer with zero padding.
// @dev assumption: start_position < 2**128
// @dev assumption: size < 2**128
// @param buffer Source bytes to read from.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from the buffer.
func buffer_read{range_check_ptr}(buffer: Bytes, start_position: U256, size: U256) -> Bytes {
alloc_locals;
let (local output: felt*) = alloc();
let buffer_len = buffer.value.len;
let buffer_data = buffer.value.data;
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;
with_attr error_message("buffer_read: start_position > 2**128 || size > 2**128") {
assert start_position.value.high = 0;
assert size.value.high = 0;
}

_buffer_read(buffer_len, buffer_data, start_position_felt, size_felt, output);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

// @notice Internal function to write bytes to memory.
// @param start_position Starting position to write at.
// @param data Pointer to the bytes data.
// @param len Length of bytes to write.
func _write_bytes{dict_ptr: DictAccess*}(start_position: felt, data: felt*, len: felt) {
if (len == 0) {
return ();
}

tempvar index = len;
tempvar dict_ptr = dict_ptr;

body:
let index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
let start_position = [fp - 5];
let data = cast([fp - 4], felt*);

dict_write(start_position + index, data[index]);

tempvar index = index;
tempvar dict_ptr = dict_ptr;
jmp body if index != 0;

end:
return ();
}

// @notice Internal function to read bytes from memory.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _read_bytes{dict_ptr: DictAccess*}(start_position: felt, size: felt, output: felt*) {
alloc_locals;
if (size == 0) {
return ();
}

tempvar dict_index = start_position + size;
tempvar dict_ptr = dict_ptr;

body:
let dict_index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
let output = cast([fp - 3], felt*);
let start_position = [fp - 5];
tempvar output_index = dict_index - start_position;

let (value) = dict_read(dict_index);
assert output[output_index] = value;

tempvar dict_index = dict_index;
tempvar dict_ptr = dict_ptr;
jmp body if output_index != 0;

return ();
}

// @notice Internal function to read bytes from a buffer with zero padding.
// @param data_len Length of the buffer.
// @param data Pointer to the buffer data.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _buffer_read{range_check_ptr}(
data_len: felt, data: felt*, start_position: felt, size: felt, output: felt*
) {
alloc_locals;
if (size == 0) {
return ();
}

// Check if start position is beyond buffer length
let start_oob = is_le(data_len, start_position);
if (start_oob == TRUE) {
memset(output, 0, size);
return ();
}

// Check if read extends past end of buffer
let end_oob = is_le(data_len, start_position + size);
if (end_oob == TRUE) {
let available_size = data_len - start_position;
memcpy(output, data + start_position, available_size);

let remaining_size = size - available_size;
memset(output + available_size, 0, remaining_size);
} else {
memcpy(output, data + start_position, size);
}
return ();
}
9 changes: 9 additions & 0 deletions cairo/ethereum_types/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ from ethereum_types.numeric import U128
struct Bytes0 {
value: felt,
}
struct Bytes1 {
value: felt,
}
struct Bytes8 {
value: felt,
}
Expand Down Expand Up @@ -119,3 +122,9 @@ struct TupleBytes32Struct {
struct TupleBytes32 {
value: TupleBytes32Struct*,
}

struct Bytes1DictAccess {
key: felt,
prev_value: Bytes1,
new_value: Bytes1,
}
4 changes: 1 addition & 3 deletions cairo/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ dev-dependencies = [
"pytest-xdist>=3.6.1",
"pytest>=8.3.3",
"pydantic>=2.9.1",
"polars>=1.18.0",
]

[tool.isort]
Expand All @@ -159,8 +160,5 @@ ethereum = { git = "https://github.com/kkrt-labs/execution-specs.git", rev = "b2
requires = ["hatchling"]
build-backend = "hatchling.build"

[dependency-groups]
dev = ["polars>=1.17.1"]

[tool.hatch.build.targets.wheel]
packages = ["src"]
70 changes: 70 additions & 0 deletions cairo/tests/ethereum/cancun/vm/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from ethereum_types.bytes import Bytes
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import composite

from ethereum.cancun.vm.memory import buffer_read, memory_read_bytes, memory_write

# NOTE: The testing strategy always assume that memory accesses are within bounds.
# Because the memory is always extended to the proper size _before_ being accessed.


@composite
def memory_write_strategy(draw):
# Higher than 2**10 will cause a HealthCheck too large error.
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

# Generate a start position in bounds with existing memory
start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))

# Generate value with size that won't overflow memory
max_value_size = memory_size - int(start_position)
value = draw(st.binary(min_size=0, max_size=max_value_size))

return memory, start_position, value


@composite
def memory_read_strategy(draw):
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))
size = draw(
st.integers(min_value=0, max_value=memory_size - int(start_position)).map(U256)
)

return memory, start_position, size


class TestMemory:
@given(memory_write_strategy())
def test_memory_write(self, cairo_run, params):
memory, start_position, value = params
cairo_memory = cairo_run("memory_write", memory, start_position, Bytes(value))
memory_write(memory, start_position, Bytes(value))
assert cairo_memory == memory

@given(memory_read_strategy())
def test_memory_read(self, cairo_run, params):
memory, start_position, size = params
(cairo_memory, cairo_value) = cairo_run(
"memory_read_bytes", memory, start_position, size
)
python_value = memory_read_bytes(memory, start_position, size)
assert cairo_memory == memory
assert cairo_value == python_value

@given(
buffer=st.binary(min_size=0, max_size=2**10).map(Bytes),
start_position=st.integers(min_value=0, max_value=2**128 - 1).map(U256),
size=st.integers(min_value=0, max_value=2**10).map(U256),
)
def test_buffer_read(
self, cairo_run, buffer: Bytes, start_position: U256, size: U256
):
assert buffer_read(buffer, start_position, size) == cairo_run(
"buffer_read", buffer, start_position, size
)
24 changes: 20 additions & 4 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@
get_origin,
)

from ethereum_types.bytes import Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256
from ethereum_types.bytes import (
Bytes,
Bytes0,
Bytes1,
Bytes8,
Bytes20,
Bytes32,
Bytes256,
)
from ethereum_types.numeric import U64, U256, Uint
from starkware.cairo.common.dict import DictManager, DictTracker
from starkware.cairo.lang.compiler.ast.cairo_types import (
Expand Down Expand Up @@ -105,6 +113,11 @@
from ethereum.rlp import Extended, Simple
from tests.utils.helpers import flatten


class Memory(bytearray):
pass


_cairo_struct_to_python_type: Dict[Tuple[str, ...], Any] = {
("ethereum_types", "others", "None"): type(None),
("ethereum_types", "numeric", "bool"): bool,
Expand All @@ -114,6 +127,7 @@
("ethereum_types", "numeric", "SetUint"): Set[Uint],
("ethereum_types", "numeric", "UnionUintU256"): Union[Uint, U256],
("ethereum_types", "bytes", "Bytes0"): Bytes0,
("ethereum_types", "bytes", "Bytes1"): Bytes1,
("ethereum_types", "bytes", "Bytes8"): Bytes8,
("ethereum_types", "bytes", "Bytes20"): Bytes20,
("ethereum_types", "bytes", "Bytes32"): Bytes32,
Expand Down Expand Up @@ -184,6 +198,7 @@
Address, Account
],
("ethereum", "exceptions", "EthereumException"): EthereumException,
("ethereum", "cancun", "vm", "memory", "Memory"): Memory,
("ethereum", "cancun", "vm", "stack", "Stack"): List[U256],
(
"ethereum",
Expand Down Expand Up @@ -287,9 +302,10 @@ def _gen_arg(
segments.load_data(struct_ptr, data)
return struct_ptr

if arg_type_origin is list:
# A `list` is represented as a Dict[felt, V] along with a length field.
value_type = get_args(arg_type)[0] # Get the concrete type parameter
if arg_type_origin in (list, Memory):
# Collection types are represented as a Dict[felt, V] along with a length field.
# Get the concrete type parameter. For bytearray, the value type is int.
value_type = next(iter(get_args(arg_type)), int)
data = defaultdict(int, {k: v for k, v in enumerate(arg)})
base = _gen_arg(dict_manager, segments, Dict[Uint, value_type], data)
segments.load_data(base + 2, [len(arg)])
Expand Down
Loading
Loading