Skip to content

Commit

Permalink
make bytearray generic collection type
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Dec 20, 2024
1 parent 51b2e04 commit 0511ab4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
7 changes: 2 additions & 5 deletions cairo/ethereum/cancun/vm/memory.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ func buffer_read{range_check_ptr}(buffer: Bytes, start_position: U256, size: U25
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;

Internals._buffer_read_internal(
buffer_len, buffer_data, start_position_felt, size_felt, output
);
Internals._buffer_read(buffer_len, buffer_data, start_position_felt, size_felt, output);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}
Expand All @@ -119,7 +117,6 @@ namespace Internals {
// @param data Pointer to the bytes data.
// @param len Length of bytes to write.
func _write_bytes{dict_ptr: DictAccess*}(start_position: felt, data: felt*, len: felt) {
alloc_locals;
if (len == 0) {
return ();
}
Expand Down Expand Up @@ -181,7 +178,7 @@ namespace Internals {
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _buffer_read_internal{range_check_ptr}(
func _buffer_read{range_check_ptr}(
data_len: felt, data: felt*, start_position: felt, size: felt, output: felt*
) {
alloc_locals;
Expand Down
15 changes: 5 additions & 10 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,10 @@ def _gen_arg(
segments.load_data(struct_ptr, data)
return struct_ptr

if arg_type_origin is list:
# A `list` is represented as a Dict[felt, V] along with a length field.
value_type = get_args(arg_type)[0] # Get the concrete type parameter
if arg_type_origin in (list, bytearray):
# Collection types are represented as a Dict[felt, V] along with a length field.
# Get the concrete type parameter. For bytearray, the value type is int.
value_type = next(iter(get_args(arg_type)), int)
data = defaultdict(int, {k: v for k, v in enumerate(arg)})
base = _gen_arg(dict_manager, segments, Dict[Uint, value_type], data)
segments.load_data(base + 2, [len(arg)])
Expand Down Expand Up @@ -330,19 +331,13 @@ def _gen_arg(
segments.load_data(struct_ptr, [instances_ptr, len(arg)])
return struct_ptr

if arg_type_origin in (dict, ChainMap, abc.Mapping, set) or arg_type is bytearray:
if arg_type_origin in (dict, ChainMap, abc.Mapping, set):
dict_ptr = segments.add()
assert dict_ptr.segment_index not in dict_manager.trackers

if arg_type_origin is set:
arg = {k: True for k in arg}
arg_type = Mapping[type(next(iter(arg))), bool]
elif arg_type is bytearray:
# Create a dict with one byte per value and include length
data = defaultdict(int, {k: v for k, v in enumerate(arg)})
base = _gen_arg(dict_manager, segments, Mapping[int, int], data)
segments.load_data(base + 2, [len(arg)]) # Store length after dict pointers
return base

data = {
_gen_arg(dict_manager, segments, get_args(arg_type)[0], k): _gen_arg(
Expand Down

0 comments on commit 0511ab4

Please sign in to comment.