From 0511ab4757e6b5c248fa880583c613e7789a7c71 Mon Sep 17 00:00:00 2001 From: enitrat Date: Fri, 20 Dec 2024 17:06:45 +0100 Subject: [PATCH] make bytearray generic collection type --- cairo/ethereum/cancun/vm/memory.cairo | 7 ++----- cairo/tests/utils/args_gen.py | 15 +++++---------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/cairo/ethereum/cancun/vm/memory.cairo b/cairo/ethereum/cancun/vm/memory.cairo index 2bc45633..9d064d72 100644 --- a/cairo/ethereum/cancun/vm/memory.cairo +++ b/cairo/ethereum/cancun/vm/memory.cairo @@ -106,9 +106,7 @@ func buffer_read{range_check_ptr}(buffer: Bytes, start_position: U256, size: U25 let start_position_felt = start_position.value.low; let size_felt = size.value.low; - Internals._buffer_read_internal( - buffer_len, buffer_data, start_position_felt, size_felt, output - ); + Internals._buffer_read(buffer_len, buffer_data, start_position_felt, size_felt, output); tempvar result = Bytes(new BytesStruct(output, size_felt)); return result; } @@ -119,7 +117,6 @@ namespace Internals { // @param data Pointer to the bytes data. // @param len Length of bytes to write. func _write_bytes{dict_ptr: DictAccess*}(start_position: felt, data: felt*, len: felt) { - alloc_locals; if (len == 0) { return (); } @@ -181,7 +178,7 @@ namespace Internals { // @param start_position Starting position to read from. // @param size Number of bytes to read. // @param output Pointer to write output bytes to. - func _buffer_read_internal{range_check_ptr}( + func _buffer_read{range_check_ptr}( data_len: felt, data: felt*, start_position: felt, size: felt, output: felt* ) { alloc_locals; diff --git a/cairo/tests/utils/args_gen.py b/cairo/tests/utils/args_gen.py index 9da7f60c..1be5c16a 100644 --- a/cairo/tests/utils/args_gen.py +++ b/cairo/tests/utils/args_gen.py @@ -291,9 +291,10 @@ def _gen_arg( segments.load_data(struct_ptr, data) return struct_ptr - if arg_type_origin is list: - # A `list` is represented as a Dict[felt, V] along with a length field. - value_type = get_args(arg_type)[0] # Get the concrete type parameter + if arg_type_origin in (list, bytearray): + # Collection types are represented as a Dict[felt, V] along with a length field. + # Get the concrete type parameter. For bytearray, the value type is int. + value_type = next(iter(get_args(arg_type)), int) data = defaultdict(int, {k: v for k, v in enumerate(arg)}) base = _gen_arg(dict_manager, segments, Dict[Uint, value_type], data) segments.load_data(base + 2, [len(arg)]) @@ -330,19 +331,13 @@ def _gen_arg( segments.load_data(struct_ptr, [instances_ptr, len(arg)]) return struct_ptr - if arg_type_origin in (dict, ChainMap, abc.Mapping, set) or arg_type is bytearray: + if arg_type_origin in (dict, ChainMap, abc.Mapping, set): dict_ptr = segments.add() assert dict_ptr.segment_index not in dict_manager.trackers if arg_type_origin is set: arg = {k: True for k in arg} arg_type = Mapping[type(next(iter(arg))), bool] - elif arg_type is bytearray: - # Create a dict with one byte per value and include length - data = defaultdict(int, {k: v for k, v in enumerate(arg)}) - base = _gen_arg(dict_manager, segments, Mapping[int, int], data) - segments.load_data(base + 2, [len(arg)]) # Store length after dict pointers - return base data = { _gen_arg(dict_manager, segments, get_args(arg_type)[0], k): _gen_arg(