diff --git a/src/mrpro/operators/CartesianSamplingOp.py b/src/mrpro/operators/CartesianSamplingOp.py index 07f8aba6..47c71c77 100644 --- a/src/mrpro/operators/CartesianSamplingOp.py +++ b/src/mrpro/operators/CartesianSamplingOp.py @@ -212,3 +212,69 @@ def _broadcast_and_scatter_along_last_dim( ).scatter_(dim=-1, index=idx_expanded, src=data_to_scatter) return data_scattered + + @property + def gram(self) -> 'CartesianSamplingGramOp': + """Return the Gram operator for this Cartesian Sampling Operator. + + Returns + ------- + Gram operator for this Cartesian Sampling Operator + """ + return CartesianSamplingGramOp(self) + + +class CartesianSamplingGramOp(LinearOperator): + """Gram operator for Cartesian Sampling Operator. + + The Gram operator is the composition CartesianSamplingOp.H @ CartesianSamplingOp. + """ + + def __init__(self, sampling_op: CartesianSamplingOp): + """Initialize Cartesian Sampling Gram Operator class. + + This should not be used directly, but rather through the `gram` method of a + :class:`mrpro.operator.CartesianSamplingOp` object. + + Parameters + ---------- + sampling_op + The Cartesian Sampling Operator for which to create the Gram operator. + """ + super().__init__() + if sampling_op._needs_indexing: + ones = torch.ones(*sampling_op._trajectory_shape[:-3], 1, *sampling_op._sorted_grid_shape.zyx) + (mask,) = sampling_op.adjoint(*sampling_op.forward(ones)) + self.register_buffer('_mask', mask) + else: + self._mask: torch.Tensor | None = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the Gram operator. + + Parameters + ---------- + x + Input data + + Returns + ------- + Output data + """ + if self._mask is None: + return (x,) + return (x * self._mask,) + + def adjoint(self, y: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint of the Gram operator. + + Parameters + ---------- + y + Input data + + Returns + ------- + Output data + """ + return self.forward(y) diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index cacdda1d..a3e81aba 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -1,6 +1,7 @@ """Fourier Operator.""" from collections.abc import Sequence +from itertools import product import numpy as np import torch @@ -223,3 +224,163 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.permute(*unpermute) return (x,) + + @property + def gram(self) -> LinearOperator: + """Return the gram operator.""" + return FourierGramOp(self) + + +def symmetrize(kernel: torch.Tensor, rank: int) -> torch.Tensor: + """Enforce hermitian symmetry on the kernel. Returns only half of the kernel.""" + flipped = kernel.clone() + for d in range(-rank, 0): + flipped = flipped.index_select(d, -1 * torch.arange(flipped.shape[d], device=flipped.device) % flipped.size(d)) + kernel = (kernel + flipped.conj()) / 2 + last_len = kernel.shape[-1] + return kernel[..., : last_len // 2 + 1] + + +def gram_nufft_kernel(weight: torch.Tensor, trajectory: torch.Tensor, recon_shape: Sequence[int]) -> torch.Tensor: + """Calculate the convolution kernel for the NUFFT gram operator. + + Parameters + ---------- + weight + either ones or density compensation weights + trajectory + k-space trajectory + recon_shape + shape of the reconstructed image + + Returns + ------- + kernel + real valued convolution kernel for the NUFFT gram operator, already in Fourier space + """ + rank = trajectory.shape[-2] + if rank != len(recon_shape): + raise ValueError('Rank of trajectory and image size must match.') + # Instead of doing one adjoint nufft with double the recon size in all dimensions, + # we do two adjoint nuffts per dimensions, saving a lot of memory. + adjnufft_ob = KbNufftAdjoint(im_size=recon_shape, n_shift=[0] * rank).to(trajectory) + + kernel = adjnufft_ob(weight, trajectory) # this will be the top left ... corner block + pad = [] + for s in kernel.shape[: -rank - 1 : -1]: + pad.extend([0, s]) + kernel = torch.nn.functional.pad(kernel, pad) # twice the size in all dimensions + + for flips in list(product([1, -1], repeat=rank)): + if all(flip == 1 for flip in flips): + # top left ... block already processed before padding + continue + flipped_trajectory = trajectory * torch.tensor(flips).to(trajectory).unsqueeze(-1) + kernel_part = adjnufft_ob(weight, flipped_trajectory) + slices = [] # which part of the kernel to is currently being processed + for dim, flip in zip(range(-rank, 0), flips, strict=True): + if flip > 0: # first half in the dimension + slices.append(slice(0, kernel_part.size(dim))) + else: # second half in the dimension + slices.append(slice(kernel_part.size(dim) + 1, None)) + kernel_part = kernel_part.index_select(dim, torch.arange(kernel_part.size(dim) - 1, 0, -1)) # flip + + kernel[[..., *slices]] = kernel_part + + kernel = symmetrize(kernel, rank) + kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward') + kernel /= kernel.shape[-rank:].numel() + kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0))) + return kernel + + +class FourierGramOp(LinearOperator): + """Gram operator for the Fourier operator. + + Implements the adjoint of the forward operator of the Fourier operator, i.e. the gram operator + `F.H@F. + + Uses a convolution, implemented as multiplication in Fourier space, to calculate the gram operator + for the toeplitz NUFFT operator. + + Uses a multiplication with a binary mask in Fourier space to calculate the gram operator for + the Cartesian FFT operator + + This Operator is only used internally and should not be used directly. + Instead, consider using the `gram` property of :class: `mrpro.operators.FourierOp`. + """ + + _kernel: torch.Tensor | None + + def __init__(self, fourier_op: FourierOp) -> None: + """Initialize the gram operator. + + If density compensation weights are provided, they the operator + F.H@dcf@F is calculated. + + Parameters + ---------- + fourier_op + the Fourier operator to calculate the gram operator for + + """ + super().__init__() + if fourier_op._nufft_dims and fourier_op._omega is not None: + weight = torch.ones_like(fourier_op._omega[..., :1, :, :, :]) + keep_dims = [-4, *fourier_op._nufft_dims] # -4 is coil + permute = [i for i in range(-weight.ndim, 0) if i not in keep_dims] + keep_dims + unpermute = np.argsort(permute) + weight = weight.permute(*permute) + weight_unflattend_shape = weight.shape + weight = weight.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + weight = weight + 0j + omega = fourier_op._omega.permute(*permute) + omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1) + kernel = gram_nufft_kernel(weight, omega, fourier_op._nufft_im_size) + kernel = kernel.reshape(*weight_unflattend_shape[: -len(keep_dims)], *kernel.shape[-len(keep_dims) :]) + kernel = kernel.permute(*unpermute) + fft = FastFourierOp( + dim=fourier_op._nufft_dims, + encoding_matrix=[2 * s for s in fourier_op._nufft_im_size], + recon_matrix=fourier_op._nufft_im_size, + ) + self.nufft_gram: None | LinearOperator = fft.H * kernel @ fft + else: + self.nufft_gram = None + + if fourier_op._fast_fourier_op is not None and fourier_op._cart_sampling_op is not None: + self.fast_fourier_gram: None | LinearOperator = ( + fourier_op._fast_fourier_op.H @ fourier_op._cart_sampling_op.gram @ fourier_op._fast_fourier_op + ) + else: + self.fast_fourier_gram = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, z, y, x) + """ + if self.nufft_gram is not None: + (x,) = self.nufft_gram(x) + + if self.fast_fourier_gram is not None: + (x,) = self.fast_fourier_gram(x) + return (x,) + + def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint operator to the input tensor. + + Parameters + ---------- + x + input tensor, shape (..., coils, k2, k1, k0) + """ + return self.forward(x) + + @property + def H(self) -> Self: # noqa: N802 + """Adjoint operator of the gram operator.""" + return self diff --git a/src/mrpro/utils/zero_pad_or_crop.py b/src/mrpro/utils/zero_pad_or_crop.py index 42adda43..23fb3959 100644 --- a/src/mrpro/utils/zero_pad_or_crop.py +++ b/src/mrpro/utils/zero_pad_or_crop.py @@ -35,7 +35,7 @@ def zero_pad_or_crop( new_shape: Sequence[int] | torch.Size, dim: None | Sequence[int] = None, ) -> torch.Tensor: - """Change shape of data by cropping or zero-padding. + """Change shape of data by center cropping or symmetric zero-padding. Parameters ---------- diff --git a/tests/operators/test_cartesian_sampling_op.py b/tests/operators/test_cartesian_sampling_op.py index 49959f91..0fd32021 100644 --- a/tests/operators/test_cartesian_sampling_op.py +++ b/tests/operators/test_cartesian_sampling_op.py @@ -5,6 +5,7 @@ from einops import rearrange from mrpro.data import KTrajectory, SpatialDimension from mrpro.operators import CartesianSamplingOp +from typing_extensions import Unpack from tests import RandomGenerator, dotproduct_adjointness_test from tests.conftest import create_traj @@ -50,33 +51,11 @@ def test_cart_sampling_op_data_match(): torch.testing.assert_close(kdata[:, :, ::2, ::4, ::3], k_sub[:, :, ::2, ::4, ::3]) -@pytest.mark.parametrize( - 'sampling', - [ - 'random', - 'partial_echo', - 'partial_fourier', - 'regular_undersampling', - 'random_undersampling', - 'different_random_undersampling', - 'cartesian_and_non_cartesian', - 'kx_ky_along_k0', - 'kx_ky_along_k0_undersampling', - ], -) -def test_cart_sampling_op_fwd_adj(sampling): - """Test adjoint property of Cartesian sampling operator.""" - - # Create 3D uniform trajectory - k_shape = (2, 5, 20, 40, 60) - nkx = (2, 1, 1, 60) - nky = (2, 1, 40, 1) - nkz = (2, 20, 1, 1) - type_kx = 'uniform' - type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' - type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' - trajectory_tensor = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz).as_tensor() - +def subsample_traj( + trajectory: KTrajectory, sampling: str, k_shape: tuple[int, int, int, Unpack[tuple[int, ...]]] +) -> KTrajectory: + """Subsample trajectory based on sampling type.""" + trajectory_tensor = trajectory.as_tensor() # Subsample data and trajectory match sampling: case 'random': @@ -108,6 +87,36 @@ def test_cart_sampling_op_fwd_adj(sampling): trajectory = KTrajectory.from_tensor(trajectory_tensor[..., random_idx[: trajectory_tensor.shape[-1] // 2]]) case _: raise NotImplementedError(f'Test {sampling} not implemented.') + return trajectory + + +@pytest.mark.parametrize( + 'sampling', + [ + 'random', + 'partial_echo', + 'partial_fourier', + 'regular_undersampling', + 'random_undersampling', + 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', + ], +) +def test_cart_sampling_op_fwd_adj(sampling): + """Test adjoint property of Cartesian sampling operator.""" + + # Create 3D uniform trajectory + k_shape = (2, 5, 20, 40, 60) + nkx = (2, 1, 1, 60) + nky = (2, 1, 40, 1) + nkz = (2, 20, 1, 1) + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + trajectory = subsample_traj(trajectory, sampling, k_shape) encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) @@ -119,6 +128,42 @@ def test_cart_sampling_op_fwd_adj(sampling): dotproduct_adjointness_test(sampling_op, u, v) +@pytest.mark.parametrize( + 'sampling', + [ + 'random', + 'partial_echo', + 'partial_fourier', + 'regular_undersampling', + 'random_undersampling', + 'different_random_undersampling', + 'cartesian_and_non_cartesian', + 'kx_ky_along_k0', + 'kx_ky_along_k0_undersampling', + ], +) +def test_cart_sampling_op_gram(sampling): + """Test adjoint gram of Cartesian sampling operator.""" + + # Create 3D uniform trajectory + k_shape = (2, 5, 20, 40, 60) + nkx = (2, 1, 1, 60) + nky = (2, 1, 40, 1) + nkz = (2, 20, 1, 1) + type_kx = 'uniform' + type_ky = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + type_kz = 'non-uniform' if sampling == 'cartesian_and_non_cartesian' else 'uniform' + trajectory = create_traj(k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + trajectory = subsample_traj(trajectory, sampling, k_shape) + + encoding_matrix = SpatialDimension(k_shape[-3], k_shape[-2], k_shape[-1]) + sampling_op = CartesianSamplingOp(encoding_matrix=encoding_matrix, traj=trajectory) + u = RandomGenerator(seed=0).complex64_tensor(size=k_shape) + (expected,) = (sampling_op.H @ sampling_op)(u) + (actual,) = sampling_op.gram(u) + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize(('k2_min', 'k2_max'), [(-1, 21), (-21, 1)]) @pytest.mark.parametrize(('k0_min', 'k0_max'), [(-6, 13), (-13, 6)]) def test_cart_sampling_op_oversampling(k0_min, k0_max, k2_min, k2_max): diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 826ea24f..c7c58c26 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -48,6 +48,25 @@ def test_fourier_op_fwd_adj_property( dotproduct_adjointness_test(fourier_op, u, v) +@COMMON_MR_TRAJECTORIES +def test_fourier_op_gram(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): + """Test gram of Fourier operator.""" + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + + recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) + fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + (expected,) = (fourier_op.H @ fourier_op)(img) + (actual,) = fourier_op.gram(img) + + torch.testing.assert_close(actual, expected, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize( ('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'type_kx', 'type_ky', 'type_kz'), # parameter names [