Skip to content

Commit

Permalink
Merge pull request #1111 from effigies/enh/cifti_dtype_arg
Browse files Browse the repository at this point in the history
ENH: Add dtype argument to Cifti2Image
  • Loading branch information
effigies authored Jun 3, 2022
2 parents d0532ec + 7933fae commit a7e1e0e
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 22 deletions.
25 changes: 18 additions & 7 deletions nibabel/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,13 +1020,24 @@ def to_file_map(self, file_map=None, dtype=None):
inter = hdr['scl_inter'].item() if hdr.has_data_intercept else np.nan
# Check whether to calculate slope / inter
scale_me = np.all(np.isnan((slope, inter)))
if scale_me:
arr_writer = make_array_writer(data,
out_dtype,
hdr.has_data_slope,
hdr.has_data_intercept)
else:
arr_writer = ArrayWriter(data, out_dtype, check_scaling=False)
try:
if scale_me:
arr_writer = make_array_writer(data,
out_dtype,
hdr.has_data_slope,
hdr.has_data_intercept)
else:
arr_writer = ArrayWriter(data, out_dtype, check_scaling=False)
except WriterError:
# Restore any changed consumable values, in case caller catches
# Should match cleanup at the end of the method
hdr.set_data_offset(offset)
hdr.set_data_dtype(data_dtype)
if hdr.has_data_slope:
hdr['scl_slope'] = slope
if hdr.has_data_intercept:
hdr['scl_inter'] = inter
raise
hdr_fh, img_fh = self._get_fileholders(file_map)
# Check if hdr and img refer to same file; this can happen with odd
# analyze images but most often this is because it's a single nifti
Expand Down
38 changes: 32 additions & 6 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
import re
from collections.abc import MutableSequence, MutableMapping, Iterable
from collections import OrderedDict
from warnings import warn

import numpy as np

from .. import xmlutils as xml
from ..filebasedimages import FileBasedHeader, SerializableImage
from ..dataobj_images import DataobjImage
from ..nifti1 import Nifti1Extensions
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..caret import CaretMetaData
from warnings import warn
from ..volumeutils import make_dt_codes


def _float_01(val):
Expand All @@ -41,6 +45,22 @@ class Cifti2HeaderError(Exception):
"""


_dtdefs = ( # code, label, dtype definition, niistring
(2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),
(4, 'int16', np.int16, "NIFTI_TYPE_INT16"),
(8, 'int32', np.int32, "NIFTI_TYPE_INT32"),
(16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),
(64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),
(256, 'int8', np.int8, "NIFTI_TYPE_INT8"),
(512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),
(768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),
(1024, 'int64', np.int64, "NIFTI_TYPE_INT64"),
(1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),
)

# Make full code alias bank, including dtype column
data_type_codes = make_dt_codes(_dtdefs)

CIFTI_MAP_TYPES = ('CIFTI_INDEX_TYPE_BRAIN_MODELS',
'CIFTI_INDEX_TYPE_PARCELS',
'CIFTI_INDEX_TYPE_SERIES',
Expand Down Expand Up @@ -103,6 +123,10 @@ def _underscore(string):
return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', string).lower()


class LimitedNifti2Header(Nifti2Header):
_data_type_codes = data_type_codes


class Cifti2MetaData(CaretMetaData):
""" A list of name-value pairs
Expand Down Expand Up @@ -1363,7 +1387,8 @@ def __init__(self,
header=None,
nifti_header=None,
extra=None,
file_map=None):
file_map=None,
dtype=None):
""" Initialize image
The image is a combination of (dataobj, header), with optional metadata
Expand Down Expand Up @@ -1392,12 +1417,13 @@ def __init__(self,
header = Cifti2Header.from_axes(header)
super(Cifti2Image, self).__init__(dataobj, header=header,
extra=extra, file_map=file_map)
self._nifti_header = Nifti2Header.from_header(nifti_header)
self._nifti_header = LimitedNifti2Header.from_header(nifti_header)

# if NIfTI header not specified, get data type from input array
if nifti_header is None:
if hasattr(dataobj, 'dtype'):
self._nifti_header.set_data_dtype(dataobj.dtype)
if dtype is not None:
self.set_data_dtype(dtype)
elif nifti_header is None and hasattr(dataobj, 'dtype'):
self.set_data_dtype(dataobj.dtype)
self.update_headers()

if self._dataobj.shape != self.header.matrix.get_data_shape():
Expand Down
6 changes: 4 additions & 2 deletions nibabel/cifti2/tests/test_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest

from nibabel.tests.test_dataobj_images import TestDataobjAPI as _TDA
from nibabel.tests.test_image_api import SerializeMixin
from nibabel.tests.test_image_api import SerializeMixin, DtypeOverrideMixin


def compare_xml_leaf(str1, str2):
Expand Down Expand Up @@ -415,7 +415,7 @@ def test_underscoring():
assert ci.cifti2._underscore(camel) == underscored


class TestCifti2ImageAPI(_TDA, SerializeMixin):
class TestCifti2ImageAPI(_TDA, SerializeMixin, DtypeOverrideMixin):
""" Basic validation for Cifti2Image instances
"""
# A callable returning an image from ``image_maker(data, header)``
Expand All @@ -426,6 +426,8 @@ class TestCifti2ImageAPI(_TDA, SerializeMixin):
ni_header_maker = Nifti2Header
example_shapes = ((2,), (2, 3), (2, 3, 4))
standard_extension = '.nii'
storable_dtypes = (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32,
np.int64, np.uint64, np.float32, np.float64)

def make_imaker(self, arr, header=None, ni_header=None):
for idx, sz in enumerate(arr.shape):
Expand Down
56 changes: 49 additions & 7 deletions nibabel/tests/test_image_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from .test_parrec import EXAMPLE_IMAGES as PARREC_EXAMPLE_IMAGES
from .test_brikhead import EXAMPLE_IMAGES as AFNI_EXAMPLE_IMAGES

from nibabel.arraywriters import WriterError


def maybe_deprecated(meth_name):
return pytest.deprecated_call() if meth_name == 'get_data' else nullcontext()
Expand Down Expand Up @@ -181,7 +183,7 @@ def validate_get_data_deprecated(self, imaker, params):
assert_array_equal(np.asanyarray(img.dataobj), data)


class GetSetDtypeMixin(object):
class GetSetDtypeMixin:
""" Adds dtype tests
Add this one if your image has ``get_data_dtype`` and ``set_data_dtype``.
Expand Down Expand Up @@ -666,6 +668,46 @@ def prox_imaker():
yield make_prox_imaker(arr.copy(), aff, hdr), params


class DtypeOverrideMixin(GetSetDtypeMixin):
""" Test images that can accept ``dtype`` arguments to ``__init__`` and
``to_file_map``
"""

def validate_init_dtype_override(self, imaker, params):
img = imaker()
klass = img.__class__
for dtype in self.storable_dtypes:
if hasattr(img, 'affine'):
new_img = klass(img.dataobj, img.affine, header=img.header, dtype=dtype)
else: # XXX This is for CIFTI-2, these validators might need refactoring
new_img = klass(img.dataobj, header=img.header, dtype=dtype)
assert new_img.get_data_dtype() == dtype

if self.has_scaling and self.can_save:
with np.errstate(invalid='ignore'):
rt_img = bytesio_round_trip(new_img)
assert rt_img.get_data_dtype() == dtype

def validate_to_file_dtype_override(self, imaker, params):
if not self.can_save:
raise unittest.SkipTest
img = imaker()
orig_dtype = img.get_data_dtype()
fname = 'image' + self.standard_extension
with InTemporaryDirectory():
for dtype in self.storable_dtypes:
try:
img.to_filename(fname, dtype=dtype)
except WriterError:
# It's possible to try to save to a dtype that requires
# scaling, and images without scale factors will fail.
# We're not testing that here.
continue
rt_img = img.__class__.from_filename(fname)
assert rt_img.get_data_dtype() == dtype
assert img.get_data_dtype() == orig_dtype


class ImageHeaderAPI(MakeImageAPI):
""" When ``self.image_maker`` is an image class, make header from class
"""
Expand All @@ -674,7 +716,12 @@ def header_maker(self):
return self.image_maker.header_class()


class TestAnalyzeAPI(ImageHeaderAPI):
class TestSpatialImageAPI(ImageHeaderAPI):
klass = image_maker = SpatialImage
can_save = False


class TestAnalyzeAPI(TestSpatialImageAPI, DtypeOverrideMixin):
""" General image validation API instantiated for Analyze images
"""
klass = image_maker = AnalyzeImage
Expand All @@ -685,11 +732,6 @@ class TestAnalyzeAPI(ImageHeaderAPI):
storable_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.float64)


class TestSpatialImageAPI(TestAnalyzeAPI):
klass = image_maker = SpatialImage
can_save = False


class TestSpm99AnalyzeAPI(TestAnalyzeAPI):
# SPM-type analyze need scipy for mat file IO
klass = image_maker = Spm99AnalyzeImage
Expand Down

0 comments on commit a7e1e0e

Please sign in to comment.