Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Gram Shortcut for FourierOp #503

Merged
merged 8 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/mrpro/operators/CartesianSamplingOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,69 @@ def _broadcast_and_scatter_along_last_dim(
).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter)

return data_scattered

@property
def gram(self) -> 'CartesianSamplingGramOp':
"""Return the Gram operator for this Cartesian Sampling Operator.

Returns
-------
Gram operator for this Cartesian Sampling Operator
"""
return CartesianSamplingGramOp(self)


class CartesianSamplingGramOp(LinearOperator):
"""Gram operator for Cartesian Sampling Operator.

The Gram operator is the composition CartesianSamplingOp.H @ CartesianSamplingOp.
"""

def __init__(self, sampling_op: CartesianSamplingOp):
"""Initialize Cartesian Sampling Gram Operator class.

This should not be used directly, but rather through the `gram` method of a
:class:`mrpro.operator.CartesianSamplingOp` object.

Parameters
----------
sampling_op
The Cartesian Sampling Operator for which to create the Gram operator.
"""
super().__init__()
if sampling_op._needs_indexing:
ones = torch.ones(*sampling_op._trajectory_shape[:-3], 1, *sampling_op._sorted_grid_shape.zyx)
(mask,) = sampling_op.adjoint(*sampling_op.forward(ones))
self.register_buffer('_mask', mask)
else:
self._mask: torch.Tensor | None = None

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the Gram operator.

Parameters
----------
x
Input data

Returns
-------
Output data
"""
if self._mask is None:
return (x,)
return (x * self._mask,)

def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the adjoint of the Gram operator.

Parameters
----------
y
Input data

Returns
-------
Output data
"""
return self.forward(y)
161 changes: 161 additions & 0 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Fourier Operator."""

from collections.abc import Sequence
from itertools import product

import numpy as np
import torch
Expand Down Expand Up @@ -223,3 +224,163 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
x = x.permute(*unpermute)

return (x,)

@property
def gram(self) -> LinearOperator:
"""Return the gram operator."""
return FourierGramOp(self)


def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor:
"""Enforce hermitian symmetry on the kernel. Returns only half of the kernel."""
flipped = kernel.clone()
for d in range(-rank, 0):
flipped = flipped.index_select(d, -1 * torch.arange(flipped.shape[d], device=flipped.device) % flipped.size(d))
kernel = (kernel + flipped.conj()) / 2
last_len = kernel.shape[-1]
return kernel[..., : last_len // 2 + 1]


def gram_nufft_kernel(weight: torch.Tensor, trajectory: torch.Tensor, recon_shape: Sequence[int]) -> torch.Tensor:
"""Calculate the convolution kernel for the NUFFT gram operator.

Parameters
----------
weight
either ones or density compensation weights
trajectory
k-space trajectory
recon_shape
shape of the reconstructed image

Returns
-------
kernel
real valued convolution kernel for the NUFFT gram operator, already in Fourier space
"""
rank = trajectory.shape[-2]
if rank != len(recon_shape):
raise ValueError('Rank of trajectory and image size must match.')
# Instead of doing one adjoint nufft with double the recon size in all dimensions,
# we do two adjoint nuffts per dimensions, saving a lot of memory.
adjnufft_ob = KbNufftAdjoint(im_size=recon_shape, n_shift=[0] * rank).to(trajectory)

kernel = adjnufft_ob(weight, trajectory) # this will be the top left ... corner block
pad = []
for s in kernel.shape[: -rank - 1 : -1]:
pad.extend([0, s])
kernel = torch.nn.functional.pad(kernel, pad) # twice the size in all dimensions

for flips in list(product([1, -1], repeat=rank)):
if all(flip == 1 for flip in flips):
# top left ... block already processed before padding
continue
flipped_trajectory = trajectory * torch.tensor(flips).to(trajectory).unsqueeze(-1)
kernel_part = adjnufft_ob(weight, flipped_trajectory)
slices = [] # which part of the kernel to is currently being processed
for dim, flip in zip(range(-rank, 0), flips, strict=True):
if flip > 0: # first half in the dimension
slices.append(slice(0, kernel_part.size(dim)))
else: # second half in the dimension
slices.append(slice(kernel_part.size(dim) + 1, None))
kernel_part = kernel_part.index_select(dim, torch.arange(kernel_part.size(dim) - 1, 0, -1)) # flip

kernel[[..., *slices]] = kernel_part

kernel = symmetrize(kernel, rank)
kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward')
kernel /= kernel.shape[-rank:].numel()
kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0)))
return kernel


class FourierGramOp(LinearOperator):
"""Gram operator for the Fourier operator.

Implements the adjoint of the forward operator of the Fourier operator, i.e. the gram operator
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved
`F.H@F.

Uses a convolution, implemented as multiplication in Fourier space, to calculate the gram operator
for the toeplitz NUFFT operator.

Uses a multiplication with a binary mask in Fourier space to calculate the gram operator for
the Cartesian FFT operator

This Operator is only used internally and should not be used directly.
Instead, consider using the `gram` property of :class: `mrpro.operators.FourierOp`.
"""

_kernel: torch.Tensor | None

def __init__(self, fourier_op: FourierOp) -> None:
"""Initialize the gram operator.

If density compensation weights are provided, they the operator
F.H@dcf@F is calculated.
fzimmermann89 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
fourier_op
the Fourier operator to calculate the gram operator for

"""
super().__init__()
if fourier_op._nufft_dims and fourier_op._omega is not None:
weight = torch.ones_like(fourier_op._omega[..., :1, :, :, :])
keep_dims = [-4, *fourier_op._nufft_dims] # -4 is coil
permute = [i for i in range(-weight.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)
weight = weight.permute(*permute)
weight_unflattend_shape = weight.shape
weight = weight.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
weight = weight + 0j
omega = fourier_op._omega.permute(*permute)
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)
kernel = gram_nufft_kernel(weight, omega, fourier_op._nufft_im_size)
kernel = kernel.reshape(*weight_unflattend_shape[: -len(keep_dims)], *kernel.shape[-len(keep_dims) :])
kernel = kernel.permute(*unpermute)
fft = FastFourierOp(
dim=fourier_op._nufft_dims,
encoding_matrix=[2 * s for s in fourier_op._nufft_im_size],
recon_matrix=fourier_op._nufft_im_size,
)
self.nufft_gram: None | LinearOperator = fft.H * kernel @ fft
else:
self.nufft_gram = None

if fourier_op._fast_fourier_op is not None and fourier_op._cart_sampling_op is not None:
self.fast_fourier_gram: None | LinearOperator = (
fourier_op._fast_fourier_op.H @ fourier_op._cart_sampling_op.gram @ fourier_op._fast_fourier_op
)
else:
self.fast_fourier_gram = None

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the operator to the input tensor.

Parameters
----------
x
input tensor, shape (..., coils, z, y, x)
"""
if self.nufft_gram is not None:
(x,) = self.nufft_gram(x)

if self.fast_fourier_gram is not None:
(x,) = self.fast_fourier_gram(x)
return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the adjoint operator to the input tensor.

Parameters
----------
x
input tensor, shape (..., coils, k2, k1, k0)
"""
return self.forward(x)

@property
def H(self) -> Self: # noqa: N802
"""Adjoint operator of the gram operator."""
return self
2 changes: 1 addition & 1 deletion src/mrpro/utils/zero_pad_or_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def zero_pad_or_crop(
new_shape: Sequence[int] | torch.Size,
dim: None | Sequence[int] = None,
) -> torch.Tensor:
"""Change shape of data by cropping or zero-padding.
"""Change shape of data by center cropping or symmetric zero-padding.

Parameters
----------
Expand Down
99 changes: 72 additions & 27 deletions tests/operators/test_cartesian_sampling_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from einops import rearrange
from mrpro.data import KTrajectory, SpatialDimension
from mrpro.operators import CartesianSamplingOp
from typing_extensions import Unpack

from tests import RandomGenerator, dotproduct_adjointness_test
from tests.conftest import create_traj
Expand Down Expand Up @@ -50,33 +51,11 @@ def test_cart_sampling_op_data_match():
torch.testing.assert_close(kdata[:, :, ::2, ::4, ::3], k_sub[:, :, ::2, ::4, ::3])


@pytest.mark.parametrize(
'sampling',
[
'random',
'partial_echo',
'partial_fourier',
'regular_undersampling',
'random_undersampling',
'different_random_undersampling',
'cartesian_and_non_cartesian',
'kx_ky_along_k0',
'kx_ky_along_k0_undersampling',
],
)
def test_cart_sampling_op_fwd_adj(sampling):
"""Test adjoint property of Cartesian sampling operator."""

# Create 3D uniform trajectory
k_shape = (2, 5, 20, 40, 60)
nkx = (2, 1, 1, 60)
nky = (2, 1, 40, 1)
nkz = (2, 20, 1, 1)
type_kx = 'uniform'
type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform'
type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform'
trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz).as_tensor()

def subsample_traj(
trajectory: KTrajectory, sampling: str, k_shape: tuple[int, int, int, Unpack[tuple[int, ...]]]
) -> KTrajectory:
"""Subsample trajectory based on sampling type."""
trajectory_tensor = trajectory.as_tensor()
# Subsample data and trajectory
match sampling:
case 'random':
Expand Down Expand Up @@ -108,6 +87,36 @@ def test_cart_sampling_op_fwd_adj(sampling):
trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]])
case _:
raise NotImplementedError(f'Test {sampling} not implemented.')
return trajectory


@pytest.mark.parametrize(
'sampling',
[
'random',
'partial_echo',
'partial_fourier',
'regular_undersampling',
'random_undersampling',
'different_random_undersampling',
'cartesian_and_non_cartesian',
'kx_ky_along_k0',
'kx_ky_along_k0_undersampling',
],
)
def test_cart_sampling_op_fwd_adj(sampling):
"""Test adjoint property of Cartesian sampling operator."""

# Create 3D uniform trajectory
k_shape = (2, 5, 20, 40, 60)
nkx = (2, 1, 1, 60)
nky = (2, 1, 40, 1)
nkz = (2, 20, 1, 1)
type_kx = 'uniform'
type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform'
type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform'
trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz)
trajectory = subsample_traj(trajectory, sampling, k_shape)

encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1])
sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory)
Expand All @@ -119,6 +128,42 @@ def test_cart_sampling_op_fwd_adj(sampling):
dotproduct_adjointness_test(sampling_op, u, v)


@pytest.mark.parametrize(
'sampling',
[
'random',
'partial_echo',
'partial_fourier',
'regular_undersampling',
'random_undersampling',
'different_random_undersampling',
'cartesian_and_non_cartesian',
'kx_ky_along_k0',
'kx_ky_along_k0_undersampling',
],
)
def test_cart_sampling_op_gram(sampling):
"""Test adjoint gram of Cartesian sampling operator."""

# Create 3D uniform trajectory
k_shape = (2, 5, 20, 40, 60)
nkx = (2, 1, 1, 60)
nky = (2, 1, 40, 1)
nkz = (2, 20, 1, 1)
type_kx = 'uniform'
type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform'
type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform'
trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz)
trajectory = subsample_traj(trajectory, sampling, k_shape)

encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1])
sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory)
u = RandomGenerator(seed=0).complex64_tensor(size=k_shape)
(expected,) = (sampling_op.H @ sampling_op)(u)
(actual,) = sampling_op.gram(u)
torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)])
@pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)])
def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max):
Expand Down
Loading