Skip to content

Commit

Permalink
Make Group.arrays, groups compatible with v2 (#2213)
Browse files Browse the repository at this point in the history
Defines a set of array / group iterators.

- .groups / .arrays: over (name, value) pairs
- .group_keys / .array_keys: over keys
- .group_values / .array_values: over values

Co-authored-by: Joe Hamman <[email protected]>
  • Loading branch information
TomAugspurger and jhamman authored Sep 20, 2024
1 parent 32540b4 commit c878da2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 52 deletions.
67 changes: 40 additions & 27 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from zarr.store.common import ensure_no_existing_node

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable, Iterator
from collections.abc import AsyncGenerator, Generator, Iterable, Iterator
from typing import Any

from zarr.abc.codec import Codec
Expand Down Expand Up @@ -678,29 +678,31 @@ async def contains(self, member: str) -> bool:
else:
return True

# todo: decide if this method should be separate from `groups`
async def group_keys(self) -> AsyncGenerator[str, None]:
async for key, value in self.members():
async def groups(self) -> AsyncGenerator[tuple[str, AsyncGroup], None]:
async for name, value in self.members():
if isinstance(value, AsyncGroup):
yield key
yield name, value

# todo: decide if this method should be separate from `group_keys`
async def groups(self) -> AsyncGenerator[AsyncGroup, None]:
async for _, value in self.members():
if isinstance(value, AsyncGroup):
yield value
async def group_keys(self) -> AsyncGenerator[str, None]:
async for key, _ in self.groups():
yield key

# todo: decide if this method should be separate from `arrays`
async def array_keys(self) -> AsyncGenerator[str, None]:
async def group_values(self) -> AsyncGenerator[AsyncGroup, None]:
async for _, group in self.groups():
yield group

async def arrays(self) -> AsyncGenerator[tuple[str, AsyncArray], None]:
async for key, value in self.members():
if isinstance(value, AsyncArray):
yield key
yield key, value

# todo: decide if this method should be separate from `array_keys`
async def arrays(self) -> AsyncGenerator[AsyncArray, None]:
async for _, value in self.members():
if isinstance(value, AsyncArray):
yield value
async def array_keys(self) -> AsyncGenerator[str, None]:
async for key, _ in self.arrays():
yield key

async def array_values(self) -> AsyncGenerator[AsyncArray, None]:
async for _, array in self.arrays():
yield array

async def tree(self, expand: bool = False, level: int | None = None) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -861,18 +863,29 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group],
def __contains__(self, member: str) -> bool:
return self._sync(self._async_group.contains(member))

def group_keys(self) -> tuple[str, ...]:
return tuple(self._sync_iter(self._async_group.group_keys()))
def groups(self) -> Generator[tuple[str, Group], None]:
for name, async_group in self._sync_iter(self._async_group.groups()):
yield name, Group(async_group)

def group_keys(self) -> Generator[str, None]:
for name, _ in self.groups():
yield name

def group_values(self) -> Generator[Group, None]:
for _, group in self.groups():
yield group

def groups(self) -> tuple[Group, ...]:
# TODO: in v2 this was a generator that return key: Group
return tuple(Group(obj) for obj in self._sync_iter(self._async_group.groups()))
def arrays(self) -> Generator[tuple[str, Array], None]:
for name, async_array in self._sync_iter(self._async_group.arrays()):
yield name, Array(async_array)

def array_keys(self) -> tuple[str, ...]:
return tuple(self._sync_iter(self._async_group.array_keys()))
def array_keys(self) -> Generator[str, None]:
for name, _ in self.arrays():
yield name

def arrays(self) -> tuple[Array, ...]:
return tuple(Array(obj) for obj in self._sync_iter(self._async_group.arrays()))
def array_values(self) -> Generator[Array, None]:
for _, array in self.arrays():
yield array

def tree(self, expand: bool = False, level: int | None = None) -> Any:
return self._sync(self._async_group.tree(expand=expand, level=level))
Expand Down
44 changes: 19 additions & 25 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,34 +301,28 @@ def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None:
assert "foo" in group


def test_group_subgroups(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups`
"""
def test_group_child_iterators(store: Store, zarr_format: ZarrFormat):
group = Group.create(store, zarr_format=zarr_format)
keys = ("foo", "bar")
subgroups_expected = tuple(group.create_group(k) for k in keys)
# create a sub-array as well
_ = group.create_array("array", shape=(10,))
subgroups_observed = group.groups()
assert set(group.group_keys()) == set(keys)
assert len(subgroups_observed) == len(subgroups_expected)
assert all(a in subgroups_observed for a in subgroups_expected)
expected_group_keys = ["g0", "g1"]
expected_group_values = [group.create_group(name=name) for name in expected_group_keys]
expected_groups = list(zip(expected_group_keys, expected_group_values, strict=False))

expected_group_values[0].create_group("subgroup")
expected_group_values[0].create_array("subarray", shape=(1,))

def test_group_subarrays(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups`
"""
group = Group.create(store, zarr_format=zarr_format)
keys = ("foo", "bar")
subarrays_expected = tuple(group.create_array(k, shape=(10,)) for k in keys)
# create a sub-group as well
_ = group.create_group("group")
subarrays_observed = group.arrays()
assert set(group.array_keys()) == set(keys)
assert len(subarrays_observed) == len(subarrays_expected)
assert all(a in subarrays_observed for a in subarrays_expected)
expected_array_keys = ["a0", "a1"]
expected_array_values = [
group.create_array(name=name, shape=(1,)) for name in expected_array_keys
]
expected_arrays = list(zip(expected_array_keys, expected_array_values, strict=False))

assert sorted(group.groups(), key=lambda x: x[0]) == expected_groups
assert sorted(group.group_keys()) == expected_group_keys
assert sorted(group.group_values(), key=lambda x: x.name) == expected_group_values

assert sorted(group.arrays(), key=lambda x: x[0]) == expected_arrays
assert sorted(group.array_keys()) == expected_array_keys
assert sorted(group.array_values(), key=lambda x: x.name) == expected_array_values


def test_group_update_attributes(store: Store, zarr_format: ZarrFormat) -> None:
Expand Down

0 comments on commit c878da2

Please sign in to comment.