Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

propose alternative chunk shape algorithm #996

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 98 additions & 26 deletions src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from warnings import warn
from typing import Tuple, Callable
from typing import Tuple, Callable, List, Optional, Union
from itertools import product, chain

import h5py
Expand All @@ -12,6 +12,92 @@
from .utils import docval, getargs, popargs, docval_macro, get_data_shape


def find_nth_none(lst, n):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function intended for external use? If we will only use this for array_with_desired_product method, then I would suggest to make the method private.

Suggested change
def find_nth_none(lst, n):
def __find_nth_none(lst, n):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, both of these should be private, though @CodyCBakerPhD 's solution makes this function unnecessary.

"""
Finds the index of the nth None in a list.

Parameters
----------
lst : list
The list in which to search for None values.
n : int
The occurrence of None to find (1-based).

Returns
-------
int or None
The index of the nth None in the list, or None if there aren't enough None values.
"""
count = 0
for index, value in enumerate(lst):
if value is None:
count += 1
if count == n:
return index
return None


def array_with_desired_product(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the intent is for this to be used to determine chunk shape, I think naming the function in a way that describes what it is used for may be more intuitive for users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually prefer it this way: separating the math from the application. What we want here is a vector that has a product that is near to a target and has minimal sum, given constraints on the vector. This function is then called in an effort to determine the shape of a chunk.

desired_product: int,
upper_limits: Iterable[Optional[int]],
defined_elements: List[Optional[int]],
) -> List[int]:
"""
Adjusts the undefined elements in a list such that their product equals the desired product,
respecting the upper limits for each element.

Parameters
----------
desired_product : int
The desired product of all elements in the list.
upper_limits : List[Optional[int]]
The upper limits for each element in the list. None indicates no limit.
defined_elements : List[Optional[int]]
The list of elements with some values predefined and others as None to be determined.

Returns
-------
List[int]
The list with all elements defined such that their product equals the desired product.

Examples
--------
>>> array_with_desired_product(30, [None, 3, None], [None, None, 5])
[6, 1, 5]

>>> array_with_desired_product(100, [None, None, 10], [2, None, None])
[2, 10, 1]

>>> array_with_desired_product(50, [5, 5, None], [None, None, 2])
[5, 5, 2]
"""
desired_product //= np.prod([x or 1 for x in defined_elements])

free_elements = [1 for x in defined_elements if x is None]
upper_limits = [x or float('inf') for x, y in zip(upper_limits, defined_elements) if y is None]

while free_elements:
#print(f"{free_elements=} {upper_limits=} {defined_elements=}")
candidate_free_elements = free_elements.copy()
idx = np.argmin(free_elements)
candidate_free_elements[idx] += 1
if np.prod(candidate_free_elements) > desired_product:
break

free_elements = candidate_free_elements

if free_elements[idx] == upper_limits[idx]:
new_elem = free_elements.pop(idx)
defined_elements[find_nth_none(defined_elements, idx + 1)] = new_elem
upper_limits.pop(idx)
desired_product //= new_elem

for x in free_elements:
defined_elements[defined_elements.index(None)] = x

return defined_elements


def append_data(data, arg):
if isinstance(data, (list, DataIO)):
data.append(arg)
Expand Down Expand Up @@ -215,7 +301,7 @@ def __init__(self, **kwargs):
self._dtype = self._get_dtype()
self._maxshape = tuple(int(x) for x in self._get_maxshape())
chunk_shape = tuple(int(x) for x in chunk_shape) if chunk_shape else chunk_shape
self.chunk_shape = chunk_shape or self._get_default_chunk_shape(chunk_mb=chunk_mb)
self.chunk_shape = chunk_shape or self._get_default_chunk_shape(chunk_mb=chunk_mb, chunk_shape=chunk_shape)
buffer_shape = tuple(int(x) for x in buffer_shape) if buffer_shape else buffer_shape
self.buffer_shape = buffer_shape or self._get_default_buffer_shape(buffer_gb=buffer_gb)

Expand Down Expand Up @@ -286,36 +372,22 @@ def __init__(self, **kwargs):
)
self.display_progress = False

@docval(
dict(
name="chunk_mb",
type=(float, int),
doc="Size of the HDF5 chunk in megabytes.",
default=None,
)
)
def _get_default_chunk_shape(self, **kwargs) -> Tuple[int, ...]:
def _get_default_chunk_shape(self, chunk_mb: Union[float, int] = 10.0, chunk_shape=Optional[List[Optional[int]]]) -> Tuple[int, ...]:
"""
Select chunk shape with size in MB less than the threshold of chunk_mb.

Keeps the dimensional ratios of the original data.
Tries to make the dimensions even.
"""
chunk_mb = getargs("chunk_mb", kwargs)

assert chunk_mb > 0, f"chunk_mb ({chunk_mb}) must be greater than zero!"

n_dims = len(self.maxshape)
itemsize = self.dtype.itemsize
chunk_bytes = chunk_mb * 1e6

min_maxshape = min(self.maxshape)
v = tuple(math.floor(maxshape_axis / min_maxshape) for maxshape_axis in self.maxshape)
prod_v = math.prod(v)
while prod_v * itemsize > chunk_bytes and prod_v != 1:
non_unit_min_v = min(x for x in v if x != 1)
v = tuple(math.floor(x / non_unit_min_v) if x != 1 else x for x in v)
prod_v = math.prod(v)
k = math.floor((chunk_bytes / (prod_v * itemsize)) ** (1 / n_dims))
return tuple([min(k * x, self.maxshape[dim]) for dim, x in enumerate(v)])
return tuple(
array_with_desired_product(
chunk_mb * 2 ** 20 / self.dtype.itemsize,
self.maxshape,
defined_elements=chunk_shape,
)
)

@docval(
dict(
Expand Down
Loading