Skip to content

Commit

Permalink
merge byetarray and list handling
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Dec 20, 2024
1 parent 6267f66 commit 51b2e04
Showing 1 changed file with 17 additions and 28 deletions.
45 changes: 17 additions & 28 deletions cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,12 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
if "__main__" in full_path:
full_path = self.main_part + full_path[full_path.index("__main__") + 1 :]
python_cls = to_python_type(full_path)
origin_cls = get_origin(python_cls) or python_cls

if get_origin(python_cls) is Annotated:
if origin_cls is Annotated:
python_cls, _ = get_args(python_cls)

if get_origin(python_cls) is Union:
if origin_cls is Union:
value_ptr = self.serialize_pointers(path, ptr)["value"]
if value_ptr is None:
return None
Expand All @@ -162,7 +163,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:

return self._serialize(variant.cairo_type, value_ptr + variant.offset)

if get_origin(python_cls) is list:
if origin_cls in (list, bytearray):
mapping_struct_ptr = self.serialize_pointers(path, ptr)["value"]
mapping_struct_path = (
get_struct_definition(self.program, path)
Expand All @@ -182,17 +183,23 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
pointers = self.serialize_pointers(mapping_struct_path, mapping_struct_ptr)
segment_size = pointers["dict_ptr"] - pointers["dict_ptr_start"]
dict_ptr = pointers["dict_ptr_start"]
stack_len = pointers["len"]
data_len = pointers["len"]

dict_repr = {
self._serialize(key_type, dict_ptr + i): self._serialize(
value_type, dict_ptr + i + 2
)
for i in range(0, segment_size, 3)
}
return [dict_repr[i] for i in range(stack_len)]
if origin_cls is bytearray:
# For bytearray, convert Bytes1 objects to integers
return bytearray(
int.from_bytes(dict_repr[i], "little") for i in range(data_len)
)

if get_origin(python_cls) in (tuple, list, Sequence, abc.Sequence):
return [dict_repr[i] for i in range(data_len)]

if origin_cls in (tuple, Sequence, abc.Sequence):
# Tuple and list are represented as structs with a pointer to the first element and the length.
# The value field is a list of Relocatable (pointers to each element) or Felt (tuple of felts).
# In usual cairo, a pointer to a struct, (e.g. Uint256*) is actually a pointer to one single
Expand All @@ -205,7 +212,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
.cairo_type.pointee.scope.path
)
members = get_struct_definition(self.program, tuple_struct_path).members
if get_origin(python_cls) is tuple and Ellipsis not in get_args(python_cls):
if origin_cls is tuple and Ellipsis not in get_args(python_cls):
# These are regular tuples with a given size.
return tuple(
self._serialize(member.cairo_type, tuple_struct_ptr + member.offset)
Expand All @@ -216,9 +223,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
raw = self.serialize_pointers(tuple_struct_path, tuple_struct_ptr)
tuple_item_path = members["data"].cairo_type.pointee.scope.path
resolved_cls = (
get_origin(python_cls)
if get_origin(python_cls) not in (Sequence, abc.Sequence)
else list
origin_cls if origin_cls not in (Sequence, abc.Sequence) else list
)
return resolved_cls(
[
Expand All @@ -227,10 +232,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
]
)

if (
get_origin(python_cls) in (Mapping, abc.Mapping, set)
or python_cls is bytearray
):
if origin_cls in (Mapping, abc.Mapping, set):
mapping_struct_ptr = self.serialize_pointers(path, ptr)["value"]
mapping_struct_path = (
get_struct_definition(self.program, path)
Expand All @@ -251,25 +253,12 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
segment_size = pointers["dict_ptr"] - pointers["dict_ptr_start"]
dict_ptr = pointers["dict_ptr_start"]

if get_origin(python_cls) is set:
if origin_cls is set:
return {
self._serialize(key_type, dict_ptr + i)
for i in range(0, segment_size, 3)
}

if python_cls is bytearray:
# For bytearray, we reconstruct it from the dictionary values up to length
d = {
self._serialize(key_type, dict_ptr + i): self._serialize(
value_type, dict_ptr + i + 2
)
for i in range(0, segment_size, 3)
}
length = pointers["len"]
return bytearray(
[int.from_bytes(d[i], "little") for i in range(length)]
)

return {
self._serialize(key_type, dict_ptr + i): self._serialize(
value_type, dict_ptr + i + 2
Expand Down

0 comments on commit 51b2e04

Please sign in to comment.