Skip to content

Commit

Permalink
UserValues
Browse files Browse the repository at this point in the history
ghstack-source-id: 110db9a247941546b19e5797df087b11006a37ff
ghstack-comment-id: 2441645438
Pull Request resolved: #488
  • Loading branch information
fzimmermann89 committed Oct 28, 2024
1 parent bbdc974 commit ff9ade1
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 11 deletions.
82 changes: 75 additions & 7 deletions src/mrpro/data/AcqInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,59 @@ class AcqIdx(MoveDataMixin):
"""User index 7."""


@dataclass(slots=True)
class UserValues(MoveDataMixin):
"""User-defined values for each readout."""

float0: torch.Tensor
"""User float 0."""

float1: torch.Tensor
"""User float 1."""

float2: torch.Tensor
"""User float 2."""

float3: torch.Tensor
"""User float 3."""

float4: torch.Tensor
"""User float 4."""

float5: torch.Tensor
"""User float 5."""

float6: torch.Tensor
"""User float 6."""

float7: torch.Tensor
"""User float 7."""

int0: torch.Tensor
"""User int 0."""

int1: torch.Tensor
"""User int 1."""

int2: torch.Tensor
"""User int 2."""

int3: torch.Tensor
"""User int 3."""

int4: torch.Tensor
"""User int 4."""

int5: torch.Tensor
"""User int 5."""

int6: torch.Tensor
"""User int 6."""

int7: torch.Tensor
"""User int 7."""


@dataclass(slots=True)
class AcqInfo(MoveDataMixin):
"""Acquisition information for each readout."""
Expand Down Expand Up @@ -148,11 +201,8 @@ class AcqInfo(MoveDataMixin):
trajectory_dimensions: torch.Tensor # =3. We only support 3D Trajectories: kz always exists.
"""Dimensionality of the k-space trajectory vector."""

user_float: torch.Tensor
"""User-defined float parameters."""

user_int: torch.Tensor
"""User-defined int parameters."""
user: UserValues
"""User-defined values."""

version: torch.Tensor
"""Major version number."""
Expand Down Expand Up @@ -238,6 +288,25 @@ def spatialdimension_2d(
user7=tensor(idx['user'][:, 7]),
)

user_values = UserValues(
float0=tensor_2d(headers['user_float'][:, 0]),
float1=tensor_2d(headers['user_float'][:, 1]),
float2=tensor_2d(headers['user_float'][:, 2]),
float3=tensor_2d(headers['user_float'][:, 3]),
float4=tensor_2d(headers['user_float'][:, 4]),
float5=tensor_2d(headers['user_float'][:, 5]),
float6=tensor_2d(headers['user_float'][:, 6]),
float7=tensor_2d(headers['user_float'][:, 7]),
int0=tensor_2d(headers['user_int'][:, 0]),
int1=tensor_2d(headers['user_int'][:, 1]),
int2=tensor_2d(headers['user_int'][:, 2]),
int3=tensor_2d(headers['user_int'][:, 3]),
int4=tensor_2d(headers['user_int'][:, 4]),
int5=tensor_2d(headers['user_int'][:, 5]),
int6=tensor_2d(headers['user_int'][:, 6]),
int7=tensor_2d(headers['user_int'][:, 7]),
)

acq_info = cls(
idx=acq_idx,
acquisition_time_stamp=tensor_2d(headers['acquisition_time_stamp']),
Expand All @@ -260,8 +329,7 @@ def spatialdimension_2d(
scan_counter=tensor_2d(headers['scan_counter']),
slice_dir=spatialdimension_2d(headers['slice_dir']),
trajectory_dimensions=tensor_2d(headers['trajectory_dimensions']).fill_(3), # see above
user_float=tensor_2d(headers['user_float']),
user_int=tensor_2d(headers['user_int']),
user=user_values,
version=tensor_2d(headers['version']),
)
return acq_info
8 changes: 4 additions & 4 deletions tests/data/test_kdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def test_KData_to_complex128_header(ismrmrd_cart):
"""Change KData dtype complex128: test header"""
kdata = KData.from_file(ismrmrd_cart.filename, DummyTrajectory())
kdata_complex128 = kdata.to(dtype=torch.complex128)
assert kdata_complex128.header.acq_info.user_float.dtype == torch.float64
assert kdata_complex128.header.acq_info.user_int.dtype == torch.int32
assert kdata_complex128.header.acq_info.user.float0.dtype == torch.float64
assert kdata_complex128.header.acq_info.user.int0.dtype == torch.int32


@pytest.mark.cuda
Expand All @@ -254,7 +254,7 @@ def test_KData_cuda(ismrmrd_cart):
assert kdata_cuda.traj.kz.is_cuda
assert kdata_cuda.traj.ky.is_cuda
assert kdata_cuda.traj.kx.is_cuda
assert kdata_cuda.header.acq_info.user_int.is_cuda
assert kdata_cuda.header.acq_info.user.int0.is_cuda
assert kdata_cuda.device == torch.device(torch.cuda.current_device())
assert kdata_cuda.header.acq_info.device == torch.device(torch.cuda.current_device())
assert kdata_cuda.is_cuda
Expand All @@ -270,7 +270,7 @@ def test_KData_cpu(ismrmrd_cart):
assert kdata_cpu.traj.kz.is_cpu
assert kdata_cpu.traj.ky.is_cpu
assert kdata_cpu.traj.kx.is_cpu
assert kdata_cpu.header.acq_info.user_int.is_cpu
assert kdata_cpu.header.acq_info.user.int0.is_cpu
assert kdata_cpu.device == torch.device('cpu')
assert kdata_cpu.header.acq_info.device == torch.device('cpu')

Expand Down

0 comments on commit ff9ade1

Please sign in to comment.