Skip to content

Commit

Permalink
Merge branch 'main' into OpNormIssue
Browse files Browse the repository at this point in the history
  • Loading branch information
Stef-Martin authored Nov 13, 2024
2 parents 89a4b5a + 798f1e2 commit 627300f
Show file tree
Hide file tree
Showing 16 changed files with 493 additions and 240 deletions.
25 changes: 24 additions & 1 deletion src/mrpro/data/_kdata/KData.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mrpro.data._kdata.KDataRemoveOsMixin import KDataRemoveOsMixin
from mrpro.data._kdata.KDataSelectMixin import KDataSelectMixin
from mrpro.data._kdata.KDataSplitMixin import KDataSplitMixin
from mrpro.data.acq_filters import is_image_acquisition
from mrpro.data.acq_filters import has_n_coils, is_image_acquisition
from mrpro.data.AcqInfo import AcqInfo, rearrange_acq_info_fields
from mrpro.data.EncodingLimits import Limits
from mrpro.data.KHeader import KHeader
Expand Down Expand Up @@ -110,6 +110,29 @@ def from_file(
modification_time = datetime.datetime.fromtimestamp(mtime)

acquisitions = [acq for acq in acquisitions if acquisition_filter_criterion(acq)]

# we need the same number of receiver coils for all acquisitions
n_coils_available = {acq.data.shape[0] for acq in acquisitions}
if len(n_coils_available) > 1:
if (
ismrmrd_header.acquisitionSystemInformation is not None
and ismrmrd_header.acquisitionSystemInformation.receiverChannels is not None
):
n_coils = int(ismrmrd_header.acquisitionSystemInformation.receiverChannels)
else:
# most likely, highest number of elements are the coils used for imaging
n_coils = int(max(n_coils_available))

warnings.warn(
f'Acquisitions with different number {n_coils_available} of receiver coil elements detected. '
f'Data with {n_coils} receiver coil elements will be used.',
stacklevel=1,
)
acquisitions = [acq for acq in acquisitions if has_n_coils(n_coils, acq)]

if not acquisitions:
raise ValueError('No acquisitions meeting the given filter criteria were found.')

kdata = torch.stack([torch.as_tensor(acq.data, dtype=torch.complex64) for acq in acquisitions])

acqinfo = AcqInfo.from_ismrmrd_acquisitions(acquisitions)
Expand Down
17 changes: 17 additions & 0 deletions src/mrpro/data/acq_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,20 @@ def is_coil_calibration_acquisition(acquisition: ismrmrd.Acquisition) -> bool:
"""
coil_calibration_flag = AcqFlags.ACQ_IS_PARALLEL_CALIBRATION | AcqFlags.ACQ_IS_PARALLEL_CALIBRATION_AND_IMAGING
return coil_calibration_flag.value & acquisition.flags


def has_n_coils(n_coils: int, acquisition: ismrmrd.Acquisition) -> bool:
"""Test if acquisitions was obtained with a certain number of receiver coils.
Parameters
----------
n_coils
number of receiver coils
acquisition
ISMRMRD acquisition
Returns
-------
True if the acquisition was obtained with n_coils receiver coils
"""
return acquisition.data.shape[0] == n_coils
98 changes: 84 additions & 14 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Cartesian Sampling Operator."""

import warnings

import torch
from einops import rearrange, repeat

from mrpro.data.enums import TrajType
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.utils.reshape import unsqueeze_left


class CartesianSamplingOp(LinearOperator):
Expand Down Expand Up @@ -64,10 +67,35 @@ def __init__(self, encoding_matrix: SpatialDimension[int], traj: KTrajectory) ->
# 1D indices into a flattened tensor.
kidx = kz_idx * sorted_grid_shape.y * sorted_grid_shape.x + ky_idx * sorted_grid_shape.x + kx_idx
kidx = rearrange(kidx, '... kz ky kx -> ... 1 (kz ky kx)')

# check that all points are inside the encoding matrix
inside_encoding_matrix = (
((kx_idx >= 0) & (kx_idx < sorted_grid_shape.x))
& ((ky_idx >= 0) & (ky_idx < sorted_grid_shape.y))
& ((kz_idx >= 0) & (kz_idx < sorted_grid_shape.z))
)
if not torch.all(inside_encoding_matrix):
warnings.warn(
'K-space points lie outside of the encoding_matrix and will be ignored.'
' Increase the encoding_matrix to include these points.',
stacklevel=2,
)

inside_encoding_matrix = rearrange(inside_encoding_matrix, '... kz ky kx -> ... 1 (kz ky kx)')
inside_encoding_matrix_idx = inside_encoding_matrix.nonzero(as_tuple=True)[-1]
inside_encoding_matrix_idx = torch.reshape(inside_encoding_matrix_idx, (*kidx.shape[:-1], -1))
self.register_buffer('_inside_encoding_matrix_idx', inside_encoding_matrix_idx)
kidx = torch.take_along_dim(kidx, inside_encoding_matrix_idx, dim=-1)
else:
self._inside_encoding_matrix_idx: torch.Tensor | None = None

self.register_buffer('_fft_idx', kidx)

# we can skip the indexing if the data is already sorted
self._needs_indexing = (
not torch.all(torch.diff(kidx) == 1) or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
not torch.all(torch.diff(kidx) == 1)
or traj.broadcasted_shape[-3:] != sorted_grid_shape.zyx
or self._inside_encoding_matrix_idx is not None
)

self._trajectory_shape = traj.broadcasted_shape
Expand All @@ -93,8 +121,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
return (x,)

x_kflat = rearrange(x, '... coil k2_enc k1_enc k0_enc -> ... coil (k2_enc k1_enc k0_enc)')
# take_along_dim does broadcast, so no need for extending here
x_indexed = torch.take_along_dim(x_kflat, self._fft_idx, dim=-1)
# take_along_dim broadcasts, but needs the same number of dimensions
idx = unsqueeze_left(self._fft_idx, x_kflat.ndim - self._fft_idx.ndim)
x_inside_encoding_matrix = torch.take_along_dim(x_kflat, idx, dim=-1)

if self._inside_encoding_matrix_idx is None:
# all trajectory points are inside the encoding matrix
x_indexed = x_inside_encoding_matrix
else:
# we need to add zeros
x_indexed = self._broadcast_and_scatter_along_last_dim(
x_inside_encoding_matrix,
self._trajectory_shape[-1] * self._trajectory_shape[-2] * self._trajectory_shape[-3],
self._inside_encoding_matrix_idx,
)

# reshape to (... other coil, k2, k1, k0)
x_reshaped = x_indexed.reshape(x.shape[:-3] + self._trajectory_shape[-3:])

Expand All @@ -120,18 +161,13 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:

y_kflat = rearrange(y, '... coil k2 k1 k0 -> ... coil (k2 k1 k0)')

# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(self._fft_idx.shape[:-1], y_kflat.shape[:-1])
idx_expanded = torch.broadcast_to(self._fft_idx, (*broadcast_shape, self._fft_idx.shape[-1]))
if self._inside_encoding_matrix_idx is not None:
idx = unsqueeze_left(self._inside_encoding_matrix_idx, y_kflat.ndim - self._inside_encoding_matrix_idx.ndim)
y_kflat = torch.take_along_dim(y_kflat, idx, dim=-1)

# although scatter_ is inplace, this will not cause issues with autograd, as self
# is always constant zero and gradients w.r.t. src work as expected.
y_scattered = torch.zeros(
*broadcast_shape,
self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x,
dtype=y.dtype,
device=y.device,
).scatter_(dim=-1, index=idx_expanded, src=y_kflat)
y_scattered = self._broadcast_and_scatter_along_last_dim(
y_kflat, self._sorted_grid_shape.z * self._sorted_grid_shape.y * self._sorted_grid_shape.x, self._fft_idx
)

# reshape to ..., other, coil, k2_enc, k1_enc, k0_enc
y_reshaped = y_scattered.reshape(
Expand All @@ -142,3 +178,37 @@ def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
)

return (y_reshaped,)

@staticmethod
def _broadcast_and_scatter_along_last_dim(
data_to_scatter: torch.Tensor, n_last_dim: int, scatter_index: torch.Tensor
) -> torch.Tensor:
"""Broadcast scatter index and scatter into zero tensor.
Parameters
----------
data_to_scatter
Data to be scattered at indices scatter_index
n_last_dim
Number of data points in last dimension
scatter_index
Indices describing where to scatter data
Returns
-------
Data scattered into tensor along scatter_index
"""
# scatter does not broadcast, so we need to manually broadcast the indices
broadcast_shape = torch.broadcast_shapes(scatter_index.shape[:-1], data_to_scatter.shape[:-1])
idx_expanded = torch.broadcast_to(scatter_index, (*broadcast_shape, scatter_index.shape[-1]))

# although scatter_ is inplace, this will not cause issues with autograd, as self
# is always constant zero and gradients w.r.t. src work as expected.
data_scattered = torch.zeros(
*broadcast_shape,
n_last_dim,
dtype=data_to_scatter.dtype,
device=data_to_scatter.device,
).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter)

return data_scattered
43 changes: 27 additions & 16 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from mrpro.data.enums import TrajType
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp
from mrpro.operators.FastFourierOp import FastFourierOp
from mrpro.operators.LinearOperator import LinearOperator


class FourierOp(LinearOperator):
class FourierOp(LinearOperator, adjoint_as_backward=True):
"""Fourier Operator class."""

def __init__(
Expand Down Expand Up @@ -67,12 +68,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
self._nufft_dims.append(dim)

if self._fft_dims:
self._fast_fourier_op = FastFourierOp(
self._fast_fourier_op: FastFourierOp | None = FastFourierOp(
dim=tuple(self._fft_dims),
recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims),
)

self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp(
encoding_matrix=encoding_matrix, traj=traj
)
else:
self._fast_fourier_op = None
self._cart_sampling_op = None
# Find dimensions which require NUFFT
if self._nufft_dims:
fft_dims_k210 = [
Expand Down Expand Up @@ -102,20 +108,23 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

self._fwd_nufft_op = KbNufft(
self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op = KbNufftAdjoint(
self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)

self._kshape = traj.broadcasted_shape
else:
self._omega: torch.Tensor | None = None
self._fwd_nufft_op = None
self._adj_nufft_op = None
self._kshape = traj.broadcasted_shape

@classmethod
def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self:
Expand Down Expand Up @@ -146,11 +155,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil k-space data with shape: (... coils k2 k1 k0)
"""
if len(self._fft_dims):
# FFT
(x,) = self._fast_fourier_op(x)

if self._nufft_dims:
if self._fwd_nufft_op is not None and self._omega is not None:
# NUFFT Type 2
# we need to move the nufft-dimensions to the end and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils nufft_dims
# we could move the permute to __init__ but then we still would need to prepend if len(other)>1
Expand All @@ -163,7 +169,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
x = x.flatten(end_dim=-len(keep_dims) - 1)

# omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims)
# TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape).
omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
Expand All @@ -173,6 +178,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims]
x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils
x = x.permute(*unpermute)

if self._fast_fourier_op is not None and self._cart_sampling_op is not None:
# FFT
(x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0])

return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
Expand All @@ -187,11 +197,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil image data with shape: (... coils z y x)
"""
if self._fft_dims:
if self._fast_fourier_op is not None and self._cart_sampling_op is not None:
# IFFT
(x,) = self._fast_fourier_op.adjoint(x)
(x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0])

if self._nufft_dims:
if self._adj_nufft_op is not None and self._omega is not None:
# NUFFT Type 1
# we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils (nufft_dims)
keep_dims = [-4, *self._nufft_dims] # -4 is coil
Expand Down
1 change: 1 addition & 0 deletions src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view
import mrpro.utils.unit_conversion

__all__ = [
"broadcast_right",
Expand Down
3 changes: 2 additions & 1 deletion tests/algorithms/test_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
(1, 32, False),
(4, 32, True),
(4, 32, False),
]
],
ids=['complex_single', 'real_single', 'complex_batch', 'real_batch'],
)
def system(request):
"""Generate data for creating a system Hx=b with linear and self-adjoint
Expand Down
Loading

0 comments on commit 627300f

Please sign in to comment.