Skip to content

Commit

Permalink
Merge pull request #1096 from effigies/enh/dtype_aliases
Browse files Browse the repository at this point in the history
ENH: Add static and dynamic dtype aliases to NIfTI images
  • Loading branch information
effigies authored Jun 3, 2022
2 parents a7e1e0e + 58d37a2 commit 1312493
Show file tree
Hide file tree
Showing 2 changed files with 341 additions and 9 deletions.
293 changes: 284 additions & 9 deletions nibabel/nifti1.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,26 +898,28 @@ def set_data_dtype(self, datatype):
>>> hdr.set_data_dtype(np.dtype(np.uint8))
>>> hdr.get_data_dtype()
dtype('uint8')
>>> hdr.set_data_dtype('implausible') #doctest: +IGNORE_EXCEPTION_DETAIL
>>> hdr.set_data_dtype('implausible')
Traceback (most recent call last):
...
HeaderDataError: data dtype "implausible" not recognized
>>> hdr.set_data_dtype('none') #doctest: +IGNORE_EXCEPTION_DETAIL
nibabel.spatialimages.HeaderDataError: data dtype "implausible" not recognized
>>> hdr.set_data_dtype('none')
Traceback (most recent call last):
...
HeaderDataError: data dtype "none" known but not supported
>>> hdr.set_data_dtype(np.void) #doctest: +IGNORE_EXCEPTION_DETAIL
nibabel.spatialimages.HeaderDataError: data dtype "none" known but not supported
>>> hdr.set_data_dtype(np.void)
Traceback (most recent call last):
...
HeaderDataError: data dtype "<type 'numpy.void'>" known but not supported
>>> hdr.set_data_dtype('int') #doctest: +IGNORE_EXCEPTION_DETAIL
nibabel.spatialimages.HeaderDataError: data dtype "<class 'numpy.void'>" known
but not supported
>>> hdr.set_data_dtype('int')
Traceback (most recent call last):
...
ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
>>> hdr.set_data_dtype(int) #doctest: +IGNORE_EXCEPTION_DETAIL
>>> hdr.set_data_dtype(int)
Traceback (most recent call last):
...
ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
ValueError: Invalid data type <class 'int'>. Specify a sized integer, e.g., 'uint8' or
numpy.int16.
>>> hdr.set_data_dtype('int64')
>>> hdr.get_data_dtype() == np.dtype('int64')
True
Expand Down Expand Up @@ -1799,6 +1801,10 @@ class Nifti1Pair(analyze.AnalyzeImage):
_meta_sniff_len = header_class.sizeof_hdr
rw = True

# If a _dtype_alias has been set, it can only be resolved by inspecting
# the data at serialization time
_dtype_alias = None

def __init__(self, dataobj, affine, header=None,
extra=None, file_map=None, dtype=None):
# Special carve-out for 64 bit integers
Expand Down Expand Up @@ -2043,6 +2049,137 @@ def set_sform(self, affine, code=None, **kwargs):
else:
self._affine[:] = self._header.get_best_affine()

def set_data_dtype(self, datatype):
""" Set numpy dtype for data from code, dtype, type or alias
Using :py:class:`int` or ``"int"`` is disallowed, as these types
will be interpreted as ``np.int64``, which is almost never desired.
``np.int64`` is permitted for those intent on making poor choices.
The following aliases are defined to allow for flexible specification:
* ``'mask'`` - Alias for ``uint8``
* ``'compat'`` - The nearest Analyze-compatible datatype
(``uint8``, ``int16``, ``int32``, ``float32``)
* ``'smallest'`` - The smallest Analyze-compatible integer
(``uint8``, ``int16``, ``int32``)
Dynamic aliases are resolved when ``get_data_dtype()`` is called
with a ``finalize=True`` flag. Until then, these aliases are not
written to the header and will not persist to new images.
Examples
--------
>>> ints = np.arange(24, dtype='i4').reshape((2,3,4))
>>> img = Nifti1Image(ints, np.eye(4))
>>> img.set_data_dtype(np.uint8)
>>> img.get_data_dtype()
dtype('uint8')
>>> img.set_data_dtype('mask')
>>> img.get_data_dtype()
dtype('uint8')
>>> img.set_data_dtype('compat')
>>> img.get_data_dtype()
'compat'
>>> img.get_data_dtype(finalize=True)
dtype('<i4')
>>> img.get_data_dtype()
dtype('<i4')
>>> img.set_data_dtype('smallest')
>>> img.get_data_dtype()
'smallest'
>>> img.get_data_dtype(finalize=True)
dtype('uint8')
>>> img.get_data_dtype()
dtype('uint8')
Note that floating point values will not be coerced to ``int``
>>> floats = np.arange(24, dtype='f4').reshape((2,3,4))
>>> img = Nifti1Image(floats, np.eye(4))
>>> img.set_data_dtype('smallest')
>>> img.get_data_dtype(finalize=True)
Traceback (most recent call last):
...
ValueError: Cannot automatically cast array (of type float32) to an integer
type with fewer than 64 bits. Please set_data_dtype() to an explicit data type.
>>> arr = np.arange(1000, 1024, dtype='i4').reshape((2,3,4))
>>> img = Nifti1Image(arr, np.eye(4))
>>> img.set_data_dtype('smallest')
>>> img.set_data_dtype('implausible')
Traceback (most recent call last):
...
nibabel.spatialimages.HeaderDataError: data dtype "implausible" not recognized
>>> img.set_data_dtype('none')
Traceback (most recent call last):
...
nibabel.spatialimages.HeaderDataError: data dtype "none" known but not supported
>>> img.set_data_dtype(np.void)
Traceback (most recent call last):
...
nibabel.spatialimages.HeaderDataError: data dtype "<class 'numpy.void'>" known
but not supported
>>> img.set_data_dtype('int')
Traceback (most recent call last):
...
ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
>>> img.set_data_dtype(int)
Traceback (most recent call last):
...
ValueError: Invalid data type <class 'int'>. Specify a sized integer, e.g., 'uint8' or
numpy.int16.
>>> img.set_data_dtype('int64')
>>> img.get_data_dtype() == np.dtype('int64')
True
"""
# Comparing dtypes to strings, numpy will attempt to call, e.g., dtype('mask'),
# so only check for aliases if the type is a string
# See https://github.com/numpy/numpy/issues/7242
if isinstance(datatype, str):
# Static aliases
if datatype == 'mask':
datatype = 'u1'
# Dynamic aliases
elif datatype in ('compat', 'smallest'):
self._dtype_alias = datatype
return

self._dtype_alias = None
super().set_data_dtype(datatype)

def get_data_dtype(self, finalize=False):
""" Get numpy dtype for data
If ``set_data_dtype()`` has been called with an alias
and ``finalize`` is ``False``, return the alias.
If ``finalize`` is ``True``, determine the appropriate dtype
from the image data object and set the final dtype in the
header before returning it.
"""
if self._dtype_alias is None:
return super().get_data_dtype()
if not finalize:
return self._dtype_alias

datatype = None
if self._dtype_alias == 'compat':
datatype = _get_analyze_compat_dtype(self._dataobj)
descrip = "an Analyze-compatible dtype"
elif self._dtype_alias == 'smallest':
datatype = _get_smallest_dtype(self._dataobj)
descrip = "an integer type with fewer than 64 bits"
else:
raise ValueError(f"Unknown dtype alias {self._dtype_alias}.")
if datatype is None:
dt = get_obj_dtype(self._dataobj)
raise ValueError(f"Cannot automatically cast array (of type {dt}) to {descrip}."
" Please set_data_dtype() to an explicit data type.")

self.set_data_dtype(datatype) # Clears the alias
return super().get_data_dtype()

def as_reoriented(self, ornt):
"""Apply an orientation change and return a new image
Expand Down Expand Up @@ -2136,3 +2273,141 @@ def save(img, filename):
Nifti1Image.instance_to_filename(img, filename)
except ImageFileError:
Nifti1Pair.instance_to_filename(img, filename)


def _get_smallest_dtype(
arr,
itypes=(np.uint8, np.int16, np.int32),
ftypes=(),
):
""" Return the smallest "sensible" dtype that will hold the array data
The purpose of this function is to support automatic type selection
for serialization, so "sensible" here means well-supported in the NIfTI-1 world.
For floating point data, select between single- and double-precision.
For integer data, select among uint8, int16 and int32.
The test is for min/max range, so float64 is pretty unlikely to be hit.
Returns ``None`` if these dtypes do not suffice.
>>> _get_smallest_dtype(np.array([0, 1]))
dtype('uint8')
>>> _get_smallest_dtype(np.array([-1, 1]))
dtype('int16')
>>> _get_smallest_dtype(np.array([0, 256]))
dtype('int16')
>>> _get_smallest_dtype(np.array([-65536, 65536]))
dtype('int32')
>>> _get_smallest_dtype(np.array([-2147483648, 2147483648]))
By default floating point types are not searched:
>>> _get_smallest_dtype(np.array([1.]))
>>> _get_smallest_dtype(np.array([2. ** 1000]))
>>> _get_smallest_dtype(np.longdouble(2) ** 2000)
>>> _get_smallest_dtype(np.array([1+0j]))
However, this function can be passed "legal" floating point types, and
the logic works the same.
>>> _get_smallest_dtype(np.array([1.]), ftypes=('float32',))
dtype('float32')
>>> _get_smallest_dtype(np.array([2. ** 1000]), ftypes=('float32',))
>>> _get_smallest_dtype(np.longdouble(2) ** 2000, ftypes=('float32',))
>>> _get_smallest_dtype(np.array([1+0j]), ftypes=('float32',))
"""
arr = np.asanyarray(arr)
if np.issubdtype(arr.dtype, np.floating):
test_dts = ftypes
info = np.finfo
elif np.issubdtype(arr.dtype, np.integer):
test_dts = itypes
info = np.iinfo
else:
return None

mn, mx = np.min(arr), np.max(arr)
for dt in test_dts:
dtinfo = info(dt)
if dtinfo.min <= mn and mx <= dtinfo.max:
return np.dtype(dt)


def _get_analyze_compat_dtype(arr):
""" Return an Analyze-compatible dtype that ``arr`` can be safely cast to
Analyze-compatible types are returned without inspection:
>>> _get_analyze_compat_dtype(np.uint8([0, 1]))
dtype('uint8')
>>> _get_analyze_compat_dtype(np.int16([0, 1]))
dtype('int16')
>>> _get_analyze_compat_dtype(np.int32([0, 1]))
dtype('int32')
>>> _get_analyze_compat_dtype(np.float32([0, 1]))
dtype('float32')
Signed ``int8`` are cast to ``uint8`` or ``int16`` based on value ranges:
>>> _get_analyze_compat_dtype(np.int8([0, 1]))
dtype('uint8')
>>> _get_analyze_compat_dtype(np.int8([-1, 1]))
dtype('int16')
Unsigned ``uint16`` are cast to ``int16`` or ``int32`` based on value ranges:
>>> _get_analyze_compat_dtype(np.uint16([32767]))
dtype('int16')
>>> _get_analyze_compat_dtype(np.uint16([65535]))
dtype('int32')
``int32`` is returned for integer types and ``float32`` for floating point types:
>>> _get_analyze_compat_dtype(np.array([-1, 1]))
dtype('int32')
>>> _get_analyze_compat_dtype(np.array([-1., 1.]))
dtype('float32')
If the value ranges exceed 4 bytes or cannot be cast, then a ``ValueError`` is raised:
>>> _get_analyze_compat_dtype(np.array([0, 4294967295]))
Traceback (most recent call last):
...
ValueError: Cannot find analyze-compatible dtype for array with dtype=int64
(min=0, max=4294967295)
>>> _get_analyze_compat_dtype([0., 2.e40])
Traceback (most recent call last):
...
ValueError: Cannot find analyze-compatible dtype for array with dtype=float64
(min=0.0, max=2e+40)
Note that real-valued complex arrays cannot be safely cast.
>>> _get_analyze_compat_dtype(np.array([1+0j]))
Traceback (most recent call last):
...
ValueError: Cannot find analyze-compatible dtype for array with dtype=complex128
(min=(1+0j), max=(1+0j))
"""
arr = np.asanyarray(arr)
dtype = arr.dtype
if dtype in (np.uint8, np.int16, np.int32, np.float32):
return dtype

if dtype == np.int8:
return np.dtype('uint8' if arr.min() >= 0 else 'int16')
elif dtype == np.uint16:
return np.dtype('int16' if arr.max() <= np.iinfo(np.int16).max else 'int32')

mn, mx = arr.min(), arr.max()
if np.can_cast(mn, np.int32) and np.can_cast(mx, np.int32):
return np.dtype('int32')
if np.can_cast(mn, np.float32) and np.can_cast(mx, np.float32):
return np.dtype('float32')

raise ValueError(
f"Cannot find analyze-compatible dtype for array with dtype={dtype} (min={mn}, max={mx})"
)
57 changes: 57 additions & 0 deletions nibabel/tests/test_nifti1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,63 @@ def test_write_scaling(self):
with np.errstate(invalid='ignore'):
self._check_write_scaling(slope, inter, e_slope, e_inter)

def test_dynamic_dtype_aliases(self):
for in_dt, mn, mx, alias, effective_dt in [
(np.uint8, 0, 255, 'compat', np.uint8),
(np.int8, 0, 127, 'compat', np.uint8),
(np.int8, -128, 127, 'compat', np.int16),
(np.int16, -32768, 32767, 'compat', np.int16),
(np.uint16, 0, 32767, 'compat', np.int16),
(np.uint16, 0, 65535, 'compat', np.int32),
(np.int32, -2**31, 2**31-1, 'compat', np.int32),
(np.uint32, 0, 2**31-1, 'compat', np.int32),
(np.uint32, 0, 2**32-1, 'compat', None),
(np.int64, -2**31, 2**31-1, 'compat', np.int32),
(np.uint64, 0, 2**31-1, 'compat', np.int32),
(np.int64, 0, 2**32-1, 'compat', None),
(np.uint64, 0, 2**32-1, 'compat', None),
(np.float32, 0, 1e30, 'compat', np.float32),
(np.float64, 0, 1e30, 'compat', np.float32),
(np.float64, 0, 1e40, 'compat', None),
(np.int64, 0, 255, 'smallest', np.uint8),
(np.int64, 0, 256, 'smallest', np.int16),
(np.int64, -1, 255, 'smallest', np.int16),
(np.int64, 0, 32768, 'smallest', np.int32),
(np.int64, 0, 4294967296, 'smallest', None),
(np.float32, 0, 1, 'smallest', None),
(np.float64, 0, 1, 'smallest', None)
]:
arr = np.arange(24, dtype=in_dt).reshape((2, 3, 4))
arr[0, 0, :2] = [mn, mx]
img = self.image_class(arr, np.eye(4), dtype=alias)
# Stored as alias
assert img.get_data_dtype() == alias
if effective_dt is None:
with pytest.raises(ValueError):
img.get_data_dtype(finalize=True)
continue
# Finalizing sets and clears the alias
assert img.get_data_dtype(finalize=True) == effective_dt
assert img.get_data_dtype() == effective_dt
# Re-set to alias
img.set_data_dtype(alias)
assert img.get_data_dtype() == alias
img_rt = bytesio_round_trip(img)
assert img_rt.get_data_dtype() == effective_dt
# Seralizing does not finalize the source image
assert img.get_data_dtype() == alias

def test_static_dtype_aliases(self):
for alias, effective_dt in [
("mask", np.uint8),
]:
for orig_dt in ('u1', 'i8', 'f4'):
arr = np.arange(24, dtype=orig_dt).reshape((2, 3, 4))
img = self.image_class(arr, np.eye(4), dtype=alias)
assert img.get_data_dtype() == effective_dt
img_rt = bytesio_round_trip(img)
assert img_rt.get_data_dtype() == effective_dt


class TestNifti1Image(TestNifti1Pair):
# Run analyze-flavor spatialimage tests
Expand Down

0 comments on commit 1312493

Please sign in to comment.