diff --git a/nibabel/analyze.py b/nibabel/analyze.py index ddc9d29805..a1c1cf1d2f 100644 --- a/nibabel/analyze.py +++ b/nibabel/analyze.py @@ -1020,13 +1020,24 @@ def to_file_map(self, file_map=None, dtype=None): inter = hdr['scl_inter'].item() if hdr.has_data_intercept else np.nan # Check whether to calculate slope / inter scale_me = np.all(np.isnan((slope, inter))) - if scale_me: - arr_writer = make_array_writer(data, - out_dtype, - hdr.has_data_slope, - hdr.has_data_intercept) - else: - arr_writer = ArrayWriter(data, out_dtype, check_scaling=False) + try: + if scale_me: + arr_writer = make_array_writer(data, + out_dtype, + hdr.has_data_slope, + hdr.has_data_intercept) + else: + arr_writer = ArrayWriter(data, out_dtype, check_scaling=False) + except WriterError: + # Restore any changed consumable values, in case caller catches + # Should match cleanup at the end of the method + hdr.set_data_offset(offset) + hdr.set_data_dtype(data_dtype) + if hdr.has_data_slope: + hdr['scl_slope'] = slope + if hdr.has_data_intercept: + hdr['scl_inter'] = inter + raise hdr_fh, img_fh = self._get_fileholders(file_map) # Check if hdr and img refer to same file; this can happen with odd # analyze images but most often this is because it's a single nifti diff --git a/nibabel/cifti2/cifti2.py b/nibabel/cifti2/cifti2.py index 32b7a7d8d3..b75fd01db9 100644 --- a/nibabel/cifti2/cifti2.py +++ b/nibabel/cifti2/cifti2.py @@ -19,6 +19,10 @@ import re from collections.abc import MutableSequence, MutableMapping, Iterable from collections import OrderedDict +from warnings import warn + +import numpy as np + from .. import xmlutils as xml from ..filebasedimages import FileBasedHeader, SerializableImage from ..dataobj_images import DataobjImage @@ -26,7 +30,7 @@ from ..nifti2 import Nifti2Image, Nifti2Header from ..arrayproxy import reshape_dataobj from ..caret import CaretMetaData -from warnings import warn +from ..volumeutils import make_dt_codes def _float_01(val): @@ -41,6 +45,22 @@ class Cifti2HeaderError(Exception): """ +_dtdefs = ( # code, label, dtype definition, niistring + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"), + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"), + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"), + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"), + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"), + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"), + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"), + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"), + (1024, 'int64', np.int64, "NIFTI_TYPE_INT64"), + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"), +) + +# Make full code alias bank, including dtype column +data_type_codes = make_dt_codes(_dtdefs) + CIFTI_MAP_TYPES = ('CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SERIES', @@ -103,6 +123,10 @@ def _underscore(string): return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', string).lower() +class LimitedNifti2Header(Nifti2Header): + _data_type_codes = data_type_codes + + class Cifti2MetaData(CaretMetaData): """ A list of name-value pairs @@ -1363,7 +1387,8 @@ def __init__(self, header=None, nifti_header=None, extra=None, - file_map=None): + file_map=None, + dtype=None): """ Initialize image The image is a combination of (dataobj, header), with optional metadata @@ -1392,12 +1417,13 @@ def __init__(self, header = Cifti2Header.from_axes(header) super(Cifti2Image, self).__init__(dataobj, header=header, extra=extra, file_map=file_map) - self._nifti_header = Nifti2Header.from_header(nifti_header) + self._nifti_header = LimitedNifti2Header.from_header(nifti_header) # if NIfTI header not specified, get data type from input array - if nifti_header is None: - if hasattr(dataobj, 'dtype'): - self._nifti_header.set_data_dtype(dataobj.dtype) + if dtype is not None: + self.set_data_dtype(dtype) + elif nifti_header is None and hasattr(dataobj, 'dtype'): + self.set_data_dtype(dataobj.dtype) self.update_headers() if self._dataobj.shape != self.header.matrix.get_data_shape(): diff --git a/nibabel/cifti2/tests/test_cifti2.py b/nibabel/cifti2/tests/test_cifti2.py index db65d0f82b..fc64c34554 100644 --- a/nibabel/cifti2/tests/test_cifti2.py +++ b/nibabel/cifti2/tests/test_cifti2.py @@ -12,7 +12,7 @@ import pytest from nibabel.tests.test_dataobj_images import TestDataobjAPI as _TDA -from nibabel.tests.test_image_api import SerializeMixin +from nibabel.tests.test_image_api import SerializeMixin, DtypeOverrideMixin def compare_xml_leaf(str1, str2): @@ -415,7 +415,7 @@ def test_underscoring(): assert ci.cifti2._underscore(camel) == underscored -class TestCifti2ImageAPI(_TDA, SerializeMixin): +class TestCifti2ImageAPI(_TDA, SerializeMixin, DtypeOverrideMixin): """ Basic validation for Cifti2Image instances """ # A callable returning an image from ``image_maker(data, header)`` @@ -426,6 +426,8 @@ class TestCifti2ImageAPI(_TDA, SerializeMixin): ni_header_maker = Nifti2Header example_shapes = ((2,), (2, 3), (2, 3, 4)) standard_extension = '.nii' + storable_dtypes = (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32, + np.int64, np.uint64, np.float32, np.float64) def make_imaker(self, arr, header=None, ni_header=None): for idx, sz in enumerate(arr.shape): diff --git a/nibabel/tests/test_image_api.py b/nibabel/tests/test_image_api.py index 81f83e2100..0519af071d 100644 --- a/nibabel/tests/test_image_api.py +++ b/nibabel/tests/test_image_api.py @@ -55,6 +55,8 @@ from .test_parrec import EXAMPLE_IMAGES as PARREC_EXAMPLE_IMAGES from .test_brikhead import EXAMPLE_IMAGES as AFNI_EXAMPLE_IMAGES +from nibabel.arraywriters import WriterError + def maybe_deprecated(meth_name): return pytest.deprecated_call() if meth_name == 'get_data' else nullcontext() @@ -181,7 +183,7 @@ def validate_get_data_deprecated(self, imaker, params): assert_array_equal(np.asanyarray(img.dataobj), data) -class GetSetDtypeMixin(object): +class GetSetDtypeMixin: """ Adds dtype tests Add this one if your image has ``get_data_dtype`` and ``set_data_dtype``. @@ -666,6 +668,46 @@ def prox_imaker(): yield make_prox_imaker(arr.copy(), aff, hdr), params +class DtypeOverrideMixin(GetSetDtypeMixin): + """ Test images that can accept ``dtype`` arguments to ``__init__`` and + ``to_file_map`` + """ + + def validate_init_dtype_override(self, imaker, params): + img = imaker() + klass = img.__class__ + for dtype in self.storable_dtypes: + if hasattr(img, 'affine'): + new_img = klass(img.dataobj, img.affine, header=img.header, dtype=dtype) + else: # XXX This is for CIFTI-2, these validators might need refactoring + new_img = klass(img.dataobj, header=img.header, dtype=dtype) + assert new_img.get_data_dtype() == dtype + + if self.has_scaling and self.can_save: + with np.errstate(invalid='ignore'): + rt_img = bytesio_round_trip(new_img) + assert rt_img.get_data_dtype() == dtype + + def validate_to_file_dtype_override(self, imaker, params): + if not self.can_save: + raise unittest.SkipTest + img = imaker() + orig_dtype = img.get_data_dtype() + fname = 'image' + self.standard_extension + with InTemporaryDirectory(): + for dtype in self.storable_dtypes: + try: + img.to_filename(fname, dtype=dtype) + except WriterError: + # It's possible to try to save to a dtype that requires + # scaling, and images without scale factors will fail. + # We're not testing that here. + continue + rt_img = img.__class__.from_filename(fname) + assert rt_img.get_data_dtype() == dtype + assert img.get_data_dtype() == orig_dtype + + class ImageHeaderAPI(MakeImageAPI): """ When ``self.image_maker`` is an image class, make header from class """ @@ -674,7 +716,12 @@ def header_maker(self): return self.image_maker.header_class() -class TestAnalyzeAPI(ImageHeaderAPI): +class TestSpatialImageAPI(ImageHeaderAPI): + klass = image_maker = SpatialImage + can_save = False + + +class TestAnalyzeAPI(TestSpatialImageAPI, DtypeOverrideMixin): """ General image validation API instantiated for Analyze images """ klass = image_maker = AnalyzeImage @@ -685,11 +732,6 @@ class TestAnalyzeAPI(ImageHeaderAPI): storable_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.float64) -class TestSpatialImageAPI(TestAnalyzeAPI): - klass = image_maker = SpatialImage - can_save = False - - class TestSpm99AnalyzeAPI(TestAnalyzeAPI): # SPM-type analyze need scipy for mat file IO klass = image_maker = Spm99AnalyzeImage