diff --git a/src/hdmf/data_utils.py b/src/hdmf/data_utils.py index 941c3f8c7..10bae495f 100644 --- a/src/hdmf/data_utils.py +++ b/src/hdmf/data_utils.py @@ -3,7 +3,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import Iterable from warnings import warn -from typing import Tuple, Callable +from typing import Tuple, Callable, List, Optional, Union from itertools import product, chain import h5py @@ -12,6 +12,92 @@ from .utils import docval, getargs, popargs, docval_macro, get_data_shape +def find_nth_none(lst, n): + """ + Finds the index of the nth None in a list. + + Parameters + ---------- + lst : list + The list in which to search for None values. + n : int + The occurrence of None to find (1-based). + + Returns + ------- + int or None + The index of the nth None in the list, or None if there aren't enough None values. + """ + count = 0 + for index, value in enumerate(lst): + if value is None: + count += 1 + if count == n: + return index + return None + + +def array_with_desired_product( + desired_product: int, + upper_limits: Iterable[Optional[int]], + defined_elements: List[Optional[int]], +) -> List[int]: + """ + Adjusts the undefined elements in a list such that their product equals the desired product, + respecting the upper limits for each element. + + Parameters + ---------- + desired_product : int + The desired product of all elements in the list. + upper_limits : List[Optional[int]] + The upper limits for each element in the list. None indicates no limit. + defined_elements : List[Optional[int]] + The list of elements with some values predefined and others as None to be determined. + + Returns + ------- + List[int] + The list with all elements defined such that their product equals the desired product. + + Examples + -------- + >>> array_with_desired_product(30, [None, 3, None], [None, None, 5]) + [6, 1, 5] + + >>> array_with_desired_product(100, [None, None, 10], [2, None, None]) + [2, 10, 1] + + >>> array_with_desired_product(50, [5, 5, None], [None, None, 2]) + [5, 5, 2] + """ + desired_product //= np.prod([x or 1 for x in defined_elements]) + + free_elements = [1 for x in defined_elements if x is None] + upper_limits = [x or float('inf') for x, y in zip(upper_limits, defined_elements) if y is None] + + while free_elements: + #print(f"{free_elements=} {upper_limits=} {defined_elements=}") + candidate_free_elements = free_elements.copy() + idx = np.argmin(free_elements) + candidate_free_elements[idx] += 1 + if np.prod(candidate_free_elements) > desired_product: + break + + free_elements = candidate_free_elements + + if free_elements[idx] == upper_limits[idx]: + new_elem = free_elements.pop(idx) + defined_elements[find_nth_none(defined_elements, idx + 1)] = new_elem + upper_limits.pop(idx) + desired_product //= new_elem + + for x in free_elements: + defined_elements[defined_elements.index(None)] = x + + return defined_elements + + def append_data(data, arg): if isinstance(data, (list, DataIO)): data.append(arg) @@ -215,7 +301,7 @@ def __init__(self, **kwargs): self._dtype = self._get_dtype() self._maxshape = tuple(int(x) for x in self._get_maxshape()) chunk_shape = tuple(int(x) for x in chunk_shape) if chunk_shape else chunk_shape - self.chunk_shape = chunk_shape or self._get_default_chunk_shape(chunk_mb=chunk_mb) + self.chunk_shape = chunk_shape or self._get_default_chunk_shape(chunk_mb=chunk_mb, chunk_shape=chunk_shape) buffer_shape = tuple(int(x) for x in buffer_shape) if buffer_shape else buffer_shape self.buffer_shape = buffer_shape or self._get_default_buffer_shape(buffer_gb=buffer_gb) @@ -286,36 +372,22 @@ def __init__(self, **kwargs): ) self.display_progress = False - @docval( - dict( - name="chunk_mb", - type=(float, int), - doc="Size of the HDF5 chunk in megabytes.", - default=None, - ) - ) - def _get_default_chunk_shape(self, **kwargs) -> Tuple[int, ...]: + def _get_default_chunk_shape(self, chunk_mb: Union[float, int] = 10.0, chunk_shape=Optional[List[Optional[int]]]) -> Tuple[int, ...]: """ Select chunk shape with size in MB less than the threshold of chunk_mb. - Keeps the dimensional ratios of the original data. + Tries to make the dimensions even. """ - chunk_mb = getargs("chunk_mb", kwargs) + assert chunk_mb > 0, f"chunk_mb ({chunk_mb}) must be greater than zero!" - n_dims = len(self.maxshape) - itemsize = self.dtype.itemsize - chunk_bytes = chunk_mb * 1e6 - - min_maxshape = min(self.maxshape) - v = tuple(math.floor(maxshape_axis / min_maxshape) for maxshape_axis in self.maxshape) - prod_v = math.prod(v) - while prod_v * itemsize > chunk_bytes and prod_v != 1: - non_unit_min_v = min(x for x in v if x != 1) - v = tuple(math.floor(x / non_unit_min_v) if x != 1 else x for x in v) - prod_v = math.prod(v) - k = math.floor((chunk_bytes / (prod_v * itemsize)) ** (1 / n_dims)) - return tuple([min(k * x, self.maxshape[dim]) for dim, x in enumerate(v)]) + return tuple( + array_with_desired_product( + chunk_mb * 2 ** 20 / self.dtype.itemsize, + self.maxshape, + defined_elements=chunk_shape, + ) + ) @docval( dict(