Skip to content

Commit

Permalink
acqinfo 5d
Browse files Browse the repository at this point in the history
ghstack-source-id: a9f96af0bf38962401845f1748d0fca0e5e30f63
ghstack-comment-id: 2441863549
Pull Request resolved: #490
  • Loading branch information
fzimmermann89 committed Oct 28, 2024
1 parent f17f7c9 commit 795a701
Showing 1 changed file with 133 additions and 87 deletions.
220 changes: 133 additions & 87 deletions src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias

import einops
import ismrmrd
import numpy as np
import torch
Expand All @@ -14,6 +16,9 @@
# Conversion functions for units
T = TypeVar('T', float, torch.Tensor)

# We use this for runtime dtype checking in the dataclasses
LongTensor: TypeAlias = torch.Tensor | torch.LongTensor | torch.IntTensor


def ms_to_s(ms: T) -> T:
"""Convert ms to s."""
Expand All @@ -25,64 +30,114 @@ def mm_to_m(m: T) -> T:
return m / 1000


class _InvariantsAcqInfo:
__slots__ = '__broadcasted_shape', '__typehints'
__broadcasted_shape: torch.Size | None

def __post_init__(self):
self._check_invariants()

def _check_invariants(self):
shapes = []
for name in self.__slots__:
expected_type = self.__annotations__[name]
value = getattr(self, name)
if not isinstance(value, expected_type):
raise TypeError(f'{name} must be of type {expected_type}, got {type(value)} instead')
if hasattr(value, 'shape'):
# isinstance(value, torch.Tensor | SpatialDimension | Rotation):
shape = value.shape
if len(shape) < 5:
raise ValueError(f'{name} must have at least 5 dimensions')
if shape[-1] != 1:
raise ValueError(f'{name} must have a k0 dimension of size 1')
if shape[-4] != 1:
raise ValueError(f'{name} must have a coil dimension of size 1')
shapes.append(value.shape)

if hasattr(value, 'dtype') and value.dtype.is_complex:
raise ValueError(f'{name} must not be complex.')
if expected_type == LongTensor and value.dtype not in (
torch.int64,
torch.uint64,
torch.int32,
torch.uint32,
):
raise ValueError(f'{name} must be integer.')

elif hasattr(value, 'broadcasted_shape'):
shapes.append(value.broadcasted_shape)
try:
broadcasted_shape = torch.broadcast_shapes(*shapes)
except RuntimeError:
raise ValueError(f'The Acquisition information tensors {self.__slots__} must be broadcastable.') from None

self.__broadcasted_shape = broadcasted_shape

@property
def broadcasted_shape(self) -> torch.Size:
assert self.__broadcasted_shape is not None # noqa: S101 # mypy hint
return self.__broadcasted_shape


@dataclass(slots=True)
class AcqIdx(MoveDataMixin):
class AcqIdx(MoveDataMixin, _InvariantsAcqInfo):
"""Acquisition index for each readout."""

k1: torch.Tensor
k1: LongTensor
"""First phase encoding."""

k2: torch.Tensor
k2: LongTensor
"""Second phase encoding."""

average: torch.Tensor
average: LongTensor
"""Signal average."""

slice: torch.Tensor
slice: LongTensor
"""Slice number (multi-slice 2D)."""

contrast: torch.Tensor
contrast: LongTensor
"""Echo number in multi-echo."""

phase: torch.Tensor
phase: LongTensor
"""Cardiac phase."""

repetition: torch.Tensor
repetition: LongTensor
"""Counter in repeated/dynamic acquisitions."""

set: torch.Tensor
set: LongTensor
"""Sets of different preparation, e.g. flow encoding, diffusion weighting."""

segment: torch.Tensor
segment: LongTensor
"""Counter for segmented acquisitions."""

user0: torch.Tensor
user0: LongTensor
"""User index 0."""

user1: torch.Tensor
user1: LongTensor
"""User index 1."""

user2: torch.Tensor
user2: LongTensor
"""User index 2."""

user3: torch.Tensor
user3: LongTensor
"""User index 3."""

user4: torch.Tensor
user4: LongTensor
"""User index 4."""

user5: torch.Tensor
user5: LongTensor
"""User index 5."""

user6: torch.Tensor
user6: LongTensor
"""User index 6."""

user7: torch.Tensor
user7: LongTensor
"""User index 7."""


@dataclass(slots=True)
class UserValues(MoveDataMixin):
class UserValues(MoveDataMixin, _InvariantsAcqInfo):
"""User-defined values for each readout."""

float0: torch.Tensor
Expand All @@ -109,28 +164,28 @@ class UserValues(MoveDataMixin):
float7: torch.Tensor
"""User float 7."""

int0: torch.Tensor
int0: LongTensor
"""User int 0."""

int1: torch.Tensor
int1: LongTensor
"""User int 1."""

int2: torch.Tensor
int2: LongTensor
"""User int 2."""

int3: torch.Tensor
int3: LongTensor
"""User int 3."""

int4: torch.Tensor
int4: LongTensor
"""User int 4."""

int5: torch.Tensor
int5: LongTensor
"""User int 5."""

int6: torch.Tensor
int6: LongTensor
"""User int 6."""

int7: torch.Tensor
int7: LongTensor
"""User int 7."""


Expand All @@ -144,34 +199,34 @@ class AcqInfo(MoveDataMixin):
acquisition_time_stamp: torch.Tensor
"""Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)"""

active_channels: torch.Tensor
active_channels: LongTensor
"""Number of active receiver coil elements."""

available_channels: torch.Tensor
available_channels: LongTensor
"""Number of available receiver coil elements."""

center_sample: torch.Tensor
center_sample: LongTensor
"""Index of the readout sample corresponding to k-space center (zero indexed)."""

channel_mask: torch.Tensor
channel_mask: LongTensor
"""Bit mask indicating active coils (64*16 = 1024 bits)."""

discard_post: torch.Tensor
discard_post: LongTensor
"""Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events)."""

discard_pre: torch.Tensor
discard_pre: LongTensor
"""Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)"""

encoding_space_ref: torch.Tensor
encoding_space_ref: LongTensor
"""Indexed reference to the encoding spaces enumerated in the MRD (xml) header."""

flags: torch.Tensor
flags: LongTensor
"""A bit mask of common attributes applicable to individual acquisition readouts."""

measurement_uid: torch.Tensor
measurement_uid: LongTensor
"""Unique ID corresponding to the readout."""

number_of_samples: torch.Tensor
number_of_samples: LongTensor
"""Number of sample points per readout (readouts may have different number of sample points)."""

patient_table_position: SpatialDimension[torch.Tensor]
Expand All @@ -192,20 +247,20 @@ class AcqInfo(MoveDataMixin):
sample_time_us: torch.Tensor
"""Readout bandwidth, as time between samples [us]."""

scan_counter: torch.Tensor
scan_counter: LongTensor
"""Zero-indexed incrementing counter for readouts."""

slice_dir: SpatialDimension[torch.Tensor]
"""Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir."""

trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists.
trajectory_dimensions: LongTensor # =3. We only support 3D Trajectories: kz always exists.
"""Dimensionality of the k-space trajectory vector."""

user: UserValues
"""User-defined values."""
"""User-defined values for each readout."""

version: torch.Tensor
"""Major version number."""
version: LongTensor
"""Major version number"""

@classmethod
def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self:
Expand Down Expand Up @@ -236,25 +291,17 @@ def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition])
idx = headers['idx']

def tensor(data: np.ndarray) -> torch.Tensor:
"""Convert to tensor with shape (other=1, coil=1, k2=1, k1=n, k0=1)."""
# we have to convert first as pytoch cant create tensors from np.uint16 arrays
# we use int32 for uint16 and int64 for uint32 to fit largest values.
match data.dtype:
case np.uint16:
data = data.astype(np.int32)
case np.uint32 | np.uint64:
data = data.astype(np.int64)
# Remove any uncessary dimensions
return torch.tensor(np.squeeze(data))

def tensor_2d(data: np.ndarray) -> torch.Tensor:
# Convert tensor to torch dtypes and ensure it is atleast 2D
data_tensor = tensor(data)
# Ensure that data is (k1*k2*other, >=1)
if data_tensor.ndim == 1:
data_tensor = data_tensor[:, None]
elif data_tensor.ndim == 0:
data_tensor = data_tensor[None, None]
return data_tensor
tensor = torch.from_numpy(data)
tensor = einops.repeat(tensor, '... -> other coil k2 (...) k0', other=1, coil=1, k2=1, k0=1)
return tensor

def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
# Ensure spatial dimension is (k1*k2*other, 1, 3)
Expand Down Expand Up @@ -283,49 +330,48 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
user6=tensor(idx['user'][:, 6]),
user7=tensor(idx['user'][:, 7]),
)

user_values = UserValues(
float0=tensor_2d(headers['user_float'][:, 0]),
float1=tensor_2d(headers['user_float'][:, 1]),
float2=tensor_2d(headers['user_float'][:, 2]),
float3=tensor_2d(headers['user_float'][:, 3]),
float4=tensor_2d(headers['user_float'][:, 4]),
float5=tensor_2d(headers['user_float'][:, 5]),
float6=tensor_2d(headers['user_float'][:, 6]),
float7=tensor_2d(headers['user_float'][:, 7]),
int0=tensor_2d(headers['user_int'][:, 0]),
int1=tensor_2d(headers['user_int'][:, 1]),
int2=tensor_2d(headers['user_int'][:, 2]),
int3=tensor_2d(headers['user_int'][:, 3]),
int4=tensor_2d(headers['user_int'][:, 4]),
int5=tensor_2d(headers['user_int'][:, 5]),
int6=tensor_2d(headers['user_int'][:, 6]),
int7=tensor_2d(headers['user_int'][:, 7]),
float0=tensor(headers['user_float'][:, 0]),
float1=tensor(headers['user_float'][:, 1]),
float2=tensor(headers['user_float'][:, 2]),
float3=tensor(headers['user_float'][:, 3]),
float4=tensor(headers['user_float'][:, 4]),
float5=tensor(headers['user_float'][:, 5]),
float6=tensor(headers['user_float'][:, 6]),
float7=tensor(headers['user_float'][:, 7]),
int0=tensor(headers['user_int'][:, 0]),
int1=tensor(headers['user_int'][:, 1]),
int2=tensor(headers['user_int'][:, 2]),
int3=tensor(headers['user_int'][:, 3]),
int4=tensor(headers['user_int'][:, 4]),
int5=tensor(headers['user_int'][:, 5]),
int6=tensor(headers['user_int'][:, 6]),
int7=tensor(headers['user_int'][:, 7]),
)

acq_info = cls(
idx=acq_idx,
acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']),
active_channels=tensor_2d(headers['active_channels']),
available_channels=tensor_2d(headers['available_channels']),
center_sample=tensor_2d(headers['center_sample']),
channel_mask=tensor_2d(headers['channel_mask']),
discard_post=tensor_2d(headers['discard_post']),
discard_pre=tensor_2d(headers['discard_pre']),
encoding_space_ref=tensor_2d(headers['encoding_space_ref']),
flags=tensor_2d(headers['flags']),
measurement_uid=tensor_2d(headers['measurement_uid']),
number_of_samples=tensor_2d(headers['number_of_samples']),
acquisition_time_stamp=tensor(headers['acquisition_time_stamp']),
active_channels=tensor(headers['active_channels']),
available_channels=tensor(headers['available_channels']),
center_sample=tensor(headers['center_sample']),
channel_mask=tensor(headers['channel_mask']),
discard_post=tensor(headers['discard_post']),
discard_pre=tensor(headers['discard_pre']),
encoding_space_ref=tensor(headers['encoding_space_ref']),
flags=tensor(headers['flags']),
measurement_uid=tensor(headers['measurement_uid']),
number_of_samples=tensor(headers['number_of_samples']),
patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m),
phase_dir=spatialdimension_2d(headers['phase_dir']),
physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']),
physiology_time_stamp=tensor(headers['physiology_time_stamp']),
position=spatialdimension_2d(headers['position']).apply_(mm_to_m),
read_dir=spatialdimension_2d(headers['read_dir']),
sample_time_us=tensor_2d(headers['sample_time_us']),
scan_counter=tensor_2d(headers['scan_counter']),
sample_time_us=tensor(headers['sample_time_us']),
scan_counter=tensor(headers['scan_counter']),
slice_dir=spatialdimension_2d(headers['slice_dir']),
trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above
trajectory_dimensions=tensor(headers['trajectory_dimensions']).fill_(3), # see above
user=user_values,
version=tensor_2d(headers['version']),
version=tensor(headers['version']),
)
return acq_info

0 comments on commit 795a701

Please sign in to comment.