Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take NonUniformFastFourierOp out of FourierOp #463

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
38f33fb
tests added
ckolbPTB Oct 22, 2024
74e6250
doc strings
ckolbPTB Oct 22, 2024
60fca9e
Update src/mrpro/operators/NonUniformFastFourierOp.py
fzimmermann89 Oct 22, 2024
fed8954
merge main
ckolbPTB Oct 23, 2024
2659402
review
ckolbPTB Oct 23, 2024
cc1472d
input pars and doc string adapted
ckolbPTB Oct 23, 2024
0c29b01
Merge branch 'main' into separate_nufft_op
ckolbPTB Oct 23, 2024
393c5a9
Merge branch 'main' into separate_nufft_op
ckolbPTB Oct 23, 2024
276873e
test for nufft output added
ckolbPTB Oct 23, 2024
89e36bc
Merge branch 'main' into separate_nufft_op
ckolbPTB Oct 24, 2024
3c6d873
empty dim
ckolbPTB Oct 24, 2024
d2fb391
Update tests/operators/test_non_uniform_fast_fourier_op.py
ckolbPTB Dec 5, 2024
a4b8a64
merge main
ckolbPTB Dec 5, 2024
2a90533
fix merge problems
ckolbPTB Dec 5, 2024
b67bef6
fix more merge problems
ckolbPTB Dec 5, 2024
9b60be2
gram and cart_samp fixed
ckolbPTB Dec 5, 2024
281e5bc
spatial dims and test for unsupported direction added
ckolbPTB Dec 5, 2024
4b8d8fb
nufft dim automatically detected
ckolbPTB Dec 6, 2024
aba850d
fix for single shot traj and further tests added
ckolbPTB Dec 11, 2024
6ba3fb4
remove superfluous tests
ckolbPTB Dec 11, 2024
b7c83e6
first try
ckolbPTB Dec 11, 2024
8270de5
forward adapted
ckolbPTB Dec 11, 2024
e73ace5
adjoint adapted
ckolbPTB Dec 11, 2024
e586360
add _nufft_type1 and _nufft_type2
ckolbPTB Dec 12, 2024
9977205
clean up
ckolbPTB Dec 12, 2024
4b70389
sep dims and joint dims for forward and adjoint
ckolbPTB Dec 12, 2024
01ca1dc
gram started
ckolbPTB Dec 12, 2024
87f51c2
gram finished for nufft
ckolbPTB Dec 12, 2024
6585ab9
tests adapted and bug fix
ckolbPTB Dec 13, 2024
2330e3d
add rpe to conftest
ckolbPTB Dec 13, 2024
739d76c
conftest update
ckolbPTB Dec 13, 2024
d8217e8
use given kshape
ckolbPTB Dec 14, 2024
1bccb41
conftest error fixed
ckolbPTB Dec 14, 2024
b11acf0
misalignment k210 and kzyx still a problem
ckolbPTB Dec 14, 2024
b2d58fa
tidy up
ckolbPTB Dec 14, 2024
d027e4a
Merge branch 'main' into separate_nufft_op
ckolbPTB Dec 14, 2024
302782c
joint dims zyx
ckolbPTB Dec 16, 2024
4382d68
nufft gram separated out
ckolbPTB Dec 16, 2024
f364526
Merge branch 'main' into separate_nufft_op
ckolbPTB Dec 16, 2024
2d9c51a
merge main
ckolbPTB Jan 10, 2025
3cd6c8e
gram adj_nufft separate, test fix and speed up
ckolbPTB Jan 10, 2025
9e229da
fix cart traj calc
ckolbPTB Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 28 additions & 179 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Fourier Operator."""

from collections.abc import Sequence
from itertools import product

import numpy as np
import torch
from torchkbnufft import KbNufft, KbNufftAdjoint
from typing_extensions import Self

from mrpro.data.enums import TrajType
Expand All @@ -15,6 +12,7 @@
from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp
from mrpro.operators.FastFourierOp import FastFourierOp
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.operators.NonUniformFastFourierOp import NonUniformFastFourierOp


class FourierOp(LinearOperator, adjoint_as_backward=True):
Expand All @@ -26,8 +24,6 @@ def __init__(
encoding_matrix: SpatialDimension[int],
traj: KTrajectory,
nufft_oversampling: float = 2.0,
nufft_numpoints: int = 6,
nufft_kbwidth: float = 2.34,
) -> None:
"""Fourier Operator class.

Expand All @@ -40,11 +36,10 @@ def __init__(
traj
the k-space trajectories where the frequencies are sampled
nufft_oversampling
oversampling used for interpolation in non-uniform FFTs
nufft_numpoints
number of neighbors for interpolation in non-uniform FFTs
nufft_kbwidth
size of the Kaiser-Bessel kernel interpolation in non-uniform FFTs
oversampling used for interpolation in non-uniform FFTs. The oversampling of the interpolation grid, which
is needed during the non-uniform FFT, ensures that there is no foldover due to the finite gridding kernel.
It can be reduced (e.g. to 1.25) to speed up the non-uniform FFT but this might lead to poorer image
quality.
"""
super().__init__()

Expand All @@ -55,9 +50,6 @@ def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]):
if i in dims
]

def get_traj(traj: KTrajectory, dims: Sequence[int]):
return [k for k, i in zip((traj.kz, traj.ky, traj.kx), (-3, -2, -1), strict=True) if i in dims]

self._ignore_dims, self._fft_dims, self._nufft_dims = [], [], []
for dim, type_ in zip((-3, -2, -1), traj.type_along_kzyx, strict=True):
if type_ & TrajType.SINGLEVALUE:
Expand All @@ -80,6 +72,7 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
else:
self._fast_fourier_op = None
self._cart_sampling_op = None

# Find dimensions which require NUFFT
if self._nufft_dims:
fft_dims_k210 = [
Expand All @@ -91,41 +84,18 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
if self._fft_dims != fft_dims_k210:
raise NotImplementedError(
'If both FFT and NUFFT dims are present, Cartesian FFT dims need to be aligned with the '
'k-space dimension, i.e. kx along k0, ky along k1 and kz along k2',
)

self._nufft_im_size = get_spatial_dims(recon_matrix, self._nufft_dims)
grid_size = [int(size * nufft_oversampling) for size in self._nufft_im_size]
omega = [
k * 2 * torch.pi / ks
for k, ks in zip(
get_traj(traj, self._nufft_dims),
get_spatial_dims(encoding_matrix, self._nufft_dims),
strict=True,
'k-space dimension, i.e. kx along k0, ky along k1 and kz along k2.',
)
]

# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction
numpoints = [min(img_size, nufft_numpoints) for img_size in self._nufft_im_size]
self._fwd_nufft_op: KbNufftAdjoint | None = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op: KbNufftAdjoint | None = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=numpoints,
kbwidth=nufft_kbwidth,
self._non_uniform_fast_fourier_op: NonUniformFastFourierOp | None = NonUniformFastFourierOp(
direction=tuple(self._nufft_dims), # type: ignore[arg-type]
recon_matrix=get_spatial_dims(recon_matrix, self._nufft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._nufft_dims),
traj=traj,
nufft_oversampling=nufft_oversampling,
)
else:
self._omega: torch.Tensor | None = None
self._fwd_nufft_op = None
self._adj_nufft_op = None
self._kshape = traj.broadcasted_shape
self._non_uniform_fast_fourier_op = None

@classmethod
def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self:
Expand Down Expand Up @@ -156,34 +126,12 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil k-space data with shape: (... coils k2 k1 k0)
"""
if self._fwd_nufft_op is not None and self._omega is not None:
# NUFFT Type 2
# we need to move the nufft-dimensions to the end and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils nufft_dims
# we could move the permute to __init__ but then we still would need to prepend if len(other)>1
keep_dims = [-4, *self._nufft_dims] # -4 is always coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)

x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1)

# omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims)
omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

x = self._fwd_nufft_op(x, omega, norm='ortho')

shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims]
x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils
x = x.permute(*unpermute)

if self._fast_fourier_op is not None and self._cart_sampling_op is not None:
# FFT
(x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0])
# NUFFT Type 2 followed by FFT
if self._non_uniform_fast_fourier_op:
(x,) = self._non_uniform_fast_fourier_op(x)

if self._fast_fourier_op and self._cart_sampling_op:
(x,) = self._cart_sampling_op(self._fast_fourier_op(x)[0])
return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
Expand All @@ -198,30 +146,12 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil image data with shape: (... coils z y x)
"""
if self._fast_fourier_op is not None and self._cart_sampling_op is not None:
# IFFT
# FFT followed by NUFFT Type 1
if self._fast_fourier_op and self._cart_sampling_op:
(x,) = self._fast_fourier_op.adjoint(self._cart_sampling_op.adjoint(x)[0])

if self._adj_nufft_op is not None and self._omega is not None:
# NUFFT Type 1
# we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils (nufft_dims)
keep_dims = [-4, *self._nufft_dims] # -4 is coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)

x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

x = self._adj_nufft_op(x, omega, norm='ortho')

x = x.reshape(*permuted_x_shape[: -len(keep_dims)], *x.shape[-len(keep_dims) :])
x = x.permute(*unpermute)
if self._non_uniform_fast_fourier_op:
(x,) = self._non_uniform_fast_fourier_op.adjoint(x)

return (x,)

Expand All @@ -231,69 +161,6 @@ def gram(self) -> LinearOperator:
return FourierGramOp(self)


def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor:
"""Enforce hermitian symmetry on the kernel. Returns only half of the kernel."""
flipped = kernel.clone()
for d in range(-rank, 0):
flipped = flipped.index_select(d, -1 * torch.arange(flipped.shape[d], device=flipped.device) % flipped.size(d))
kernel = (kernel + flipped.conj()) / 2
last_len = kernel.shape[-1]
return kernel[..., : last_len // 2 + 1]


def gram_nufft_kernel(weight: torch.Tensor, trajectory: torch.Tensor, recon_shape: Sequence[int]) -> torch.Tensor:
"""Calculate the convolution kernel for the NUFFT gram operator.

Parameters
----------
weight
either ones or density compensation weights
trajectory
k-space trajectory
recon_shape
shape of the reconstructed image

Returns
-------
kernel
real valued convolution kernel for the NUFFT gram operator, already in Fourier space
"""
rank = trajectory.shape[-2]
if rank != len(recon_shape):
raise ValueError('Rank of trajectory and image size must match.')
# Instead of doing one adjoint nufft with double the recon size in all dimensions,
# we do two adjoint nuffts per dimensions, saving a lot of memory.
adjnufft_ob = KbNufftAdjoint(im_size=recon_shape, n_shift=[0] * rank).to(trajectory)

kernel = adjnufft_ob(weight, trajectory) # this will be the top left ... corner block
pad = []
for s in kernel.shape[: -rank - 1 : -1]:
pad.extend([0, s])
kernel = torch.nn.functional.pad(kernel, pad) # twice the size in all dimensions

for flips in list(product([1, -1], repeat=rank)):
if all(flip == 1 for flip in flips):
# top left ... block already processed before padding
continue
flipped_trajectory = trajectory * torch.tensor(flips).to(trajectory).unsqueeze(-1)
kernel_part = adjnufft_ob(weight, flipped_trajectory)
slices = [] # which part of the kernel to is currently being processed
for dim, flip in zip(range(-rank, 0), flips, strict=True):
if flip > 0: # first half in the dimension
slices.append(slice(0, kernel_part.size(dim)))
else: # second half in the dimension
slices.append(slice(kernel_part.size(dim) + 1, None))
kernel_part = kernel_part.index_select(dim, torch.arange(kernel_part.size(dim) - 1, 0, -1)) # flip

kernel[[..., *slices]] = kernel_part

kernel = symmetrize(kernel, rank)
kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward')
kernel /= kernel.shape[-rank:].numel()
kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0)))
return kernel


class FourierGramOp(LinearOperator):
"""Gram operator for the Fourier operator.

Expand Down Expand Up @@ -325,30 +192,12 @@ def __init__(self, fourier_op: FourierOp) -> None:

"""
super().__init__()
if fourier_op._nufft_dims and fourier_op._omega is not None:
weight = torch.ones_like(fourier_op._omega[..., :1, :, :, :])
keep_dims = [-4, *fourier_op._nufft_dims] # -4 is coil
permute = [i for i in range(-weight.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)
weight = weight.permute(*permute)
weight_unflattend_shape = weight.shape
weight = weight.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
weight = weight + 0j
omega = fourier_op._omega.permute(*permute)
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
kernel = gram_nufft_kernel(weight, omega, fourier_op._nufft_im_size)
kernel = kernel.reshape(*weight_unflattend_shape[: -len(keep_dims)], *kernel.shape[-len(keep_dims) :])
kernel = kernel.permute(*unpermute)
fft = FastFourierOp(
dim=fourier_op._nufft_dims,
encoding_matrix=[2 * s for s in fourier_op._nufft_im_size],
recon_matrix=fourier_op._nufft_im_size,
)
self.nufft_gram: None | LinearOperator = fft.H * kernel @ fft
if fourier_op._non_uniform_fast_fourier_op:
self.nufft_gram: None | LinearOperator = fourier_op._non_uniform_fast_fourier_op.gram
else:
self.nufft_gram = None

if fourier_op._fast_fourier_op is not None and fourier_op._cart_sampling_op is not None:
if fourier_op._fast_fourier_op and fourier_op._cart_sampling_op:
self.fast_fourier_gram: None | LinearOperator = (
fourier_op._fast_fourier_op.H @ fourier_op._cart_sampling_op.gram @ fourier_op._fast_fourier_op
)
Expand All @@ -363,10 +212,10 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
x
input tensor, shape (..., coils, z, y, x)
"""
if self.nufft_gram is not None:
if self.nufft_gram:
(x,) = self.nufft_gram(x)

if self.fast_fourier_gram is not None:
if self.fast_fourier_gram:
(x,) = self.fast_fourier_gram(x)
return (x,)

Expand Down
Loading
Loading