Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

acqinfo 5d #490

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 133 additions & 87 deletions src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias

import einops
import ismrmrd
import numpy as np
import torch
Expand All @@ -14,6 +16,9 @@
# Conversion functions for units
T = TypeVar('T', float, torch.Tensor)

# We use this for runtime dtype checking in the dataclasses
LongTensor: TypeAlias = torch.Tensor | torch.LongTensor | torch.IntTensor


def ms_to_s(ms: T) -> T:
"""Convert ms to s."""
Expand All @@ -25,64 +30,114 @@ def mm_to_m(m: T) -> T:
return m / 1000


class _InvariantsAcqInfo:
__slots__ = '__broadcasted_shape', '__typehints'
__broadcasted_shape: torch.Size | None

def __post_init__(self):
self._check_invariants()

def _check_invariants(self):
shapes = []
for name in self.__slots__:
expected_type = self.__annotations__[name]
value = getattr(self, name)
if not isinstance(value, expected_type):
raise TypeError(f'{name} must be of type {expected_type}, got {type(value)} instead')
if hasattr(value, 'shape'):
# isinstance(value, torch.Tensor | SpatialDimension | Rotation):
shape = value.shape
if len(shape) < 5:
raise ValueError(f'{name} must have at least 5 dimensions')
if shape[-1] != 1:
raise ValueError(f'{name} must have a k0 dimension of size 1')
if shape[-4] != 1:
raise ValueError(f'{name} must have a coil dimension of size 1')
shapes.append(value.shape)

if hasattr(value, 'dtype') and value.dtype.is_complex:
raise ValueError(f'{name} must not be complex.')
if expected_type == LongTensor and value.dtype not in (
torch.int64,
torch.uint64,
torch.int32,
torch.uint32,
):
raise ValueError(f'{name} must be integer.')

elif hasattr(value, 'broadcasted_shape'):
shapes.append(value.broadcasted_shape)
try:
broadcasted_shape = torch.broadcast_shapes(*shapes)
except RuntimeError:
raise ValueError(f'The Acquisition information tensors {self.__slots__} must be broadcastable.') from None

self.__broadcasted_shape = broadcasted_shape

@property
def broadcasted_shape(self) -> torch.Size:
assert self.__broadcasted_shape is not None # noqa: S101 # mypy hint
return self.__broadcasted_shape


@dataclass(slots=True)
class AcqIdx(MoveDataMixin):
class AcqIdx(MoveDataMixin, _InvariantsAcqInfo):
"""Acquisition index for each readout."""

k1: torch.Tensor
k1: LongTensor
"""First phase encoding."""

k2: torch.Tensor
k2: LongTensor
"""Second phase encoding."""

average: torch.Tensor
average: LongTensor
"""Signal average."""

slice: torch.Tensor
slice: LongTensor
"""Slice number (multi-slice 2D)."""

contrast: torch.Tensor
contrast: LongTensor
"""Echo number in multi-echo."""

phase: torch.Tensor
phase: LongTensor
"""Cardiac phase."""

repetition: torch.Tensor
repetition: LongTensor
"""Counter in repeated/dynamic acquisitions."""

set: torch.Tensor
set: LongTensor
"""Sets of different preparation, e.g. flow encoding, diffusion weighting."""

segment: torch.Tensor
segment: LongTensor
"""Counter for segmented acquisitions."""

user0: torch.Tensor
user0: LongTensor
"""User index 0."""

user1: torch.Tensor
user1: LongTensor
"""User index 1."""

user2: torch.Tensor
user2: LongTensor
"""User index 2."""

user3: torch.Tensor
user3: LongTensor
"""User index 3."""

user4: torch.Tensor
user4: LongTensor
"""User index 4."""

user5: torch.Tensor
user5: LongTensor
"""User index 5."""

user6: torch.Tensor
user6: LongTensor
"""User index 6."""

user7: torch.Tensor
user7: LongTensor
"""User index 7."""


@dataclass(slots=True)
class UserValues(MoveDataMixin):
class UserValues(MoveDataMixin, _InvariantsAcqInfo):
"""User-defined values for each readout."""

float0: torch.Tensor
Expand All @@ -109,28 +164,28 @@ class UserValues(MoveDataMixin):
float7: torch.Tensor
"""User float 7."""

int0: torch.Tensor
int0: LongTensor
"""User int 0."""

int1: torch.Tensor
int1: LongTensor
"""User int 1."""

int2: torch.Tensor
int2: LongTensor
"""User int 2."""

int3: torch.Tensor
int3: LongTensor
"""User int 3."""

int4: torch.Tensor
int4: LongTensor
"""User int 4."""

int5: torch.Tensor
int5: LongTensor
"""User int 5."""

int6: torch.Tensor
int6: LongTensor
"""User int 6."""

int7: torch.Tensor
int7: LongTensor
"""User int 7."""


Expand All @@ -144,34 +199,34 @@ class AcqInfo(MoveDataMixin):
acquisition_time_stamp: torch.Tensor
"""Clock time stamp. Not in s but in vendor-specific time units (e.g. 2.5ms for Siemens)"""

active_channels: torch.Tensor
active_channels: LongTensor
"""Number of active receiver coil elements."""

available_channels: torch.Tensor
available_channels: LongTensor
"""Number of available receiver coil elements."""

center_sample: torch.Tensor
center_sample: LongTensor
"""Index of the readout sample corresponding to k-space center (zero indexed)."""

channel_mask: torch.Tensor
channel_mask: LongTensor
"""Bit mask indicating active coils (64*16 = 1024 bits)."""

discard_post: torch.Tensor
discard_post: LongTensor
"""Number of readout samples to be discarded at the end (e.g. if the ADC is active during gradient events)."""

discard_pre: torch.Tensor
discard_pre: LongTensor
"""Number of readout samples to be discarded at the beginning (e.g. if the ADC is active during gradient events)"""

encoding_space_ref: torch.Tensor
encoding_space_ref: LongTensor
"""Indexed reference to the encoding spaces enumerated in the MRD (xml) header."""

flags: torch.Tensor
flags: LongTensor
"""A bit mask of common attributes applicable to individual acquisition readouts."""

measurement_uid: torch.Tensor
measurement_uid: LongTensor
"""Unique ID corresponding to the readout."""

number_of_samples: torch.Tensor
number_of_samples: LongTensor
"""Number of sample points per readout (readouts may have different number of sample points)."""

patient_table_position: SpatialDimension[torch.Tensor]
Expand All @@ -192,20 +247,20 @@ class AcqInfo(MoveDataMixin):
sample_time_us: torch.Tensor
"""Readout bandwidth, as time between samples [us]."""

scan_counter: torch.Tensor
scan_counter: LongTensor
"""Zero-indexed incrementing counter for readouts."""

slice_dir: SpatialDimension[torch.Tensor]
"""Directional cosine of slice normal, i.e. cross-product of read_dir and phase_dir."""

trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists.
trajectory_dimensions: LongTensor # =3. We only support 3D Trajectories: kz always exists.
"""Dimensionality of the k-space trajectory vector."""

user: UserValues
"""User-defined values."""
"""User-defined values for each readout."""

version: torch.Tensor
"""Major version number."""
version: LongTensor
"""Major version number"""

@classmethod
def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition]) -> Self:
Expand Down Expand Up @@ -236,25 +291,17 @@ def from_ismrmrd_acquisitions(cls, acquisitions: Sequence[ismrmrd.Acquisition])
idx = headers['idx']

def tensor(data: np.ndarray) -> torch.Tensor:
"""Convert to tensor with shape (other=1, coil=1, k2=1, k1=n, k0=1)."""
# we have to convert first as pytoch cant create tensors from np.uint16 arrays
# we use int32 for uint16 and int64 for uint32 to fit largest values.
match data.dtype:
case np.uint16:
data = data.astype(np.int32)
case np.uint32 | np.uint64:
data = data.astype(np.int64)
# Remove any uncessary dimensions
return torch.tensor(np.squeeze(data))

def tensor_2d(data: np.ndarray) -> torch.Tensor:
# Convert tensor to torch dtypes and ensure it is atleast 2D
data_tensor = tensor(data)
# Ensure that data is (k1*k2*other, >=1)
if data_tensor.ndim == 1:
data_tensor = data_tensor[:, None]
elif data_tensor.ndim == 0:
data_tensor = data_tensor[None, None]
return data_tensor
tensor = torch.from_numpy(data)
tensor = einops.repeat(tensor, '... -> other coil k2 (...) k0', other=1, coil=1, k2=1, k0=1)
return tensor

def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
# Ensure spatial dimension is (k1*k2*other, 1, 3)
Expand Down Expand Up @@ -283,49 +330,48 @@ def spatialdimension_2d(data: np.ndarray) -> SpatialDimension[torch.Tensor]:
user6=tensor(idx['user'][:, 6]),
user7=tensor(idx['user'][:, 7]),
)

user_values = UserValues(
float0=tensor_2d(headers['user_float'][:, 0]),
float1=tensor_2d(headers['user_float'][:, 1]),
float2=tensor_2d(headers['user_float'][:, 2]),
float3=tensor_2d(headers['user_float'][:, 3]),
float4=tensor_2d(headers['user_float'][:, 4]),
float5=tensor_2d(headers['user_float'][:, 5]),
float6=tensor_2d(headers['user_float'][:, 6]),
float7=tensor_2d(headers['user_float'][:, 7]),
int0=tensor_2d(headers['user_int'][:, 0]),
int1=tensor_2d(headers['user_int'][:, 1]),
int2=tensor_2d(headers['user_int'][:, 2]),
int3=tensor_2d(headers['user_int'][:, 3]),
int4=tensor_2d(headers['user_int'][:, 4]),
int5=tensor_2d(headers['user_int'][:, 5]),
int6=tensor_2d(headers['user_int'][:, 6]),
int7=tensor_2d(headers['user_int'][:, 7]),
float0=tensor(headers['user_float'][:, 0]),
float1=tensor(headers['user_float'][:, 1]),
float2=tensor(headers['user_float'][:, 2]),
float3=tensor(headers['user_float'][:, 3]),
float4=tensor(headers['user_float'][:, 4]),
float5=tensor(headers['user_float'][:, 5]),
float6=tensor(headers['user_float'][:, 6]),
float7=tensor(headers['user_float'][:, 7]),
int0=tensor(headers['user_int'][:, 0]),
int1=tensor(headers['user_int'][:, 1]),
int2=tensor(headers['user_int'][:, 2]),
int3=tensor(headers['user_int'][:, 3]),
int4=tensor(headers['user_int'][:, 4]),
int5=tensor(headers['user_int'][:, 5]),
int6=tensor(headers['user_int'][:, 6]),
int7=tensor(headers['user_int'][:, 7]),
)

acq_info = cls(
idx=acq_idx,
acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']),
active_channels=tensor_2d(headers['active_channels']),
available_channels=tensor_2d(headers['available_channels']),
center_sample=tensor_2d(headers['center_sample']),
channel_mask=tensor_2d(headers['channel_mask']),
discard_post=tensor_2d(headers['discard_post']),
discard_pre=tensor_2d(headers['discard_pre']),
encoding_space_ref=tensor_2d(headers['encoding_space_ref']),
flags=tensor_2d(headers['flags']),
measurement_uid=tensor_2d(headers['measurement_uid']),
number_of_samples=tensor_2d(headers['number_of_samples']),
acquisition_time_stamp=tensor(headers['acquisition_time_stamp']),
active_channels=tensor(headers['active_channels']),
available_channels=tensor(headers['available_channels']),
center_sample=tensor(headers['center_sample']),
channel_mask=tensor(headers['channel_mask']),
discard_post=tensor(headers['discard_post']),
discard_pre=tensor(headers['discard_pre']),
encoding_space_ref=tensor(headers['encoding_space_ref']),
flags=tensor(headers['flags']),
measurement_uid=tensor(headers['measurement_uid']),
number_of_samples=tensor(headers['number_of_samples']),
patient_table_position=spatialdimension_2d(headers['patient_table_position']).apply_(mm_to_m),
phase_dir=spatialdimension_2d(headers['phase_dir']),
physiology_time_stamp=tensor_2d(headers['physiology_time_stamp']),
physiology_time_stamp=tensor(headers['physiology_time_stamp']),
position=spatialdimension_2d(headers['position']).apply_(mm_to_m),
read_dir=spatialdimension_2d(headers['read_dir']),
sample_time_us=tensor_2d(headers['sample_time_us']),
scan_counter=tensor_2d(headers['scan_counter']),
sample_time_us=tensor(headers['sample_time_us']),
scan_counter=tensor(headers['scan_counter']),
slice_dir=spatialdimension_2d(headers['slice_dir']),
trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above
trajectory_dimensions=tensor(headers['trajectory_dimensions']).fill_(3), # see above
user=user_values,
version=tensor_2d(headers['version']),
version=tensor(headers['version']),
)
return acq_info
Loading