diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7803441eb..95d14317a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-docstring-first @@ -15,14 +15,14 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff # linter args: [--fix] - id: ruff-format # formatter - repo: https://github.com/crate-ci/typos - rev: v1.25.0 + rev: v1.27.0 hooks: - id: typos @@ -34,7 +34,7 @@ repos: exclude: ^tests/ - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy pass_filenames: false diff --git a/examples/direct_reconstruction.ipynb b/examples/direct_reconstruction.ipynb index 3b6dc930e..1e4e74c9c 100644 --- a/examples/direct_reconstruction.ipynb +++ b/examples/direct_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/direct_reconstruction.py b/examples/direct_reconstruction.py index 7672aa7e7..5d55812c9 100644 --- a/examples/direct_reconstruction.py +++ b/examples/direct_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/iterative_sense_reconstruction.ipynb b/examples/iterative_sense_reconstruction.ipynb index 87249b2fb..f612d7522 100644 --- a/examples/iterative_sense_reconstruction.ipynb +++ b/examples/iterative_sense_reconstruction.ipynb @@ -37,10 +37,10 @@ "\n", "import requests\n", "\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { diff --git a/examples/iterative_sense_reconstruction.py b/examples/iterative_sense_reconstruction.py index 6d0bc49a5..ba5e6a01a 100644 --- a/examples/iterative_sense_reconstruction.py +++ b/examples/iterative_sense_reconstruction.py @@ -11,10 +11,10 @@ import requests -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction diff --git a/examples/pulseq_2d_radial_golden_angle.ipynb b/examples/pulseq_2d_radial_golden_angle.ipynb index 52e0310bb..bcb4482a1 100644 --- a/examples/pulseq_2d_radial_golden_angle.ipynb +++ b/examples/pulseq_2d_radial_golden_angle.ipynb @@ -33,13 +33,14 @@ "cell_type": "code", "execution_count": null, "id": "d16f41f1", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "# define zenodo records URL and create a temporary directory and h5-file\n", "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", - "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'\n", - "data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5')" + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" ] }, { @@ -50,9 +51,10 @@ "outputs": [], "source": [ "# Download raw data using requests\n", - "response = requests.get(zenodo_url + fname, timeout=30)\n", - "data_file.write(response.content)\n", - "data_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" ] }, { @@ -125,10 +127,10 @@ "# download the sequence file from zenodo\n", "zenodo_url = 'https://zenodo.org/records/10868061/files/'\n", "seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq'\n", - "seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq')\n", - "response = requests.get(zenodo_url + seq_fname, timeout=30)\n", - "seq_file.write(response.content)\n", - "seq_file.flush()" + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file:\n", + " response = requests.get(zenodo_url + seq_fname, timeout=30)\n", + " seq_file.write(response.content)\n", + " seq_file.flush()" ] }, { diff --git a/examples/pulseq_2d_radial_golden_angle.py b/examples/pulseq_2d_radial_golden_angle.py index 3f857c382..f4db5217a 100644 --- a/examples/pulseq_2d_radial_golden_angle.py +++ b/examples/pulseq_2d_radial_golden_angle.py @@ -19,13 +19,14 @@ # define zenodo records URL and create a temporary directory and h5-file zenodo_url = 'https://zenodo.org/records/10854057/files/' fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' -data_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') + # %% # Download raw data using requests -response = requests.get(zenodo_url + fname, timeout=30) -data_file.write(response.content) -data_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() # %% [markdown] # ### Image reconstruction using KTrajectoryIsmrmrd @@ -62,10 +63,10 @@ # download the sequence file from zenodo zenodo_url = 'https://zenodo.org/records/10868061/files/' seq_fname = 'pulseq_radial_2D_402spokes_golden_angle.seq' -seq_file = tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') -response = requests.get(zenodo_url + seq_fname, timeout=30) -seq_file.write(response.content) -seq_file.flush() +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.seq') as seq_file: + response = requests.get(zenodo_url + seq_fname, timeout=30) + seq_file.write(response.content) + seq_file.flush() # %% # Read raw data and calculate trajectory using KTrajectoryPulseq diff --git a/examples/regularized_iterative_sense_reconstruction.ipynb b/examples/regularized_iterative_sense_reconstruction.ipynb new file mode 100644 index 000000000..0a6743161 --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af432293", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Regularized Iterative SENSE Reconstruction of 2D golden angle radial data\n", + "Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a7a6ce3", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# define zenodo URL of the example ismrmd data\n", + "zenodo_url = 'https://zenodo.org/records/10854057/files/'\n", + "fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cd8486b", + "metadata": {}, + "outputs": [], + "source": [ + "# Download raw data\n", + "import tempfile\n", + "\n", + "import requests\n", + "\n", + "with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file:\n", + " response = requests.get(zenodo_url + fname, timeout=30)\n", + " data_file.write(response.content)\n", + " data_file.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "6a9defa1", + "metadata": {}, + "source": [ + "### Image reconstruction\n", + "We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data.\n", + "RegularizedIterativeSENSEReconstruction solves the following reconstruction problem:\n", + "\n", + "Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms,\n", + "coil sensitivity maps...) $A$ then we can formulate the forward problem as:\n", + "\n", + "$ y = Ax + n $\n", + "\n", + "where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 $\n", + "\n", + "where $W^\\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal\n", + "operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain\n", + "a solution with certain properties:\n", + "\n", + "$ F(x) = ||W^{\\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$\n", + "\n", + "where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image.\n", + "With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$.\n", + "\n", + "Setting the derivative of the functional $F$ to zero and rearranging yields\n", + "\n", + "$ (A^H W A + l B) x = A^H W y + l x_{reg}$\n", + "\n", + "which is a linear system $Hx = b$ that needs to be solved for $x$.\n", + "\n", + "One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution\n", + "dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks\n", + "has been used as an image regulariser.\n", + "\n", + "In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image.\n", + "Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using\n", + "only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the\n", + "regularization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4da15c2", + "metadata": {}, + "outputs": [], + "source": [ + "import mrpro" + ] + }, + { + "cell_type": "markdown", + "id": "de055070", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Read-in the raw data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ac1d89f", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.data import KData\n", + "from mrpro.data.traj_calculators import KTrajectoryIsmrmrd\n", + "\n", + "# Load in the Data and the trajectory from the ISMRMRD file\n", + "kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd())\n", + "kdata.header.recon_matrix.x = 256\n", + "kdata.header.recon_matrix.y = 256" + ] + }, + { + "cell_type": "markdown", + "id": "1f389140", + "metadata": {}, + "source": [ + "##### Image $x_{reg}$ from fully sampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "212b915c", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction\n", + "from mrpro.data import CsmData\n", + "\n", + "# Estimate coil maps\n", + "direct_reconstruction = DirectReconstruction(kdata, csm=None)\n", + "img_coilwise = direct_reconstruction(kdata)\n", + "csm = CsmData.from_idata_walsh(img_coilwise)\n", + "\n", + "# Iterative SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3)\n", + "img_iterative_sense = iterative_sense_reconstruction(kdata)" + ] + }, + { + "cell_type": "markdown", + "id": "bec6b712", + "metadata": {}, + "source": [ + "##### Image $x$ from undersampled data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6740447", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Data undersampling, i.e. take only the first 20 radial lines\n", + "idx_us = torch.arange(0, 20)[None, :]\n", + "kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fbfd664", + "metadata": {}, + "outputs": [], + "source": [ + "# Iterativ SENSE reconstruction\n", + "iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6)\n", + "img_us_iterative_sense = iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "041ffe72", + "metadata": {}, + "outputs": [], + "source": [ + "# Regularized iterativ SENSE reconstruction\n", + "from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction\n", + "\n", + "regularization_weight = 1.0\n", + "n_iterations = 6\n", + "regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction(\n", + " kdata_us,\n", + " csm=csm,\n", + " n_iterations=n_iterations,\n", + " regularization_data=img_iterative_sense.data,\n", + " regularization_weight=regularization_weight,\n", + ")\n", + "img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d5bbec1", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()]\n", + "vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4))\n", + "for ind in range(3):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "2bd49c87", + "metadata": {}, + "source": [ + "### Behind the scenes" + ] + }, + { + "cell_type": "markdown", + "id": "53779251", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "##### Set-up the density compensation operator $W$ and acquisition model $A$\n", + "\n", + "This is very similar to the iterative SENSE reconstruction. For more detail please look at the\n", + "iterative_sense_reconstruction notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e985a4f3", + "metadata": {}, + "outputs": [], + "source": [ + "dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator()\n", + "fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us)\n", + "csm_operator = csm.as_operator()\n", + "acquisition_operator = fourier_operator @ csm_operator" + ] + }, + { + "cell_type": "markdown", + "id": "2daa0fee", + "metadata": {}, + "source": [ + "##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac1d5fb4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "right_hand_side = (\n", + " acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "76a0b153", + "metadata": {}, + "source": [ + "##### Set-up the linear self-adjoint operator $H = A^H W A + l$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5effb592", + "metadata": {}, + "outputs": [], + "source": [ + "from mrpro.operators import IdentityOp\n", + "\n", + "operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor(\n", + " regularization_weight\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f24a8588", + "metadata": {}, + "source": [ + "##### Run conjugate gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96827838", + "metadata": {}, + "outputs": [], + "source": [ + "img_manual = mrpro.algorithms.optimizers.cg(\n", + " operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c065a7", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the reconstructed image\n", + "vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]]\n", + "vis_title = ['Regularized Iterative SENSE R=20', '\"Manual\" Regularized Iterative SENSE R=20']\n", + "fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4))\n", + "for ind in range(2):\n", + " ax[0, ind].imshow(vis_im[ind][0, 0, ...])\n", + " ax[0, ind].set_title(vis_title[ind])" + ] + }, + { + "cell_type": "markdown", + "id": "d6d7efdf", + "metadata": {}, + "source": [ + "### Check for equal results\n", + "The two versions should result in the same image data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f59b6015", + "metadata": {}, + "outputs": [], + "source": [ + "# If the assert statement did not raise an exception, the results are equal.\n", + "assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual)" + ] + }, + { + "cell_type": "markdown", + "id": "6ecd6e70", + "metadata": {}, + "source": [ + "### Next steps\n", + "Play around with the regularization_weight to see how it effects the final image quality.\n", + "\n", + "Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications\n", + "we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the\n", + "streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality\n", + "compared to the unregularised images." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/regularized_iterative_sense_reconstruction.py b/examples/regularized_iterative_sense_reconstruction.py new file mode 100644 index 000000000..2ab7ba033 --- /dev/null +++ b/examples/regularized_iterative_sense_reconstruction.py @@ -0,0 +1,193 @@ +# %% [markdown] +# # Regularized Iterative SENSE Reconstruction of 2D golden angle radial data +# Here we use the RegularizedIterativeSENSEReconstruction class to reconstruct images from ISMRMRD 2D radial data +# %% +# define zenodo URL of the example ismrmd data +zenodo_url = 'https://zenodo.org/records/10854057/files/' +fname = 'pulseq_radial_2D_402spokes_golden_angle_with_traj.h5' +# %% +# Download raw data +import tempfile + +import requests + +with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.h5') as data_file: + response = requests.get(zenodo_url + fname, timeout=30) + data_file.write(response.content) + data_file.flush() + +# %% [markdown] +# ### Image reconstruction +# We use the RegularizedIterativeSENSEReconstruction class to reconstruct images from 2D radial data. +# RegularizedIterativeSENSEReconstruction solves the following reconstruction problem: +# +# Let's assume we have obtained the k-space data $y$ from an image $x$ with an acquisition model (Fourier transforms, +# coil sensitivity maps...) $A$ then we can formulate the forward problem as: +# +# $ y = Ax + n $ +# +# where $n$ describes complex Gaussian noise. The image $x$ can be obtained by minimizing the functionl $F$ +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 $ +# +# where $W^\frac{1}{2}$ is the square root of the density compensation function (which corresponds to a diagonal +# operator). Because this is an ill-posed problem, we can add a regularization term to stabilize the problem and obtain +# a solution with certain properties: +# +# $ F(x) = ||W^{\frac{1}{2}}(Ax - y)||_2^2 + l||Bx - x_{reg}||_2^2$ +# +# where $l$ is the strength of the regularization, $B$ is a linear operator and $x_{reg}$ is a regularization image. +# With this functional $F$ we obtain a solution which is close to $x_{reg}$ and to the acquired data $y$. +# +# Setting the derivative of the functional $F$ to zero and rearranging yields +# +# $ (A^H W A + l B) x = A^H W y + l x_{reg}$ +# +# which is a linear system $Hx = b$ that needs to be solved for $x$. +# +# One important question of course is, what to use for $x_{reg}$. For dynamic images (e.g. cine MRI) low-resolution +# dynamic images or high-quality static images have been proposed. In recent years, also the output of neural-networks +# has been used as an image regulariser. +# +# In this example we are going to use a high-quality image to regularize the reconstruction of an undersampled image. +# Both images are obtained from the same data acquisition (one using all the acquired data ($x_{reg}$) and one using +# only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the +# regularization. + +# %% +import mrpro + +# %% [markdown] +# ##### Read-in the raw data +# %% +from mrpro.data import KData +from mrpro.data.traj_calculators import KTrajectoryIsmrmrd + +# Load in the Data and the trajectory from the ISMRMRD file +kdata = KData.from_file(data_file.name, KTrajectoryIsmrmrd()) +kdata.header.recon_matrix.x = 256 +kdata.header.recon_matrix.y = 256 + +# %% [markdown] +# ##### Image $x_{reg}$ from fully sampled data + +# %% +from mrpro.algorithms.reconstruction import DirectReconstruction, IterativeSENSEReconstruction +from mrpro.data import CsmData + +# Estimate coil maps +direct_reconstruction = DirectReconstruction(kdata, csm=None) +img_coilwise = direct_reconstruction(kdata) +csm = CsmData.from_idata_walsh(img_coilwise) + +# Iterative SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata, csm=csm, n_iterations=3) +img_iterative_sense = iterative_sense_reconstruction(kdata) + +# %% [markdown] +# ##### Image $x$ from undersampled data + +# %% +import torch + +# Data undersampling, i.e. take only the first 20 radial lines +idx_us = torch.arange(0, 20)[None, :] +kdata_us = kdata.split_k1_into_other(idx_us, other_label='repetition') + +# %% +# Iterativ SENSE reconstruction +iterative_sense_reconstruction = IterativeSENSEReconstruction(kdata_us, csm=csm, n_iterations=6) +img_us_iterative_sense = iterative_sense_reconstruction(kdata_us) + +# %% +# Regularized iterativ SENSE reconstruction +from mrpro.algorithms.reconstruction import RegularizedIterativeSENSEReconstruction + +regularization_weight = 1.0 +n_iterations = 6 +regularized_iterative_sense_reconstruction = RegularizedIterativeSENSEReconstruction( + kdata_us, + csm=csm, + n_iterations=n_iterations, + regularization_data=img_iterative_sense.data, + regularization_weight=regularization_weight, +) +img_us_regularized_iterative_sense = regularized_iterative_sense_reconstruction(kdata_us) + +# %% +import matplotlib.pyplot as plt + +vis_im = [img_iterative_sense.rss(), img_us_iterative_sense.rss(), img_us_regularized_iterative_sense.rss()] +vis_title = ['Fully sampled', 'Iterative SENSE R=20', 'Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 3, squeeze=False, figsize=(12, 4)) +for ind in range(3): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + + +# %% [markdown] +# ### Behind the scenes + +# %% [markdown] +# ##### Set-up the density compensation operator $W$ and acquisition model $A$ +# +# This is very similar to the iterative SENSE reconstruction. For more detail please look at the +# iterative_sense_reconstruction notebook. +# %% +dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_us.traj).as_operator() +fourier_operator = mrpro.operators.FourierOp.from_kdata(kdata_us) +csm_operator = csm.as_operator() +acquisition_operator = fourier_operator @ csm_operator + +# %% [markdown] +# ##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$ + +# %% +right_hand_side = ( + acquisition_operator.H(dcf_operator(kdata_us.data)[0])[0] + regularization_weight * img_iterative_sense.data +) + + +# %% [markdown] +# ##### Set-up the linear self-adjoint operator $H = A^H W A + l$ + +# %% +from mrpro.operators import IdentityOp + +operator = acquisition_operator.H @ dcf_operator @ acquisition_operator + IdentityOp() * torch.as_tensor( + regularization_weight +) + +# %% [markdown] +# ##### Run conjugate gradient + +# %% +img_manual = mrpro.algorithms.optimizers.cg( + operator, right_hand_side, initial_value=right_hand_side, max_iterations=n_iterations, tolerance=0.0 +) + +# %% +# Display the reconstructed image +vis_im = [img_us_regularized_iterative_sense.rss(), img_manual.abs()[:, 0, ...]] +vis_title = ['Regularized Iterative SENSE R=20', '"Manual" Regularized Iterative SENSE R=20'] +fig, ax = plt.subplots(1, 2, squeeze=False, figsize=(8, 4)) +for ind in range(2): + ax[0, ind].imshow(vis_im[ind][0, 0, ...]) + ax[0, ind].set_title(vis_title[ind]) + +# %% [markdown] +# ### Check for equal results +# The two versions should result in the same image data. + +# %% +# If the assert statement did not raise an exception, the results are equal. +assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual) + +# %% [markdown] +# ### Next steps +# Play around with the regularization_weight to see how it effects the final image quality. +# +# Of course we are cheating here because we used the fully sampled image as a regularization. In real world applications +# we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the +# streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality +# compared to the unregularised images. diff --git a/pyproject.toml b/pyproject.toml index 8d2780589..31798a35c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ testpaths = ["tests"] filterwarnings = [ "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", + "ignore:Anomaly Detection has been enabled:UserWarning", #torch.autograd ] addopts = "-n auto" markers = ["cuda : Tests only to be run when cuda device is available"] diff --git a/src/mrpro/VERSION b/src/mrpro/VERSION index 0a67c464f..0f6ae6fb6 100644 --- a/src/mrpro/VERSION +++ b/src/mrpro/VERSION @@ -1 +1 @@ -0.241015 +0.241029 diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index 54cb1d782..b879796b0 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -81,11 +81,6 @@ def cg( if torch.vdot(residual.flatten(), residual.flatten()) == 0: return solution - # squared tolerance; - # (we will check ||residual||^2 < tolerance^2 instead of ||residual|| < tol - # to avoid the computation of the root for the norm) - tolerance_squared = tolerance**2 - # dummy value. new value will be set in loop before first usage residual_norm_squared_previous = None @@ -95,7 +90,7 @@ def cg( residual_norm_squared = torch.vdot(residual_flat, residual_flat).real # check if the solution is already accurate enough - if tolerance != 0 and (residual_norm_squared < tolerance_squared): + if tolerance != 0 and (residual_norm_squared < tolerance**2): return solution if iteration > 0: @@ -105,8 +100,8 @@ def cg( # update estimates of the solution and the residual (operator_conjugate_vector,) = operator(conjugate_vector) alpha = residual_norm_squared / (torch.vdot(conjugate_vector.flatten(), operator_conjugate_vector.flatten())) - solution += alpha * conjugate_vector - residual -= alpha * operator_conjugate_vector + solution = solution + alpha * conjugate_vector + residual = residual - alpha * operator_conjugate_vector residual_norm_squared_previous = residual_norm_squared diff --git a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py index 04d37128c..32785a91a 100644 --- a/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py +++ b/src/mrpro/algorithms/reconstruction/IterativeSENSEReconstruction.py @@ -4,20 +4,17 @@ from collections.abc import Callable -import torch - -from mrpro.algorithms.optimizers.cg import cg -from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace -from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import ( + RegularizedIterativeSENSEReconstruction, +) from mrpro.data._kdata.KData import KData from mrpro.data.CsmData import CsmData from mrpro.data.DcfData import DcfData -from mrpro.data.IData import IData from mrpro.data.KNoise import KNoise from mrpro.operators.LinearOperator import LinearOperator -class IterativeSENSEReconstruction(DirectReconstruction): +class IterativeSENSEReconstruction(RegularizedIterativeSENSEReconstruction): r"""Iterative SENSE reconstruction. This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2` @@ -49,6 +46,8 @@ def __init__( ) -> None: """Initialize IterativeSENSEReconstruction. + For a regularized version of the iterative SENSE algorithm please see RegularizedIterativeSENSEReconstruction. + Parameters ---------- kdata @@ -74,76 +73,4 @@ def __init__( ValueError If the kdata and fourier_op are None or if csm is a Callable but kdata is None. """ - super().__init__(kdata, fourier_op, csm, noise, dcf) - self.n_iterations = n_iterations - - def _self_adjoint_operator(self) -> LinearOperator: - """Create the self-adjoint operator. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Create the self-adjoint operator as :math:`H = A^H W A` if the DCF is not None otherwise as :math:`H = A^H A`. - """ - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Create H = A^H W A - operator = operator.H @ dcf_operator @ operator - else: - # Create H = A^H A - operator = operator.H @ operator - - return operator - - def _right_hand_side(self, kdata: KData) -> torch.Tensor: - """Calculate the right-hand-side of the normal equation. - - Create the acquisition model as :math:`A = F S` if the CSM :math:`S` is defined otherwise :math:`A = F` with - the Fourier operator :math:`F`. - - Calculate the right-hand-side as :math:`b = A^H W y` if the DCF is not None otherwise as :math:`b = A^H y`. - - Parameters - ---------- - kdata - k-space data to reconstruct. - """ - device = kdata.data.device - operator = self.fourier_op @ self.csm.as_operator() if self.csm is not None else self.fourier_op - - if self.dcf is not None: - dcf_operator = self.dcf.as_operator() - # Calculate b = A^H W y - (right_hand_side,) = operator.to(device).H(dcf_operator(kdata.data)[0]) - else: - # Calculate b = A^H y - (right_hand_side,) = operator.to(device).H(kdata.data) - - return right_hand_side - - def forward(self, kdata: KData) -> IData: - """Apply the reconstruction. - - Parameters - ---------- - kdata - k-space data to reconstruct. - - Returns - ------- - the reconstruced image. - """ - device = kdata.data.device - if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) - - operator = self._self_adjoint_operator().to(device) - right_hand_side = self._right_hand_side(kdata) - - img_tensor = cg( - operator, right_hand_side, initial_value=right_hand_side, max_iterations=self.n_iterations, tolerance=0.0 - ) - img = IData.from_tensor_and_kheader(img_tensor, kdata.header) - return img + super().__init__(kdata, fourier_op, csm, noise, dcf, n_iterations=n_iterations, regularization_weight=0) diff --git a/src/mrpro/algorithms/reconstruction/Reconstruction.py b/src/mrpro/algorithms/reconstruction/Reconstruction.py index 127be8d5f..c4208157e 100644 --- a/src/mrpro/algorithms/reconstruction/Reconstruction.py +++ b/src/mrpro/algorithms/reconstruction/Reconstruction.py @@ -101,15 +101,13 @@ def direct_reconstruction(self, kdata: KData) -> IData: ------- image data """ - device = kdata.data.device if self.noise is not None: - kdata = prewhiten_kspace(kdata, self.noise.to(device)) + kdata = prewhiten_kspace(kdata, self.noise) operator = self.fourier_op if self.csm is not None: operator = operator @ self.csm.as_operator() if self.dcf is not None: operator = self.dcf.as_operator() @ operator - operator = operator.to(device) (img_tensor,) = operator.H(kdata.data) img = IData.from_tensor_and_kheader(img_tensor, kdata.header) return img diff --git a/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py new file mode 100644 index 000000000..c9a307ebe --- /dev/null +++ b/src/mrpro/algorithms/reconstruction/RegularizedIterativeSENSEReconstruction.py @@ -0,0 +1,139 @@ +"""Regularized Iterative SENSE Reconstruction by adjoint Fourier transform.""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch + +from mrpro.algorithms.optimizers.cg import cg +from mrpro.algorithms.prewhiten_kspace import prewhiten_kspace +from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.data._kdata.KData import KData +from mrpro.data.CsmData import CsmData +from mrpro.data.DcfData import DcfData +from mrpro.data.IData import IData +from mrpro.data.KNoise import KNoise +from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperator import LinearOperator + + +class RegularizedIterativeSENSEReconstruction(DirectReconstruction): + r"""Regularized iterative SENSE reconstruction. + + This algorithm solves the problem :math:`min_x \frac{1}{2}||W^\frac{1}{2} (Ax - y)||_2^2 + + \frac{1}{2}L||Bx - x_0||_2^2` + by using a conjugate gradient algorithm to solve + :math:`H x = b` with :math:`H = A^H W A + L B` and :math:`b = A^H W y + L x_0` where :math:`A` + is the acquisition model (coil sensitivity maps, Fourier operator, k-space sampling), :math:`y` is the acquired + k-space data, :math:`W` describes the density compensation, :math:`L` is the strength of the regularization and + :math:`x_0` is the regularization image (i.e. the prior). :math:`B` is a linear operator applied to :math:`x`. + """ + + n_iterations: int + """Number of CG iterations.""" + + regularization_data: torch.Tensor + """Regularization data (i.e. prior) :math:`x_0`.""" + + regularization_weight: torch.Tensor + """Strength of the regularization :math:`L`.""" + + regularization_op: LinearOperator + """Linear operator :math:`B` applied to the current estimate in the regularization term.""" + + def __init__( + self, + kdata: KData | None = None, + fourier_op: LinearOperator | None = None, + csm: Callable | CsmData | None = CsmData.from_idata_walsh, + noise: KNoise | None = None, + dcf: DcfData | None = None, + *, + n_iterations: int = 5, + regularization_data: float | torch.Tensor = 0.0, + regularization_weight: float | torch.Tensor, + regularization_op: LinearOperator | None = None, + ) -> None: + """Initialize RegularizedIterativeSENSEReconstruction. + + For a unregularized version of the iterative SENSE algorithm the regularization_weight can be set to 0 or + IterativeSENSEReconstruction algorithm can be used. + + Parameters + ---------- + kdata + KData. If kdata is provided and fourier_op or dcf are None, then fourier_op and dcf are estimated based on + kdata. Otherwise fourier_op and dcf are used as provided. + fourier_op + Instance of the FourierOperator used for reconstruction. If None, set up based on kdata. + csm + Sensitivity maps for coil combination. If None, no coil combination is carried out, i.e. images for each + coil are returned. If a callable is provided, coil images are reconstructed using the adjoint of the + FourierOperator (including density compensation) and then sensitivity maps are calculated using the + callable. For this, kdata needs also to be provided. For examples have a look at the CsmData class + e.g. from_idata_walsh or from_idata_inati. + noise + KNoise used for prewhitening. If None, no prewhitening is performed + dcf + K-space sampling density compensation. If None, set up based on kdata. + n_iterations + Number of CG iterations + regularization_data + Regularization data, e.g. a reference image (:math:`x_0`). + regularization_weight + Strength of the regularization (:math:`L`). + regularization_op + Linear operator :math:`B` applied to the current estimate in the regularization term. If None, nothing is + applied to the current estimate. + + + Raises + ------ + ValueError + If the kdata and fourier_op are None or if csm is a Callable but kdata is None. + """ + super().__init__(kdata, fourier_op, csm, noise, dcf) + self.n_iterations = n_iterations + self.regularization_data = torch.as_tensor(regularization_data) + self.regularization_weight = torch.as_tensor(regularization_weight) + self.regularization_op = regularization_op if regularization_op is not None else IdentityOp() + + def forward(self, kdata: KData) -> IData: + """Apply the reconstruction. + + Parameters + ---------- + kdata + k-space data to reconstruct. + + Returns + ------- + the reconstruced image. + """ + if self.noise is not None: + kdata = prewhiten_kspace(kdata, self.noise) + + # Create the normal operator as H = A^H W A if the DCF is not None otherwise as H = A^H A. + # The acquisition model is A = F S if the CSM S is defined otherwise A = F with the Fourier operator F + csm_op = self.csm.as_operator() if self.csm is not None else IdentityOp() + precondition_op = self.dcf.as_operator() if self.dcf is not None else IdentityOp() + operator = (self.fourier_op @ csm_op).H @ precondition_op @ (self.fourier_op @ csm_op) + + # Calculate the right-hand-side as b = A^H W y if the DCF is not None otherwise as b = A^H y. + (right_hand_side,) = (self.fourier_op @ csm_op).H(precondition_op(kdata.data)[0]) + + # Add regularization + if not torch.all(self.regularization_weight == 0): + operator = operator + IdentityOp() @ (self.regularization_weight * self.regularization_op) + right_hand_side += self.regularization_weight * self.regularization_data + + img_tensor = cg( + operator, + right_hand_side, + initial_value=right_hand_side, + max_iterations=self.n_iterations, + tolerance=0.0, + ) + img = IData.from_tensor_and_kheader(img_tensor, kdata.header) + return img diff --git a/src/mrpro/algorithms/reconstruction/__init__.py b/src/mrpro/algorithms/reconstruction/__init__.py index 180186dd4..38b539f8b 100644 --- a/src/mrpro/algorithms/reconstruction/__init__.py +++ b/src/mrpro/algorithms/reconstruction/__init__.py @@ -1,8 +1,10 @@ from mrpro.algorithms.reconstruction.Reconstruction import Reconstruction from mrpro.algorithms.reconstruction.DirectReconstruction import DirectReconstruction +from mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction import RegularizedIterativeSENSEReconstruction from mrpro.algorithms.reconstruction.IterativeSENSEReconstruction import IterativeSENSEReconstruction __all__ = [ "DirectReconstruction", "IterativeSENSEReconstruction", - "Reconstruction" -] \ No newline at end of file + "Reconstruction", + "RegularizedIterativeSENSEReconstruction" +] diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index f136d9d51..d919e63c6 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -102,7 +102,7 @@ def operator_norm( max_iterations: int = 20, relative_tolerance: float = 1e-4, absolute_tolerance: float = 1e-5, - callback: Callable | None = None, + callback: Callable[[torch.Tensor], None] | None = None, ) -> torch.Tensor: """Power iteration for computing the operator norm of the linear operator. @@ -294,6 +294,20 @@ def __rmul__(self, other: torch.Tensor | complex) -> LinearOperator: else: return NotImplemented # type: ignore[unreachable] + def __and__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Vertical stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self], [other]] + return mrpro.operators.LinearOperatorMatrix(operators) + + def __or__(self, other: LinearOperator) -> mrpro.operators.LinearOperatorMatrix: + """Horizontal stacking of two LinearOperators.""" + if not isinstance(other, LinearOperator): + return NotImplemented # type: ignore[unreachable] + operators = [[self, other]] + return mrpro.operators.LinearOperatorMatrix(operators) + @property def gram(self) -> LinearOperator: """Gram operator. diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py new file mode 100644 index 000000000..ab0673398 --- /dev/null +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -0,0 +1,363 @@ +"""Linear Operator Matrix class.""" + +from __future__ import annotations + +import operator +from collections.abc import Callable, Iterator, Sequence +from functools import reduce +from types import EllipsisType +from typing import cast + +import torch +from typing_extensions import Self, Unpack + +from mrpro.operators.LinearOperator import LinearOperator, LinearOperatorSum +from mrpro.operators.Operator import Operator +from mrpro.operators.ZeroOp import ZeroOp + +_SingleIdxType = int | slice | EllipsisType | Sequence[int] +_IdxType = _SingleIdxType | tuple[_SingleIdxType, _SingleIdxType] + + +class LinearOperatorMatrix(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]): + r"""Matrix of Linear Operators. + + A matrix of Linear Operators, where each element is a Linear Operator. + + This matrix can be applied to a sequence of tensors, where the number of tensors should match + the number of columns of the matrix. The output will be a sequence of tensors, where the number + of tensors will match the number of rows of the matrix. + The i-th output tensor is calculated as + :math:`\sum_j \text{operators}[i][j](x[j])` where :math:`\text{operators}[i][j]` is the linear operator + in the i-th row and j-th column and :math:`x[j]` is the j-th input tensor. + + The matrix can be indexed and sliced like a regular matrix to get submatrices. + If indexing returns a single element, it is returned as a Linear Operator. + + Basic arithmetic operations are supported with Linear Operators and Tensors. + + """ + + _operators: list[list[LinearOperator]] + + def __init__(self, operators: Sequence[Sequence[LinearOperator]]): + """Initialize Linear Operator Matrix. + + Create a matrix of LinearOperators from a sequence of rows, where each row is a sequence + of LinearOperators that represent the columns of the matrix. + + Parameters + ---------- + operators + A sequence of rows, which are sequences of Linear Operators. + """ + if not all(isinstance(op, LinearOperator) for row in operators for op in row): + raise ValueError('All elements should be LinearOperators.') + if not all(len(row) == len(operators[0]) for row in operators): + raise ValueError('All rows should have the same length.') + super().__init__() + self._operators = cast( # cast because ModuleList is not recognized as a list + list[list[LinearOperator]], torch.nn.ModuleList(torch.nn.ModuleList(row) for row in operators) + ) + self._shape = (len(operators), len(operators[0]) if operators else 0) + + @property + def shape(self) -> tuple[int, int]: + """Shape of the Operator Matrix (rows, columns).""" + return self._shape + + def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has columns. + + Returns + ------- + Output tensors. The same number of tensors as the operator has rows. + """ + if len(x) != self.shape[1]: + raise ValueError('Input should be the same number of tensors as the LinearOperatorMatrix has columns.') + return tuple( + reduce(operator.add, (op(xi)[0] for op, xi in zip(row, x, strict=True))) for row in self._operators + ) + + def __getitem__(self, idx: _IdxType) -> Self | LinearOperator: + """Index the Operator Matrix. + + Parameters + ---------- + idx + Index or slice to select rows and columns. + + Returns + ------- + Subset LinearOperatorMatrix or Linear Operator. + """ + idxs: tuple[_SingleIdxType, _SingleIdxType] = idx if isinstance(idx, tuple) else (idx, slice(None)) + if len(idxs) > 2: + raise IndexError('Too many indices for LinearOperatorMatrix') + + def _to_numeric_index(idx: slice | int | Sequence[int] | EllipsisType, length: int) -> Sequence[int]: + """Convert index to a sequence of integers or raise an error.""" + if isinstance(idx, slice): + if (idx.start is not None and (idx.start < -length or idx.start >= length)) or ( + idx.stop is not None and (idx.stop < -length or idx.stop > length) + ): + raise IndexError('Index out of range') + return range(*idx.indices(length)) + if isinstance(idx, int): + if idx < -length or idx >= length: + raise IndexError('Index out of range') + return (idx,) + if idx is Ellipsis: + return range(length) + if isinstance(idx, Sequence): + if min(idx) < -length or max(idx) >= length: + raise IndexError('Index out of range') + return idx + else: + raise IndexError('Invalid index type') + + row_numbers = _to_numeric_index(idxs[0], self._shape[0]) + col_numbers = _to_numeric_index(idxs[1], self._shape[1]) + + sliced_operators = [ + [row[col_number] for col_number in col_numbers] + for row in [self._operators[row_number] for row_number in row_numbers] + ] + + # Return a single operator if only one row and column is selected + if len(row_numbers) == 1 and len(col_numbers) == 1: + return sliced_operators[0][0] + else: + return self.__class__(sliced_operators) + + def __iter__(self) -> Iterator[Sequence[LinearOperator]]: + """Iterate over the rows of the Operator Matrix.""" + return iter(self._operators) + + def __repr__(self): + """Representation of the Operator Matrix.""" + return f'LinearOperatorMatrix(shape={self._shape}, operators={self._operators})' + + # Note: The type ignores are needed because we currently cannot do arithmetic operations with non-linear operators. + def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: ignore[override] + """Addition.""" + operators: list[list[LinearOperator]] = [] + if isinstance(other, LinearOperatorMatrix): + if self.shape != other.shape: + raise ValueError('OperatorMatrix shapes do not match.') + for self_row, other_row in zip(self._operators, other._operators, strict=False): + operators.append([s + o for s, o in zip(self_row, other_row, strict=False)]) + elif isinstance(other, LinearOperator | torch.Tensor): + if not self.shape[0] == self.shape[1]: + raise NotImplementedError('Cannot add a LinearOperator to a non-square OperatorMatrix.') + for i, self_row in enumerate(self._operators): + operators.append([op + other if i == j else op for j, op in enumerate(self_row)]) + else: + return NotImplemented # type: ignore[unreachable] + return self.__class__(operators) + + def __radd__(self, other: Self | LinearOperator | torch.Tensor) -> Self: + """Right addition.""" + return self.__add__(other) + + def __mul__(self, other: torch.Tensor | Sequence[torch.Tensor | complex] | complex) -> Self: + """LinearOperatorMatrix*Tensor multiplication. + + Example: ([A,B]*c)(x) = [A*c, B*c](x) = A(c*x) + B(c*x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[1] + elif len(other) != self.shape[1]: + raise ValueError('Other should have the same length as the operator has columns.') + else: + other_ = other + operators = [] + for row in self._operators: + operators.append([op * o for op, o in zip(row, other_, strict=True)]) + return self.__class__(operators) + + def __rmul__(self, other: torch.Tensor | Sequence[torch.Tensor] | complex) -> Self: + """Tensor*LinearOperatorMatrix multiplication. + + Example: (c*[A,B])(x) = [c*A, c*B](x) = c*A(x) + c*B(x) + """ + if isinstance(other, torch.Tensor | complex | float | int): + other_: Sequence[torch.Tensor | complex] = (other,) * self.shape[0] + elif len(other) != self.shape[0]: + raise ValueError('Other should have the same length as the operator has rows.') + else: + other_ = other + operators = [] + for row, o in zip(self._operators, other_, strict=True): + operators.append([cast(LinearOperator, o * op) for op in row]) + return self.__class__(operators) + + def __matmul__(self, other: LinearOperator | Self) -> Self: # type: ignore[override] + """Composition of operators.""" + if isinstance(other, LinearOperator): + return self._binary_operation(other, '__matmul__') + elif isinstance(other, LinearOperatorMatrix): + if self.shape[1] != other.shape[0]: + raise ValueError('OperatorMatrix shapes do not match.') + new_operators = [] + for row in self._operators: + new_row = [] + for other_col in zip(*other._operators, strict=True): + elements = [s @ o for s, o in zip(row, other_col, strict=True)] + new_row.append(LinearOperatorSum(*elements)) + new_operators.append(new_row) + return self.__class__(new_operators) + return NotImplemented # type: ignore[unreachable] + + @property + def H(self) -> Self: # noqa N802 + """Adjoints of the operators.""" + return self.__class__([[op.H for op in row] for row in zip(*self._operators, strict=True)]) + + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply the adjoint of the operator to the input. + + Parameters + ---------- + x + Input tensors. Requires the same number of tensors as the operator has rows. + + Returns + ------- + Output tensors. The same number of tensors as the operator has columns. + """ + return self.H(*x) + + @classmethod + def from_diagonal(cls, *operators: LinearOperator): + """Create a diagonal LinearOperatorMatrix. + + Create a square LinearOperatorMatrix with the given Linear Operators on the diagonal, + resulting in a block-diagonal linear operator. + + Parameters + ---------- + operators + Sequence of Linear Operators to be placed on the diagonal. + """ + operator_matrix: list[list[LinearOperator]] = [ + [op if i == j else ZeroOp(False) for j in range(len(operators))] for i, op in enumerate(operators) + ] + return cls(operator_matrix) + + def operator_norm( + self, + *initial_value: torch.Tensor, + dim: Sequence[int] | None = None, + max_iterations: int = 20, + relative_tolerance: float = 1e-4, + absolute_tolerance: float = 1e-5, + callback: Callable[[torch.Tensor], None] | None = None, + ) -> torch.Tensor: + """Upper bound of operator norm of the Matrix. + + Uses the bounds :math:`||[A, B]^T|||<=sqrt(||A||^2 + ||B||^2)` and :math:`||[A, B]|||<=max(||A||,||B||)` + to estimate the operator norm of the matrix. + First, operator_norm is called on each element of the matrix. + Next, the norm is estimated for each column using the first bound. + Finally, the norm of the full matrix of linear operators is calculated using the second bound. + + Parameters + ---------- + initial_value + Initial value(s) for the power iteration, length should match the number of columns + of the operator matrix. + dim + Dimensions to calculate the operator norm over. Other dimensions are assumed to be + batch dimensions. None means all dimensions. + max_iterations + Maximum number of iterations used in the power iteration. + relative_tolerance + Relative tolerance for convergence. + absolute_tolerance + Absolute tolerance for convergence. + callback + Callback function to be called with the current estimate of the operator norm. + + + Returns + ------- + Estimated operator norm upper bound. + """ + + def _singlenorm(op: LinearOperator, initial_value: torch.Tensor): + return op.operator_norm( + initial_value, + dim=dim, + max_iterations=max_iterations, + relative_tolerance=relative_tolerance, + absolute_tolerance=absolute_tolerance, + callback=callback, + ) + + if len(initial_value) != self.shape[1]: + raise ValueError('Initial value should have the same length as the operator has columns.') + norms = torch.tensor( + [[_singlenorm(op, iv) for op, iv in zip(row, initial_value, strict=True)] for row in self._operators] + ) + norm = norms.square().sum(-2).sqrt().amax(-1).unsqueeze(-1) + return norm + + def __or__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Vertical stacking.""" + if isinstance(other, LinearOperator): + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking : cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[*self._operators[0], other]] + return self.__class__(operators) + else: + if (rows_self := self.shape[0]) != (rows_other := other.shape[0]): + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack matrices with {rows_self} and {rows_other}.' + ) + operators = [[*self_row, *other_row] for self_row, other_row in zip(self, other, strict=True)] + return self.__class__(operators) + + def __ror__(self, other: LinearOperator) -> Self: + """Vertical stacking.""" + if rows := self.shape[0] > 1: + raise ValueError( + f'Shape mismatch in vertical stacking: cannot stack LinearOperator and matrix with {rows} rows.' + ) + operators = [[other, *self._operators[0]]] + return self.__class__(operators) + + def __and__(self, other: LinearOperator | LinearOperatorMatrix) -> Self: + """Horizontal stacking.""" + if isinstance(other, LinearOperator): + if cols := self.shape[1] > 1: + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [*self._operators, [other]] + return self.__class__(operators) + else: + if (cols_self := self.shape[1]) != (cols_other := other.shape[1]): + raise ValueError( + 'Shape mismatch in horizontal stacking:' + f'cannot stack matrices with {cols_self} and {cols_other} columns.' + ) + operators = [*self._operators, *other] + return self.__class__(operators) + + def __rand__(self, other: LinearOperator) -> Self: + """Horizontal stacking.""" + if cols := self.shape[1] > 1: + raise ValueError( + f'Shape mismatch in horizontal stacking: cannot stack LinearOperator and matrix with {cols} columns.' + ) + operators = [[other], *self._operators] + return self.__class__(operators) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 691c88f16..c22f386cd 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -11,6 +11,7 @@ from mrpro.operators.FourierOp import FourierOp from mrpro.operators.GridSamplingOp import GridSamplingOp from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp from mrpro.operators.PCACompressionOp import PCACompressionOp @@ -37,6 +38,7 @@ "GridSamplingOp", "IdentityOp", "LinearOperator", + "LinearOperatorMatrix", "MagnitudeOp", "MultiIdentityOp", "Operator", diff --git a/tests/algorithms/test_cg.py b/tests/algorithms/test_cg.py index 16c16e78b..8a4434e2a 100644 --- a/tests/algorithms/test_cg.py +++ b/tests/algorithms/test_cg.py @@ -145,3 +145,13 @@ def callback(cg_status: CGStatus) -> None: assert True cg(h_operator, right_hand_side, callback=callback) + + +def test_autograd(system): + """Test autograd through cg""" + h_operator, right_hand_side, _ = system + right_hand_side.requires_grad_(True) + with torch.autograd.detect_anomaly(): + result = cg(h_operator, right_hand_side, tolerance=0, max_iterations=5) + result.abs().sum().backward() + assert right_hand_side.grad is not None diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py new file mode 100644 index 000000000..7ba87d715 --- /dev/null +++ b/tests/operators/test_linearoperatormatrix.py @@ -0,0 +1,322 @@ +from typing import Any + +import pytest +import torch +from mrpro.operators import EinsumOp, LinearOperator, MagnitudeOp +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix + +from tests import RandomGenerator +from tests.helper import dotproduct_adjointness_test + + +def random_linearop(size, rng): + """Create a random LinearOperator.""" + return EinsumOp(rng.complex64_tensor(size), '... i j, ... j -> ... i') + + +def random_linearoperatormatrix(size, inner_size, rng): + """Create a random LinearOperatorMatrix.""" + operators = [[random_linearop(inner_size, rng) for i in range(size[1])] for j in range(size[0])] + return LinearOperatorMatrix(operators) + + +def test_linearoperatormatrix_shape(): + """Test creation and shape of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert matrix.shape == (5, 3) + + +def test_linearoperatormatrix_add_matrix(): + """Test addition of two LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 + matrix2)(*vector) + expected = tuple(a + b for a, b in zip(matrix1(*vector), matrix2(*vector), strict=False)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_add_tensor_nonsquare(): + """Test failure of addition of tensor to non-square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + with pytest.raises(NotImplementedError, match='square'): + (matrix1 + other) + + +def test_linearoperatormatrix_add_tensor_square(): + """Add tensor to square matrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 3), (2, 2), rng) + other = rng.complex64_tensor(2) + vector = rng.complex64_tensor((3, 2)) + result = (matrix1 + other)(*vector) + expected = tuple((mv + other * v for mv, v in zip(matrix1(*vector), vector, strict=True))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_rmul(): + """Test post multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(3) + vector = rng.complex64_tensor((3, 10)) + result = (other * matrix1)(*vector) + expected = tuple(other * el for el in matrix1(*vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_mul(): + """Test pre multiplication with tensor.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((5, 3), (3, 10), rng) + other = rng.complex64_tensor(10) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 * other)(*vector) + expected = matrix1(*(other * el for el in vector)) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition(): + """Test composition of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((5, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = (matrix1 @ matrix2)(*vector) + expected = matrix1(*(matrix2(*vector))) + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_composition_mismatch(): + """Test composition with mismatching shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((1, 5), (2, 3), rng) + matrix2 = random_linearoperatormatrix((4, 3), (3, 10), rng) + vector = rng.complex64_tensor((4, 10)) + with pytest.raises(ValueError, match='shapes do not match'): + (matrix1 @ matrix2)(*vector) + + +def test_linearoperatormatrix_adjoint(): + """Test adjointness of Adjoint.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + dotproduct_adjointness_test(Wrapper(), rng.complex64_tensor((3, 10)), rng.complex64_tensor((5, 3))) + + +def test_linearoperatormatrix_repr(): + """Test repr of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((5, 3), (3, 10), rng) + assert 'LinearOperatorMatrix(shape=(5, 3)' in repr(matrix) + + +def test_linearoperatormatrix_getitem(): + """Test slicing of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + def check(actual, expected): + assert tuple(tuple(row) for row in actual) == tuple(tuple(row) for row in expected) + + sliced = matrix[1:3, 2] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[2:3] for row in matrix._operators[1:3]]) + + sliced = matrix[0] + assert sliced.shape == (1, 6) + check(sliced._operators, matrix._operators[:1]) + + sliced = matrix[..., 0] + assert sliced.shape == (12, 1) + check(sliced._operators, [row[:1] for row in matrix._operators]) + + sliced = matrix[1:6:2, (3, 4)] + assert sliced.shape == (3, 2) + check(sliced._operators, [[matrix._operators[i][j] for j in (3, 4)] for i in range(1, 6, 2)]) + + sliced = matrix[-2:-4:-1, -1] + assert sliced.shape == (2, 1) + check(sliced._operators, [row[-1:] for row in matrix._operators[-2:-4:-1]]) + + sliced = matrix[5, 5] + assert isinstance(sliced, LinearOperator) + assert sliced == matrix._operators[5][5] + + +def test_linearoperatormatrix_getitem_error(): + """Test error when slicing with wrong indices.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((12, 6), (3, 10), rng) + + with pytest.raises(IndexError, match='Too many indices'): + matrix[1, 1, 1] + + with pytest.raises(IndexError, match='out of range'): + matrix[20] + with pytest.raises(IndexError, match='out of range'): + matrix[-20] + with pytest.raises(IndexError, match='out of range'): + matrix[1:100] + with pytest.raises(IndexError, match='out of range'): + matrix[(100, 1)] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., -20] + with pytest.raises(IndexError, match='out of range'): + matrix[..., 1:100] + with pytest.raises(IndexError, match='index type'): + matrix[..., 1.0] + + +def test_linearoperatormatrix_norm_rows(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((3, 1), (3, 10), rng) + vector = rng.complex64_tensor((1, 10)) + result = matrix.operator_norm(*vector) + expected = sum(row[0].operator_norm(vector[0], dim=None) ** 2 for row in matrix._operators) ** 0.5 + torch.testing.assert_close(result, expected) + + +def test_linearoperatormatrix_norm_cols(): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((1, 3), (3, 10), rng) + vector = rng.complex64_tensor((3, 10)) + result = matrix.operator_norm(*vector) + expected = max(op.operator_norm(v, dim=None) for op, v in zip(matrix._operators[0], vector, strict=False)) + torch.testing.assert_close(result, expected) + + +@pytest.mark.parametrize('seed', [0, 1, 2, 3]) +def test_linearoperatormatrix_norm(seed): + """Test norm of LinearOperatorMatrix.""" + rng = RandomGenerator(seed) + matrix = random_linearoperatormatrix((4, 2), (3, 10), rng) + vector = rng.complex64_tensor((2, 10)) + result = matrix.operator_norm(*vector) + + class Wrapper(LinearOperator): + """Stack the output of the matrix operator.""" + + def forward(self, x): + return (torch.stack(matrix(*x), 0),) + + def adjoint(self, x): + return (torch.stack(matrix.adjoint(*x), 0),) + + real = Wrapper().operator_norm(vector, dim=None) + + assert result >= real + + +def test_linearoperatormatrix_shorthand_vertical(): + """Test shorthand for vertical stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 10), rng) + op2 = random_linearop((4, 10), rng) + x1 = rng.complex64_tensor((10,)) + + matrix1 = op1 & op2 + assert matrix1.shape == (2, 1) + + actual = matrix1(x1) + expected = (*op1(x1), *op2(x1)) + torch.testing.assert_close(actual, expected) + + matrix2 = op2 & (matrix1 & op1) + assert matrix2.shape == (4, 1) + + matrix3 = matrix2 & matrix2 + assert matrix3.shape == (8, 1) + + actual = matrix3(x1) + expected = 2 * (*op2(x1), *matrix1(x1), *op1(x1)) + torch.testing.assert_close(actual, expected) + + +def test_linearoperatormatrix_shorthand_horizontal(): + """Test shorthand for horizontal stacking.""" + rng = RandomGenerator(0) + op1 = random_linearop((3, 4), rng) + op2 = random_linearop((3, 2), rng) + x1 = rng.complex64_tensor((4,)) + x2 = rng.complex64_tensor((2,)) + x3 = rng.complex64_tensor((4,)) + x4 = rng.complex64_tensor((2,)) + + matrix1 = op1 | op2 + assert matrix1.shape == (1, 2) + + actual1 = matrix1(x1, x2) + expected1 = (op1(x1)[0] + op2(x2)[0],) + torch.testing.assert_close(actual1, expected1) + + matrix2 = op2 | (matrix1 | op1) + assert matrix2.shape == (1, 4) + + matrix3 = matrix2 | matrix2 + assert matrix3.shape == (1, 8) + + expected3 = (2 * (op2(x2)[0] + (matrix1(x3, x4)[0] + op1(x1)[0])),) + actual3 = matrix3(x2, x3, x4, x1, x2, x3, x4, x1) + torch.testing.assert_close(actual3, expected3) + + +def test_linearoperatormatrix_stacking_error(): + """Test error when stacking matrix operators with different shapes.""" + rng = RandomGenerator(0) + matrix1 = random_linearoperatormatrix((3, 4), (3, 10), rng) + matrix2 = random_linearoperatormatrix((3, 2), (3, 10), rng) + matrix3 = random_linearoperatormatrix((2, 4), (3, 10), rng) + op = random_linearop((3, 10), rng) + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & matrix2 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | matrix3 + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 | op + with pytest.raises(ValueError, match='Shape mismatch'): + matrix1 & op + + +def test_linearoperatormatrix_error_nonlinearop(): + """Test error if trying to create a LinearOperatorMatrix with non linear operator.""" + op: Any = [[MagnitudeOp()]] # Any is used to hide this error from mypy + with pytest.raises(ValueError, match='LinearOperator'): + LinearOperatorMatrix(op) + + +def test_linearoperatormatrix_error_inconsistent_shapes(): + """Test error if trying to create a LinearOperatorMatrix with inonsistent row lengths.""" + rng = RandomGenerator(0) + op = random_linearop((3, 4), rng) + with pytest.raises(ValueError, match='same length'): + LinearOperatorMatrix([[op, op], [op]]) + + +def test_linearoperatormatrix_from_diagonal(): + """Test creation of LinearOperatorMatrix from diagonal.""" + rng = RandomGenerator(0) + ops = [random_linearop((2, 4), rng) for _ in range(3)] + matrix = LinearOperatorMatrix.from_diagonal(*ops) + xs = rng.complex64_tensor((3, 4)) + actual = matrix(*xs) + expected = tuple(op(x)[0] for op, x in zip(ops, xs, strict=False)) + torch.testing.assert_close(actual, expected) diff --git a/tests/operators/test_operator_norm.py b/tests/operators/test_operator_norm.py index d5ae39873..c6e13dcf9 100644 --- a/tests/operators/test_operator_norm.py +++ b/tests/operators/test_operator_norm.py @@ -12,9 +12,8 @@ def test_power_iteration_uses_stopping_criterion(): """Test if the power iteration stops if the absolute and relative tolerance are chosen high.""" - # callback function that should not be called because the power iteration - # should stop if the tolerances are set high - def callback(): + def callback(_): + """Callback function that should not be called, because the power iteration should stop.""" pytest.fail('The power iteration did not stop despite high atol and rtol!') random_generator = RandomGenerator(seed=0)