diff --git a/src/mrpro/data/AcqInfo.py b/src/mrpro/data/AcqInfo.py index a264db5f6..926ad6024 100644 --- a/src/mrpro/data/AcqInfo.py +++ b/src/mrpro/data/AcqInfo.py @@ -2,7 +2,9 @@ from collections.abc import Sequence from dataclasses import dataclass +from typing import TypeAlias +import einops import ismrmrd import numpy as np import torch @@ -14,6 +16,9 @@ # Conversion functions for units T = TypeVar('T', float, torch.Tensor) +# We use this for runtime dtype checking in the dataclasses +LongTensor: TypeAlias = torch.Tensor | torch.LongTensor | torch.IntTensor + def ms_to_s(ms: T) -> T: """Convert ms to s.""" @@ -25,64 +30,114 @@ def mm_to_m(m: T) -> T: return m / 1000 +class _InvariantsAcqInfo: + __slots__ = '__broadcasted_shape', '__typehints' + __broadcasted_shape: torch.Size | None + + def __post_init__(self): + self._check_invariants() + + def _check_invariants(self): + shapes = [] + for name in self.__slots__: + expected_type = self.__annotations__[name] + value = getattr(self, name) + if not isinstance(value, expected_type): + raise TypeError(f'{name} must be of type {expected_type}, got {type(value)} instead') + if hasattr(value, 'shape'): + # isinstance(value, torch.Tensor | SpatialDimension | Rotation): + shape = value.shape + if len(shape) < 5: + raise ValueError(f'{name} must have at least 5 dimensions') + if shape[-1] != 1: + raise ValueError(f'{name} must have a k0 dimension of size 1') + if shape[-4] != 1: + raise ValueError(f'{name} must have a coil dimension of size 1') + shapes.append(value.shape) + + if hasattr(value, 'dtype') and value.dtype.is_complex: + raise ValueError(f'{name} must not be complex.') + if expected_type == LongTensor and value.dtype not in ( + torch.int64, + torch.uint64, + torch.int32, + torch.uint32, + ): + raise ValueError(f'{name} must be integer.') + + elif hasattr(value, 'broadcasted_shape'): + shapes.append(value.broadcasted_shape) + try: + broadcasted_shape = torch.broadcast_shapes(*shapes) + except RuntimeError: + raise ValueError(f'The Acquisition information tensors {self.__slots__} must be broadcastable.') from None + + self.__broadcasted_shape = broadcasted_shape + + @property + def broadcasted_shape(self) -> torch.Size: + assert self.__broadcasted_shape is not None # noqa: S101 # mypy hint + return self.__broadcasted_shape + + @dataclass(slots=True) -class AcqIdx(MoveDataMixin): +class AcqIdx(MoveDataMixin, _InvariantsAcqInfo): """Acquisition index for each readout.""" - k1: torch.Tensor + k1: LongTensor """First phase encoding.""" - k2: torch.Tensor + k2: LongTensor """Second phase encoding.""" - average: torch.Tensor + average: LongTensor """Signal average.""" - slice: torch.Tensor + slice: LongTensor """Slice number (multi-slice 2D).""" - contrast: torch.Tensor + contrast: LongTensor """Echo number in multi-echo.""" - phase: torch.Tensor + phase: LongTensor """Cardiac phase.""" - repetition: torch.Tensor + repetition: LongTensor """Counter in repeated/dynamic acquisitions.""" - set: torch.Tensor + set: LongTensor """Sets of different preparation, e.g. flow encoding, diffusion weighting.""" - segment: torch.Tensor + segment: LongTensor """Counter for segmented acquisitions.""" - user0: torch.Tensor + user0: LongTensor """User index 0.""" - user1: torch.Tensor + user1: LongTensor """User index 1.""" - user2: torch.Tensor + user2: LongTensor """User index 2.""" - user3: torch.Tensor + user3: LongTensor """User index 3.""" - user4: torch.Tensor + user4: LongTensor """User index 4.""" - user5: torch.Tensor + user5: LongTensor """User index 5.""" - user6: torch.Tensor + user6: LongTensor """User index 6.""" - user7: torch.Tensor + user7: LongTensor """User index 7.""" @dataclass(slots=True) -class UserValues(MoveDataMixin): +class UserValues(MoveDataMixin, _InvariantsAcqInfo): """User-defined values for each readout.""" float0: torch.Tensor @@ -109,28 +164,28 @@ class UserValues(MoveDataMixin): float7: torch.Tensor """User float 7.""" - int0: torch.Tensor + int0: LongTensor """User int 0.""" - int1: torch.Tensor + int1: LongTensor """User int 1.""" - int2: torch.Tensor + int2: LongTensor """User int 2.""" - int3: torch.Tensor + int3: LongTensor """User int 3.""" - int4: torch.Tensor + int4: LongTensor """User int 4.""" - int5: torch.Tensor + int5: LongTensor """User int 5.""" - int6: torch.Tensor + int6: LongTensor """User int 6.""" - int7: torch.Tensor + int7: LongTensor """User int 7.""" @@ -144,34 +199,34 @@ class AcqInfo(MoveDataMixin): acquisition_time_stamp: torch.Tensor """Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)""" - active_channels: torch.Tensor + active_channels: LongTensor """Number of active receiver coil elements.""" - available_channels: torch.Tensor + available_channels: LongTensor """Number of available receiver coil elements.""" - center_sample: torch.Tensor + center_sample: LongTensor """Index of the readout sample corresponding to k-space center (zero indexed).""" - channel_mask: torch.Tensor + channel_mask: LongTensor """Bit mask indicating active coils (64*16 = 1024 bits).""" - discard_post: torch.Tensor + discard_post: LongTensor """Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events).""" - discard_pre: torch.Tensor + discard_pre: LongTensor """Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)""" - encoding_space_ref: torch.Tensor + encoding_space_ref: LongTensor """Indexed reference to the encoding spaces enumerated in the MRD (xml) header.""" - flags: torch.Tensor + flags: LongTensor """A bit mask of common attributes applicable to individual acquisition readouts.""" - measurement_uid: torch.Tensor + measurement_uid: LongTensor """Unique ID corresponding to the readout.""" - number_of_samples: torch.Tensor + number_of_samples: LongTensor """Number of sample points per readout (readouts may have different number of sample points).""" patient_table_position: SpatialDimension[torch.Tensor] @@ -192,20 +247,20 @@ class AcqInfo(MoveDataMixin): sample_time_us: torch.Tensor """Readout bandwidth, as time between samples [us].""" - scan_counter: torch.Tensor + scan_counter: LongTensor """Zero-indexed incrementing counter for readouts.""" slice_dir: SpatialDimension[torch.Tensor] """Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir.""" - trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists. + trajectory_dimensions: LongTensor # =3. We only support 3D Trajectories: kz always exists. """Dimensionality of the k-space trajectory vector.""" user: UserValues - """User-defined values.""" + """User-defined values for each readout.""" - version: torch.Tensor - """Major version number.""" + version: LongTensor + """Major version number""" @classmethod def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self: @@ -236,6 +291,7 @@ def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) idx = headers['idx'] def tensor(data: np.ndarray) -> torch.Tensor: + """Convert to tensor with shape (other=1, coil=1, k2=1, k1=n, k0=1).""" # we have to convert first as pytoch cant create tensors from np.uint16 arrays # we use int32 for uint16 and int64 for uint32 to fit largest values. match data.dtype: @@ -243,18 +299,9 @@ def tensor(data: np.ndarray) -> torch.Tensor: data = data.astype(np.int32) case np.uint32 | np.uint64: data = data.astype(np.int64) - # Remove any uncessary dimensions - return torch.tensor(np.squeeze(data)) - - def tensor_2d(data: np.ndarray) -> torch.Tensor: - # Convert tensor to torch dtypes and ensure it is atleast 2D - data_tensor = tensor(data) - # Ensure that data is (k1*k2*other, >=1) - if data_tensor.ndim == 1: - data_tensor = data_tensor[:, None] - elif data_tensor.ndim == 0: - data_tensor = data_tensor[None, None] - return data_tensor + tensor = torch.from_numpy(data) + tensor = einops.repeat(tensor, '... -> other coil k2 (...) k0', other=1, coil=1, k2=1, k0=1) + return tensor def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: # Ensure spatial dimension is (k1*k2*other, 1, 3) @@ -283,49 +330,48 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]: user6=tensor(idx['user'][:, 6]), user7=tensor(idx['user'][:, 7]), ) - user_values = UserValues( - float0=tensor_2d(headers['user_float'][:, 0]), - float1=tensor_2d(headers['user_float'][:, 1]), - float2=tensor_2d(headers['user_float'][:, 2]), - float3=tensor_2d(headers['user_float'][:, 3]), - float4=tensor_2d(headers['user_float'][:, 4]), - float5=tensor_2d(headers['user_float'][:, 5]), - float6=tensor_2d(headers['user_float'][:, 6]), - float7=tensor_2d(headers['user_float'][:, 7]), - int0=tensor_2d(headers['user_int'][:, 0]), - int1=tensor_2d(headers['user_int'][:, 1]), - int2=tensor_2d(headers['user_int'][:, 2]), - int3=tensor_2d(headers['user_int'][:, 3]), - int4=tensor_2d(headers['user_int'][:, 4]), - int5=tensor_2d(headers['user_int'][:, 5]), - int6=tensor_2d(headers['user_int'][:, 6]), - int7=tensor_2d(headers['user_int'][:, 7]), + float0=tensor(headers['user_float'][:, 0]), + float1=tensor(headers['user_float'][:, 1]), + float2=tensor(headers['user_float'][:, 2]), + float3=tensor(headers['user_float'][:, 3]), + float4=tensor(headers['user_float'][:, 4]), + float5=tensor(headers['user_float'][:, 5]), + float6=tensor(headers['user_float'][:, 6]), + float7=tensor(headers['user_float'][:, 7]), + int0=tensor(headers['user_int'][:, 0]), + int1=tensor(headers['user_int'][:, 1]), + int2=tensor(headers['user_int'][:, 2]), + int3=tensor(headers['user_int'][:, 3]), + int4=tensor(headers['user_int'][:, 4]), + int5=tensor(headers['user_int'][:, 5]), + int6=tensor(headers['user_int'][:, 6]), + int7=tensor(headers['user_int'][:, 7]), ) acq_info = cls( idx=acq_idx, - acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']), - active_channels=tensor_2d(headers['active_channels']), - available_channels=tensor_2d(headers['available_channels']), - center_sample=tensor_2d(headers['center_sample']), - channel_mask=tensor_2d(headers['channel_mask']), - discard_post=tensor_2d(headers['discard_post']), - discard_pre=tensor_2d(headers['discard_pre']), - encoding_space_ref=tensor_2d(headers['encoding_space_ref']), - flags=tensor_2d(headers['flags']), - measurement_uid=tensor_2d(headers['measurement_uid']), - number_of_samples=tensor_2d(headers['number_of_samples']), + acquisition_time_stamp=tensor(headers['acquisition_time_stamp']), + active_channels=tensor(headers['active_channels']), + available_channels=tensor(headers['available_channels']), + center_sample=tensor(headers['center_sample']), + channel_mask=tensor(headers['channel_mask']), + discard_post=tensor(headers['discard_post']), + discard_pre=tensor(headers['discard_pre']), + encoding_space_ref=tensor(headers['encoding_space_ref']), + flags=tensor(headers['flags']), + measurement_uid=tensor(headers['measurement_uid']), + number_of_samples=tensor(headers['number_of_samples']), patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m), phase_dir=spatialdimension_2d(headers['phase_dir']), - physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']), + physiology_time_stamp=tensor(headers['physiology_time_stamp']), position=spatialdimension_2d(headers['position']).apply_(mm_to_m), read_dir=spatialdimension_2d(headers['read_dir']), - sample_time_us=tensor_2d(headers['sample_time_us']), - scan_counter=tensor_2d(headers['scan_counter']), + sample_time_us=tensor(headers['sample_time_us']), + scan_counter=tensor(headers['scan_counter']), slice_dir=spatialdimension_2d(headers['slice_dir']), - trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above + trajectory_dimensions=tensor(headers['trajectory_dimensions']).fill_(3), # see above user=user_values, - version=tensor_2d(headers['version']), + version=tensor(headers['version']), ) return acq_info