From 84537989a17873c3bb62d5aba0218213b3ad6ccb Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 13 Jan 2025 00:19:52 +0100 Subject: [PATCH] update --- .../comparison_trajectory_calculators.ipynb | 32 +-- .../notebooks/direct_reconstruction.ipynb | 6 +- ...e_reconstruction_with_regularization.ipynb | 176 ++++++-------- .../notebooks/qmri_sg_challenge_2024_t1.ipynb | 221 +++++++++++------- .../comparison_trajectory_calculators.py | 10 +- examples/scripts/direct_reconstruction.py | 6 +- ...ense_reconstruction_with_regularization.py | 61 +++-- examples/scripts/qmri_sg_challenge_2024_t1.py | 160 ++++++++----- .../scripts/qmri_sg_challenge_2024_t2_star.py | 146 ------------ 9 files changed, 357 insertions(+), 461 deletions(-) delete mode 100644 examples/scripts/qmri_sg_challenge_2024_t2_star.py diff --git a/examples/notebooks/comparison_trajectory_calculators.ipynb b/examples/notebooks/comparison_trajectory_calculators.ipynb index 94d965d3..604d2fe7 100644 --- a/examples/notebooks/comparison_trajectory_calculators.ipynb +++ b/examples/notebooks/comparison_trajectory_calculators.ipynb @@ -35,7 +35,7 @@ }, "source": [ "# Different ways to obtain the Trajectory\n", - "This example builds upon the example and demonstrates three ways\n", + "This example builds upon the example and demonstrates three ways\n", "to obtain the trajectory information required for image reconstruction:\n", "- using the trajectory that is stored in the ISMRMRD file\n", "- calculating the trajectory using the radial 2D trajectory calculator\n", @@ -107,7 +107,7 @@ "id": "6", "metadata": {}, "source": [ - "### Using KTrajectoryRadial2D - Trajectory\n", + "### Using KTrajectoryRadial2D - Specific trajectory calculator\n", "For some common trajectories, we provide specific trajectory calculators.\n", "These calculators often require only a few parameters to be specified,\n", "such as the angle between spokes in the radial trajectory. Other parameters\n", @@ -143,7 +143,7 @@ "id": "8", "metadata": {}, "source": [ - "### Using KTrajectoryPulseq\n", + "### Using KTrajectoryPulseq - Trajectory from pulseq sequence file\n", "This will calculate the trajectory from the pulseq sequence file\n", "using the PyPulseq trajectory calculator. This method\n", "requires the pulseq sequence file that was used to acquire the data.\n", @@ -223,20 +223,26 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "13", "metadata": {}, - "outputs": [], "source": [ - "# Tada! We have successfully reconstructed images using three different trajectory calculators.\n", - "# ```{note}\n", - "# Which of these three methods is the best depends on the specific use case:\n", - "# If a trajectory is already stored in the ISMRMRD file, it is the most convenient to use.\n", - "# If a pulseq sequence file is available, the trajectory can be calculated using the PyPulseq trajectory calculator.\n", - "# Otherwise, a trajectory calculator needs to be implemented for the specific trajectory used.\n", - "# ```" + "Tada! We have successfully reconstructed images using three different trajectory calculators.\n", + "```{note}\n", + "Which of these three methods is the best depends on the specific use case:\n", + "If a trajectory is already stored in the ISMRMRD file, it is the most convenient to use.\n", + "If a pulseq sequence file is available, the trajectory can be calculated using the PyPulseq trajectory calculator.\n", + "Otherwise, a trajectory calculator needs to be implemented for the specific trajectory used.\n", + "```" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/notebooks/direct_reconstruction.ipynb b/examples/notebooks/direct_reconstruction.ipynb index 34bec560..ca8c00e7 100644 --- a/examples/notebooks/direct_reconstruction.ipynb +++ b/examples/notebooks/direct_reconstruction.ipynb @@ -53,7 +53,6 @@ "import tempfile\n", "from pathlib import Path\n", "\n", - "import mrpro.algorithms.dcf\n", "import zenodo_get\n", "\n", "dataset = '14617082'\n", @@ -115,12 +114,13 @@ "source": [ "## Setup the DirectReconstruction instance\n", "We create a `~mrpro.algorithms.reconstruction.DirectReconstruction` and supply ``kdata``.\n", - "`~mrpro.algorithms.reconstruction.DirectReconstruction` uses the information in `kdata` to\n", + "`~mrpro.algorithms.reconstruction.DirectReconstruction` uses the information in ``kdata`` to\n", " setup a Fourier transfrm, density compensation factors, and estimate coil sensitivity maps.\n", "(See the *Behind the scenes* section for more details.)\n", "\n", "```{note}\n", - "You can also directly set the Fourier operator, coil sensitivity maps, dcf, etc. of the reconstruction instance.\n", + "You can also directly set the Fourier operator, coil sensitivity maps, density compensation factors, etc.\n", + "of the reconstruction instance.\n", "```" ] }, diff --git a/examples/notebooks/iterative_sense_reconstruction_with_regularization.ipynb b/examples/notebooks/iterative_sense_reconstruction_with_regularization.ipynb index 4401105a..e97f61a2 100644 --- a/examples/notebooks/iterative_sense_reconstruction_with_regularization.ipynb +++ b/examples/notebooks/iterative_sense_reconstruction_with_regularization.ipynb @@ -67,7 +67,9 @@ { "cell_type": "markdown", "id": "4", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ "### Image reconstruction\n", "We use the `~mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction` class to reconstruct images\n", @@ -107,81 +109,29 @@ "regularization." ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", - "id": "6", - "metadata": { - "lines_to_next_cell": 0 - }, - "source": [ - "##### Read-in the raw data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "# Download raw data from Zenodo\n", - "import tempfile\n", - "from pathlib import Path\n", - "\n", - "import mrpro\n", - "import torch\n", - "import zenodo_get\n", - "\n", - "dataset = '14617082'\n", - "\n", - "tmp = tempfile.TemporaryDirectory() # RAII, automatically cleaned up\n", - "data_folder = Path(tmp.name)\n", - "zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries" - ] - }, - { - "cell_type": "markdown", - "id": "9", + "id": "5", "metadata": {}, "source": [ "### Reading of both fully sampled and undersampled data\n", - "This will use the trajectory that is stored in the ISMRMRD file." + "We read the raw data and the trajectory from the ISMRMRD file.\n", + "We load both, the fully sampled and the undersampled data.\n", + "The fully sampled data will be used to estimate the coil sensitivity maps and as a regularization image.\n", + "The undersampled data will be used to reconstruct the image." ] }, { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "6", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "# Read the raw data and the trajectory from ISMRMRD file\n", + "import mrpro\n", "\n", "kdata_fullysampled = mrpro.data.KData.from_file(\n", " data_folder / 'radial2D_402spokes_golden_angle_with_traj.h5',\n", @@ -195,7 +145,7 @@ }, { "cell_type": "markdown", - "id": "11", + "id": "7", "metadata": {}, "source": [ "##### Image $x_{reg}$ from fully sampled data\n", @@ -207,7 +157,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -227,22 +177,22 @@ }, { "cell_type": "markdown", - "id": "13", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "9", + "metadata": {}, "source": [ - "##### Image $x$ from undersampled data" + "##### Image $x$ from undersampled data\n", + "We now reconstruct the undersampled image using the fully sampled image first wthout regularization,\n", + "and with with an regularization image." ] }, { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "10", "metadata": {}, "outputs": [], "source": [ - "# unregularized iterative SENSE reconstruction of the undersampled data\n", + "# Unregularized iterative SENSE reconstruction of the undersampled data\n", "iterative_sense_reconstruction = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(\n", " kdata_undersampled, csm=csm, n_iterations=6\n", ")\n", @@ -252,7 +202,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -270,18 +220,20 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "12", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "##### Display the results" + "##### Display the results\n", + "Besides the fully sampled image, we display two undersampled images:\n", + "The first one is obtained by unregularized iterative SENSE, the second one using regularization." ] }, { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "13", "metadata": { "tags": [ "hide-cell" @@ -290,6 +242,7 @@ "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", + "import torch\n", "\n", "\n", "def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None:\n", @@ -307,7 +260,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -321,29 +274,32 @@ }, { "cell_type": "markdown", - "id": "19", + "id": "15", "metadata": {}, "source": [ - "### Behind the scenes" + "### Behind the scenes\n", + "We now investigate the steps that are done in the regularized iterative SENSE reconstruction and\n", + "perform them manually. This also demonstrates how to use the `mrpro` operators and algorithms\n", + "to build your own reconstruction pipeline." ] }, { "cell_type": "markdown", - "id": "20", + "id": "16", "metadata": { "lines_to_next_cell": 0 }, "source": [ "##### Set-up the density compensation operator $W$ and acquisition model $A$\n", "\n", - "This is very similar to .\n", + "This is very similar to .\n", "For more details, please refer to that notebook." ] }, { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -355,16 +311,18 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "18", "metadata": {}, "source": [ - "##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$" + "##### Calculate the right-hand-side of the linear system\n", + "We calculated $b = A^H W y + l x_{reg}$.\n", + "Here, we make use of operator composition using ``@``." ] }, { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -377,20 +335,22 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "20", "metadata": {}, "source": [ - "##### Set-up the linear self-adjoint operator $H = A^H W A + l$" + "##### Set-up the linear self-adjoint operator $H$\n", + "We define $H= A^H W A + l$. We use the `~mrpro.operators.IdentityOp` and make\n", + "use of operator composition using ``@``, addition using ``+`` and multiplication using ``*``.\n", + "The resulting operator is a `~mrpro.operators.LinearOperator` object." ] }, { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "21", "metadata": {}, "outputs": [], "source": [ - "\n", "operator = (\n", " acquisition_operator.H @ dcf_operator @ acquisition_operator + mrpro.operators.IdentityOp() * regularization_weight\n", ")" @@ -398,16 +358,21 @@ }, { "cell_type": "markdown", - "id": "26", - "metadata": {}, + "id": "22", + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ - "##### Run conjugate gradient" + "##### Run conjugate gradient\n", + "We solve the linear system $Hx = b$ using the conjugate gradient method.\n", + "Here, we use early stopping after 8 iterations. Instead, we could also use a tolerance to stop the iterations when\n", + "the residual is small enough." ] }, { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "23", "metadata": { "lines_to_next_cell": 0 }, @@ -420,16 +385,18 @@ }, { "cell_type": "markdown", - "id": "28", + "id": "24", "metadata": {}, "source": [ - "##### Display the reconstructed image" + "##### Display the reconstructed image\n", + "We can now compare our 'manual' reconstruction with the regularized iterative SENSE reconstruction\n", + "obtained using `~mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction`." ] }, { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "25", "metadata": { "lines_to_next_cell": 0 }, @@ -444,27 +411,28 @@ }, { "cell_type": "markdown", - "id": "30", - "metadata": {}, + "id": "26", + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ - "### Check for equal results\n", - "The two versions should result in the same image data." + "We can also check if the results are equal by comparing the actual image data.\n", + "If the assert statement does not raise an exception, the results are equal." ] }, { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "27", "metadata": {}, "outputs": [], "source": [ - "# If the assert statement did not raise an exception, the results are equal.\n", "assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual)" ] }, { "cell_type": "markdown", - "id": "32", + "id": "28", "metadata": {}, "source": [ "### Next steps\n", @@ -474,14 +442,6 @@ "streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality\n", "compared to the unregularised images." ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/examples/notebooks/qmri_sg_challenge_2024_t1.ipynb b/examples/notebooks/qmri_sg_challenge_2024_t1.ipynb index 040f1a16..8872919d 100644 --- a/examples/notebooks/qmri_sg_challenge_2024_t1.ipynb +++ b/examples/notebooks/qmri_sg_challenge_2024_t1.ipynb @@ -30,9 +30,19 @@ { "cell_type": "markdown", "id": "2", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ - "# QMRI Challenge ISMRM 2024 - $T_1$ mapping" + "# QMRI Challenge ISMRM 2024 - $T_1$ mapping\n", + "In the 2024 ISMRM QMRI Challenge, the goal is to estimate $T_1$ maps from a set of inversion recovery images.\n", + "The dataset consists of images obtained at 10 different inversion times using a turbo spin echo sequence. Each\n", + "inversion time is saved in a separate DICOM file. In order to obtain a $T_1$ map, we are going to:\n", + "- download the data from Zenodo\n", + "- read in the DICOM files (one for each inversion time) and combine them in an IData object\n", + "- define a signal model and data loss (mean-squared error) function\n", + "- find good starting values for each pixel\n", + "- carry out a fit using ADAM from PyTorch" ] }, { @@ -42,43 +52,28 @@ "metadata": {}, "outputs": [], "source": [ - "# Imports\n", - "import shutil\n", - "import tempfile\n", - "import zipfile\n", - "from pathlib import Path\n", + "# # Imports\n", + "# import shutil\n", + "# import tempfile\n", + "# import zipfile\n", + "# from pathlib import Path\n", "\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "import zenodo_get\n", - "from einops import rearrange\n", - "from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped]\n", - "from mrpro.algorithms.optimizers import adam\n", - "from mrpro.data import IData\n", - "from mrpro.operators import MagnitudeOp\n", - "from mrpro.operators.functionals import MSE\n", - "from mrpro.operators.models import InversionRecovery" + "# import matplotlib.pyplot as plt\n", + "# import torch\n", + "# import zenodo_get\n", + "# from einops import rearrange\n", + "# from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped]\n", + "# from mrpro.algorithms.optimizers import adam\n", + "# from mrpro.data import IData\n", + "# from mrpro.operators import MagnitudeOp\n", + "# from mrpro.operators.functionals import MSE\n", + "# from mrpro.operators.models import InversionRecovery" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, - "source": [ - "### Overview\n", - "The dataset consists of images obtained at 10 different inversion times using a turbo spin echo sequence. Each\n", - "inversion time is saved in a separate DICOM file. In order to obtain a $T_1$ map, we are going to:\n", - "- download the data from Zenodo\n", - "- read in the DICOM files (one for each inversion time) and combine them in an IData object\n", - "- define a signal model and data loss (mean-squared error) function\n", - "- find good starting values for each pixel\n", - "- carry out a fit using ADAM from PyTorch" - ] - }, - { - "cell_type": "markdown", - "id": "5", - "metadata": {}, "source": [ "### Get data from Zenodo" ] @@ -86,12 +81,24 @@ { "cell_type": "code", "execution_count": null, - "id": "6", - "metadata": {}, + "id": "5", + "metadata": { + "tags": [ + "hide-output" + ] + }, "outputs": [], "source": [ - "data_folder = Path(tempfile.mkdtemp())\n", + "import tempfile\n", + "import zipfile\n", + "from pathlib import Path\n", + "\n", + "import zenodo_get\n", + "\n", "dataset = '10868350'\n", + "\n", + "tmp = tempfile.TemporaryDirectory() # RAII, automatically cleaned up\n", + "data_folder = Path(tmp.name)\n", "zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries\n", "with zipfile.ZipFile(data_folder / Path('T1 IR.zip'), 'r') as zip_ref:\n", " zip_ref.extractall(data_folder)" @@ -99,28 +106,59 @@ }, { "cell_type": "markdown", - "id": "7", + "id": "6", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "### Create image data (IData) object with different inversion times" + "### Create image data (IData) object with different inversion times\n", + "We read in the DICOM files and combine them in an `mrpro.data.IData` object.\n", + "The inversion times are stored in the DICOM files are available in the header of the `~mrpro.data.IData` object." ] }, { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "7", "metadata": {}, "outputs": [], "source": [ + "import mrpro\n", + "\n", "ti_dicom_files = data_folder.glob('**/*.dcm')\n", - "idata_multi_ti = IData.from_dicom_files(ti_dicom_files)\n", + "idata_multi_ti = mrpro.data.IData.from_dicom_files(ti_dicom_files)\n", "\n", "if idata_multi_ti.header.ti is None:\n", " raise ValueError('Inversion times need to be defined in the DICOM files.')" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import torch\n", + "\n", + "\n", + "def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None:\n", + " \"\"\"Plot images.\"\"\"\n", + " n_images = len(images)\n", + " _, axes = plt.subplots(1, n_images, squeeze=False, figsize=(n_images * 3, 3))\n", + " for i in range(n_images):\n", + " axes[0][i].imshow(images[i], cmap='gray')\n", + " axes[0][i].axis('off')\n", + " if titles:\n", + " axes[0][i].set_title(titles[i])\n", + " plt.show()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -129,10 +167,10 @@ "outputs": [], "source": [ "# Let's have a look at some of the images\n", - "fig, axes = plt.subplots(1, 3, squeeze=False)\n", - "for idx, ax in enumerate(axes.flatten()):\n", - " ax.imshow(torch.abs(idata_multi_ti.data[idx, 0, 0, :, :]))\n", - " ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.3f}s')" + "show_images(\n", + " *idata_multi_ti.data[:, 0, 0].abs(),\n", + " titles=[f'TI = {ti:.3f}s' for ti in idata_multi_ti.header.ti.squeeze()],\n", + ")" ] }, { @@ -156,7 +194,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = MagnitudeOp() @ InversionRecovery(ti=idata_multi_ti.header.ti)" + "model = mrpro.operators.MagnitudeOp() @ mrpro.operators.models.InversionRecovery(ti=idata_multi_ti.header.ti)" ] }, { @@ -177,7 +215,7 @@ "metadata": {}, "outputs": [], "source": [ - "mse = MSE(idata_multi_ti.data.abs())" + "mse = mrpro.operators.functionals.MSE(idata_multi_ti.data.abs())" ] }, { @@ -237,14 +275,18 @@ "# just a scaling factor and we are going to normalize the signal curves.\n", "(signal_dictionary,) = model(torch.ones(1), t1_dictionary)\n", "signal_dictionary = signal_dictionary.to(dtype=torch.complex64)\n", - "vector_norm = torch.linalg.vector_norm(signal_dictionary, dim=0)\n", - "signal_dictionary /= vector_norm\n", + "signal_dictionary /= torch.linalg.vector_norm(signal_dictionary, dim=0)\n", "\n", "# Calculate the dot-product and select for each voxel the T1 values that correspond to the maximum of the dot-product\n", - "n_y, n_x = idata_multi_ti.data.shape[-2:]\n", - "dot_product = torch.mm(rearrange(idata_multi_ti.data, 'other 1 z y x->(z y x) other'), signal_dictionary)\n", - "idx_best_match = torch.argmax(torch.abs(dot_product), dim=1)\n", - "t1_start = rearrange(t1_dictionary[idx_best_match], '(y x)->1 1 y x', y=n_y, x=n_x)" + "import einops\n", + "\n", + "dot_product = einops.einsum(\n", + " idata_multi_ti.data,\n", + " signal_dictionary,\n", + " 'ti ..., ti t1 -> t1 ...',\n", + ")\n", + "idx_best_match = dot_product.abs().argmax(dim=0)\n", + "t1_start = t1_dictionary[idx_best_match]" ] }, { @@ -255,25 +297,30 @@ "outputs": [], "source": [ "# The maximum absolute value observed is a good approximation for m0\n", - "m0_start = torch.amax(torch.abs(idata_multi_ti.data), 0)" + "m0_start = idata_multi_ti.data.abs().amax(dim=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "19", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "# Visualize the starting values\n", - "fig, axes = plt.subplots(1, 2, figsize=(8, 2), squeeze=False)\n", - "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", - "im = axes[0, 0].imshow(m0_start[0, 0, ...])\n", + "fig, axes = plt.subplots(1, 2, figsize=(6, 2), squeeze=False)\n", + "\n", + "im = axes[0, 0].imshow(m0_start[0, 0])\n", "axes[0, 0].set_title('$M_0$ start values')\n", - "fig.colorbar(im, cax=colorbar_ax[0])\n", - "im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2.5)\n", + "axes[0, 0].set_axis_off()\n", + "fig.colorbar(im, ax=axes[0, 0], label='a.u.')\n", + "\n", + "im = axes[0, 1].imshow(t1_start[0, 0], vmin=0, vmax=2.5, cmap='magma')\n", "axes[0, 1].set_title('$T_1$ start values')\n", - "fig.colorbar(im, cax=colorbar_ax[1], label='s')" + "axes[0, 1].set_axis_off()\n", + "fig.colorbar(im, ax=axes[0, 1], label='s')" ] }, { @@ -296,10 +343,8 @@ "lr = 1e-1\n", "\n", "# Run optimization\n", - "params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr)\n", - "m0, t1 = (p.detach() for p in params_result)\n", - "m0[torch.isnan(t1)] = 0\n", - "t1[torch.isnan(t1)] = 0" + "params_result = mrpro.algorithms.optimizers.adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr)\n", + "m0, t1 = (p.detach() for p in params_result)" ] }, { @@ -312,44 +357,40 @@ "\n", "$E_{relative} = \\sum_{TI}\\frac{|(q(M_0, T_1, TI) - x)|}{|x|}$\n", "\n", - "on a voxel-by-voxel basis" + "on a voxel-by-voxel basis\n", + "We also mask out the background by thresholding on $M_0$." ] }, { "cell_type": "code", "execution_count": null, "id": "23", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [], "source": [ - "img_mult_te_abs_sum = torch.sum(torch.abs(idata_multi_ti.data), dim=0)\n", - "relative_absolute_error = torch.sum(torch.abs(model(m0, t1)[0] - idata_multi_ti.data), dim=0) / (\n", - " img_mult_te_abs_sum + 1e-9\n", - ")\n", + "error = model(m0, t1)[0] - idata_multi_ti.data\n", + "relative_absolute_error = error.abs().sum(dim=0) / (idata_multi_ti.data.abs().sum(dim=0) + 1e-9)\n", "fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False)\n", - "colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n", - "im = axes[0, 0].imshow(m0[0, 0, ...])\n", + "\n", + "mask = torch.isnan(t1) | (m0 < 500)\n", + "m0[mask] = 0\n", + "t1[mask] = 0\n", + "relative_absolute_error[mask] = 0\n", + "\n", + "im = axes[0, 0].imshow(m0[0, 0])\n", "axes[0, 0].set_title('$M_0$')\n", - "fig.colorbar(im, cax=colorbar_ax[0])\n", - "im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2.5)\n", + "axes[0, 0].set_axis_off()\n", + "fig.colorbar(im, ax=axes[0, 0], label='a.u.')\n", + "\n", + "im = axes[0, 1].imshow(t1[0, 0], vmin=0, vmax=2.5)\n", "axes[0, 1].set_title('$T_1$')\n", - "fig.colorbar(im, cax=colorbar_ax[1], label='s')\n", - "im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...], vmin=0, vmax=1.0)\n", + "axes[0, 1].set_axis_off()\n", + "fig.colorbar(im, ax=axes[0, 1], label='s')\n", + "\n", + "im = axes[0, 2].imshow(relative_absolute_error[0, 0], vmin=0, vmax=1.0)\n", "axes[0, 2].set_title('Relative error')\n", - "fig.colorbar(im, cax=colorbar_ax[2])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "# Clean-up by removing temporary directory\n", - "shutil.rmtree(data_folder)" + "axes[0, 2].set_axis_off()\n", + "fig.colorbar(im, ax=axes[0, 2])" ] } ], @@ -360,7 +401,7 @@ "provenance": [] }, "jupytext": { - "cell_metadata_filter": "-all" + "cell_metadata_filter": "tags,-all" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", diff --git a/examples/scripts/comparison_trajectory_calculators.py b/examples/scripts/comparison_trajectory_calculators.py index 1f4abc2c..8dfd6df0 100644 --- a/examples/scripts/comparison_trajectory_calculators.py +++ b/examples/scripts/comparison_trajectory_calculators.py @@ -1,6 +1,6 @@ # %% [markdown] # # Different ways to obtain the Trajectory -# This example builds upon the example and demonstrates three ways +# This example builds upon the example and demonstrates three ways # to obtain the trajectory information required for image reconstruction: # - using the trajectory that is stored in the ISMRMRD file # - calculating the trajectory using the radial 2D trajectory calculator @@ -45,7 +45,7 @@ img_using_ismrmrd_traj = reconstruction(kdata) # %% [markdown] -# ### Using KTrajectoryRadial2D - Trajectory +# ### Using KTrajectoryRadial2D - Specific trajectory calculator # For some common trajectories, we provide specific trajectory calculators. # These calculators often require only a few parameters to be specified, # such as the angle between spokes in the radial trajectory. Other parameters @@ -69,7 +69,7 @@ img_using_rad2d_traj = reconstruction(kdata) # %% [markdown] -# ### Using KTrajectoryPulseq +# ### Using KTrajectoryPulseq - Trajectory from pulseq sequence file # This will calculate the trajectory from the pulseq sequence file # using the PyPulseq trajectory calculator. This method # requires the pulseq sequence file that was used to acquire the data. @@ -115,7 +115,7 @@ def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: titles=['KTrajectoryIsmrmrd', 'KTrajectoryRadial2D', 'KTrajectoryPulseq'], ) -# %% +# %% [markdown] # Tada! We have successfully reconstructed images using three different trajectory calculators. # ```{note} # Which of these three methods is the best depends on the specific use case: @@ -123,3 +123,5 @@ def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: # If a pulseq sequence file is available, the trajectory can be calculated using the PyPulseq trajectory calculator. # Otherwise, a trajectory calculator needs to be implemented for the specific trajectory used. # ``` + +# %% diff --git a/examples/scripts/direct_reconstruction.py b/examples/scripts/direct_reconstruction.py index 5e45e5ee..a7624111 100644 --- a/examples/scripts/direct_reconstruction.py +++ b/examples/scripts/direct_reconstruction.py @@ -6,7 +6,6 @@ import tempfile from pathlib import Path -import mrpro.algorithms.dcf import zenodo_get dataset = '14617082' @@ -39,12 +38,13 @@ # %% [markdown] ### Setup the DirectReconstruction instance # We create a `~mrpro.algorithms.reconstruction.DirectReconstruction` and supply ``kdata``. -# `~mrpro.algorithms.reconstruction.DirectReconstruction` uses the information in `kdata` to +# `~mrpro.algorithms.reconstruction.DirectReconstruction` uses the information in ``kdata`` to # setup a Fourier transfrm, density compensation factors, and estimate coil sensitivity maps. # (See the *Behind the scenes* section for more details.) # # ```{note} -# You can also directly set the Fourier operator, coil sensitivity maps, dcf, etc. of the reconstruction instance. +# You can also directly set the Fourier operator, coil sensitivity maps, density compensation factors, etc. +# of the reconstruction instance. # ``` # %% diff --git a/examples/scripts/iterative_sense_reconstruction_with_regularization.py b/examples/scripts/iterative_sense_reconstruction_with_regularization.py index 09cb90d5..feec70aa 100644 --- a/examples/scripts/iterative_sense_reconstruction_with_regularization.py +++ b/examples/scripts/iterative_sense_reconstruction_with_regularization.py @@ -52,31 +52,17 @@ # only parts of it ($x$)). This of course is an unrealistic case but it will allow us to study the effect of the # regularization. -# %% -# %% [markdown] -# ##### Read-in the raw data -# %% -# %% tags=["hide-cell"] -# Download raw data from Zenodo -import tempfile -from pathlib import Path - -import mrpro -import torch -import zenodo_get - -dataset = '14617082' - -tmp = tempfile.TemporaryDirectory() # RAII, automatically cleaned up -data_folder = Path(tmp.name) -zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries # %% [markdown] # ### Reading of both fully sampled and undersampled data -# This will use the trajectory that is stored in the ISMRMRD file. +# We read the raw data and the trajectory from the ISMRMRD file. +# We load both, the fully sampled and the undersampled data. +# The fully sampled data will be used to estimate the coil sensitivity maps and as a regularization image. +# The undersampled data will be used to reconstruct the image. # %% # Read the raw data and the trajectory from ISMRMRD file +import mrpro kdata_fullysampled = mrpro.data.KData.from_file( data_folder / 'radial2D_402spokes_golden_angle_with_traj.h5', @@ -108,10 +94,11 @@ # %% [markdown] # ##### Image $x$ from undersampled data - +# We now reconstruct the undersampled image using the fully sampled image first wthout regularization, +# and with with an regularization image. # %% -# unregularized iterative SENSE reconstruction of the undersampled data +# Unregularized iterative SENSE reconstruction of the undersampled data iterative_sense_reconstruction = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction( kdata_undersampled, csm=csm, n_iterations=6 ) @@ -131,8 +118,11 @@ # %% [markdown] # ##### Display the results +# Besides the fully sampled image, we display two undersampled images: +# The first one is obtained by unregularized iterative SENSE, the second one using regularization. # %% tags=["hide-cell"] import matplotlib.pyplot as plt +import torch def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: @@ -157,11 +147,14 @@ def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: # %% [markdown] # ### Behind the scenes +# We now investigate the steps that are done in the regularized iterative SENSE reconstruction and +# perform them manually. This also demonstrates how to use the `mrpro` operators and algorithms +# to build your own reconstruction pipeline. # %% [markdown] # ##### Set-up the density compensation operator $W$ and acquisition model $A$ # -# This is very similar to . +# This is very similar to . # For more details, please refer to that notebook. # %% dcf_operator = mrpro.data.DcfData.from_traj_voronoi(kdata_undersampled.traj).as_operator() @@ -170,7 +163,9 @@ def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: acquisition_operator = fourier_operator @ csm_operator # %% [markdown] -# ##### Calculate the right-hand-side of the linear system $b = A^H W y + l x_{reg}$ +# ##### Calculate the right-hand-side of the linear system +# We calculated $b = A^H W y + l x_{reg}$. +# Here, we make use of operator composition using ``@``. # %% regularization_weight = 1.0 @@ -180,23 +175,29 @@ def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: right_hand_side = right_hand_side + regularization_weight * regularization_image # %% [markdown] -# ##### Set-up the linear self-adjoint operator $H = A^H W A + l$ +# ##### Set-up the linear self-adjoint operator $H$ +# We define $H= A^H W A + l$. We use the `~mrpro.operators.IdentityOp` and make +# use of operator composition using ``@``, addition using ``+`` and multiplication using ``*``. +# The resulting operator is a `~mrpro.operators.LinearOperator` object. # %% - operator = ( acquisition_operator.H @ dcf_operator @ acquisition_operator + mrpro.operators.IdentityOp() * regularization_weight ) # %% [markdown] # ##### Run conjugate gradient - +# We solve the linear system $Hx = b$ using the conjugate gradient method. +# Here, we use early stopping after 8 iterations. Instead, we could also use a tolerance to stop the iterations when +# the residual is small enough. # %% img_manual = mrpro.algorithms.optimizers.cg( operator, right_hand_side, initial_value=right_hand_side, max_iterations=8, tolerance=0.0 ) # %% [markdown] # ##### Display the reconstructed image +# We can now compare our 'manual' reconstruction with the regularized iterative SENSE reconstruction +# obtained using `~mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction`. # %% show_images( @@ -205,11 +206,9 @@ def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: titles=['Regularized Iterative SENSE R=20', '"Manual" Regularized Iterative SENSE R=20'], ) # %% [markdown] -# ### Check for equal results -# The two versions should result in the same image data. - +# We can also check if the results are equal by comparing the actual image data. +# If the assert statement does not raise an exception, the results are equal. # %% -# If the assert statement did not raise an exception, the results are equal. assert torch.allclose(img_us_regularized_iterative_sense.data, img_manual) # %% [markdown] @@ -219,5 +218,3 @@ def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: # we would not have that. One option is to apply a low-pass filter to the undersampled k-space data to try to reduce the # streaking artifacts and use that as a regularization image. Try that and see if you can also improve the image quality # compared to the unregularised images. - -# %% diff --git a/examples/scripts/qmri_sg_challenge_2024_t1.py b/examples/scripts/qmri_sg_challenge_2024_t1.py index d0259f26..77296332 100644 --- a/examples/scripts/qmri_sg_challenge_2024_t1.py +++ b/examples/scripts/qmri_sg_challenge_2024_t1.py @@ -1,26 +1,6 @@ # %% [markdown] # # QMRI Challenge ISMRM 2024 - $T_1$ mapping - -# %% -# Imports -import shutil -import tempfile -import zipfile -from pathlib import Path - -import matplotlib.pyplot as plt -import torch -import zenodo_get -from einops import rearrange -from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped] -from mrpro.algorithms.optimizers import adam -from mrpro.data import IData -from mrpro.operators import MagnitudeOp -from mrpro.operators.functionals import MSE -from mrpro.operators.models import InversionRecovery - -# %% [markdown] -# ### Overview +# In the 2024 ISMRM QMRI Challenge, the goal is to estimate $T_1$ maps from a set of inversion recovery images. # The dataset consists of images obtained at 10 different inversion times using a turbo spin echo sequence. Each # inversion time is saved in a separate DICOM file. In order to obtain a $T_1$ map, we are going to: # - download the data from Zenodo @@ -28,32 +8,78 @@ # - define a signal model and data loss (mean-squared error) function # - find good starting values for each pixel # - carry out a fit using ADAM from PyTorch +# %% +# # Imports +# import shutil +# import tempfile +# import zipfile +# from pathlib import Path + +# import matplotlib.pyplot as plt +# import torch +# import zenodo_get +# from einops import rearrange +# from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped] +# from mrpro.algorithms.optimizers import adam +# from mrpro.data import IData +# from mrpro.operators import MagnitudeOp +# from mrpro.operators.functionals import MSE +# from mrpro.operators.models import InversionRecovery # %% [markdown] # ### Get data from Zenodo -# %% -data_folder = Path(tempfile.mkdtemp()) +# %% tags=["hide-output"] +import tempfile +import zipfile +from pathlib import Path + +import zenodo_get + dataset = '10868350' + +tmp = tempfile.TemporaryDirectory() # RAII, automatically cleaned up +data_folder = Path(tmp.name) zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries with zipfile.ZipFile(data_folder / Path('T1 IR.zip'), 'r') as zip_ref: zip_ref.extractall(data_folder) # %% [markdown] # ### Create image data (IData) object with different inversion times +# We read in the DICOM files and combine them in an `mrpro.data.IData` object. +# The inversion times are stored in the DICOM files are available in the header of the `~mrpro.data.IData` object. # %% +import mrpro + ti_dicom_files = data_folder.glob('**/*.dcm') -idata_multi_ti = IData.from_dicom_files(ti_dicom_files) +idata_multi_ti = mrpro.data.IData.from_dicom_files(ti_dicom_files) if idata_multi_ti.header.ti is None: raise ValueError('Inversion times need to be defined in the DICOM files.') +# %% tags=["hide-cell"] +import matplotlib.pyplot as plt +import torch + + +def show_images(*images: torch.Tensor, titles: list[str] | None = None) -> None: + """Plot images.""" + n_images = len(images) + _, axes = plt.subplots(1, n_images, squeeze=False, figsize=(n_images * 3, 3)) + for i in range(n_images): + axes[0][i].imshow(images[i], cmap='gray') + axes[0][i].axis('off') + if titles: + axes[0][i].set_title(titles[i]) + plt.show() + + # %% # Let's have a look at some of the images -fig, axes = plt.subplots(1, 3, squeeze=False) -for idx, ax in enumerate(axes.flatten()): - ax.imshow(torch.abs(idata_multi_ti.data[idx, 0, 0, :, :])) - ax.set_title(f'TI = {idata_multi_ti.header.ti[idx]:.3f}s') +show_images( + *idata_multi_ti.data[:, 0, 0].abs(), + titles=[f'TI = {ti:.3f}s' for ti in idata_multi_ti.header.ti.squeeze()], +) # %% [markdown] # ### Signal model and loss function @@ -65,13 +91,13 @@ # images only contain the magnitude of the signal. Therefore, we need $|q(TI)|$: # %% -model = MagnitudeOp() @ InversionRecovery(ti=idata_multi_ti.header.ti) +model = mrpro.operators.MagnitudeOp() @ mrpro.operators.models.InversionRecovery(ti=idata_multi_ti.header.ti) # %% [markdown] # As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal # model $q$. # %% -mse = MSE(idata_multi_ti.data.abs()) +mse = mrpro.operators.functionals.MSE(idata_multi_ti.data.abs()) # %% [markdown] # Now we can simply combine the two into a functional to solve @@ -104,29 +130,37 @@ # just a scaling factor and we are going to normalize the signal curves. (signal_dictionary,) = model(torch.ones(1), t1_dictionary) signal_dictionary = signal_dictionary.to(dtype=torch.complex64) -vector_norm = torch.linalg.vector_norm(signal_dictionary, dim=0) -signal_dictionary /= vector_norm +signal_dictionary /= torch.linalg.vector_norm(signal_dictionary, dim=0) # Calculate the dot-product and select for each voxel the T1 values that correspond to the maximum of the dot-product -n_y, n_x = idata_multi_ti.data.shape[-2:] -dot_product = torch.mm(rearrange(idata_multi_ti.data, 'other 1 z y x->(z y x) other'), signal_dictionary) -idx_best_match = torch.argmax(torch.abs(dot_product), dim=1) -t1_start = rearrange(t1_dictionary[idx_best_match], '(y x)->1 1 y x', y=n_y, x=n_x) +import einops + +dot_product = einops.einsum( + idata_multi_ti.data, + signal_dictionary, + 'ti ..., ti t1 -> t1 ...', +) +idx_best_match = dot_product.abs().argmax(dim=0) +t1_start = t1_dictionary[idx_best_match] # %% # The maximum absolute value observed is a good approximation for m0 -m0_start = torch.amax(torch.abs(idata_multi_ti.data), 0) +m0_start = idata_multi_ti.data.abs().amax(dim=0) # %% # Visualize the starting values -fig, axes = plt.subplots(1, 2, figsize=(8, 2), squeeze=False) -colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] -im = axes[0, 0].imshow(m0_start[0, 0, ...]) +fig, axes = plt.subplots(1, 2, figsize=(6, 2), squeeze=False) + +im = axes[0, 0].imshow(m0_start[0, 0]) axes[0, 0].set_title('$M_0$ start values') -fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2.5) +axes[0, 0].set_axis_off() +fig.colorbar(im, ax=axes[0, 0], label='a.u.') + +im = axes[0, 1].imshow(t1_start[0, 0], vmin=0, vmax=2.5, cmap='magma') axes[0, 1].set_title('$T_1$ start values') -fig.colorbar(im, cax=colorbar_ax[1], label='s') +axes[0, 1].set_axis_off() +fig.colorbar(im, ax=axes[0, 1], label='s') + # %% [markdown] # ### Carry out fit @@ -137,10 +171,8 @@ lr = 1e-1 # Run optimization -params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr) +params_result = mrpro.algorithms.optimizers.adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr) m0, t1 = (p.detach() for p in params_result) -m0[torch.isnan(t1)] = 0 -t1[torch.isnan(t1)] = 0 # %% [markdown] # ### Visualize the final results @@ -149,25 +181,29 @@ # $E_{relative} = \sum_{TI}\frac{|(q(M_0, T_1, TI) - x)|}{|x|}$ # # on a voxel-by-voxel basis +# We also mask out the background by thresholding on $M_0$. # %% -img_mult_te_abs_sum = torch.sum(torch.abs(idata_multi_ti.data), dim=0) -relative_absolute_error = torch.sum(torch.abs(model(m0, t1)[0] - idata_multi_ti.data), dim=0) / ( - img_mult_te_abs_sum + 1e-9 -) +error = model(m0, t1)[0] - idata_multi_ti.data +relative_absolute_error = error.abs().sum(dim=0) / (idata_multi_ti.data.abs().sum(dim=0) + 1e-9) fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False) -colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] -im = axes[0, 0].imshow(m0[0, 0, ...]) + +mask = torch.isnan(t1) | (m0 < 500) +m0[mask] = 0 +t1[mask] = 0 +relative_absolute_error[mask] = 0 + +im = axes[0, 0].imshow(m0[0, 0]) axes[0, 0].set_title('$M_0$') -fig.colorbar(im, cax=colorbar_ax[0]) -im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2.5) -axes[0, 1].set_title('$T_1$') -fig.colorbar(im, cax=colorbar_ax[1], label='s') -im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...], vmin=0, vmax=1.0) -axes[0, 2].set_title('Relative error') -fig.colorbar(im, cax=colorbar_ax[2]) +axes[0, 0].set_axis_off() +fig.colorbar(im, ax=axes[0, 0], label='a.u.') +im = axes[0, 1].imshow(t1[0, 0], vmin=0, vmax=2.5) +axes[0, 1].set_title('$T_1$') +axes[0, 1].set_axis_off() +fig.colorbar(im, ax=axes[0, 1], label='s') -# %% -# Clean-up by removing temporary directory -shutil.rmtree(data_folder) +im = axes[0, 2].imshow(relative_absolute_error[0, 0], vmin=0, vmax=1.0) +axes[0, 2].set_title('Relative error') +axes[0, 2].set_axis_off() +fig.colorbar(im, ax=axes[0, 2]) diff --git a/examples/scripts/qmri_sg_challenge_2024_t2_star.py b/examples/scripts/qmri_sg_challenge_2024_t2_star.py deleted file mode 100644 index a80f4075..00000000 --- a/examples/scripts/qmri_sg_challenge_2024_t2_star.py +++ /dev/null @@ -1,146 +0,0 @@ -# %% [markdown] -# # QMRI Challenge ISMRM 2024 - $T_2^*$ mapping - -# %% -# Imports -import shutil -import tempfile -import time -import zipfile -from pathlib import Path - -import matplotlib.pyplot as plt -import torch -import zenodo_get -from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped] -from mrpro.algorithms.optimizers import adam -from mrpro.data import IData -from mrpro.operators.functionals import MSE -from mrpro.operators.models import MonoExponentialDecay - -# %% [markdown] -# ### Overview -# The dataset consists of gradient echo images obtained at 11 different echo times, each saved in a separate DICOM file. -# In order to obtain a $T_2^*$ map, we are going to: -# - download the data from Zenodo -# - read in the DICOM files (one for each echo time) and combine them in an IData object -# - define a signal model (mono-exponential decay) and data loss (mean-squared error) function -# - carry out a fit using ADAM from PyTorch -# -# Everything is based on PyTorch, and therefore we can run the code either on the CPU or GPU. Simply set the flag below -# to True to run the parameter estimation on the GPU. - -# %% -flag_use_cuda = False - -# %% [markdown] -# ### Get data from Zenodo - -# %% -data_folder = Path(tempfile.mkdtemp()) -dataset = '10868361' -zenodo_get.zenodo_get([dataset, '-r', 5, '-o', data_folder]) # r: retries -with zipfile.ZipFile(data_folder / Path('T2star.zip'), 'r') as zip_ref: - zip_ref.extractall(data_folder) - -# %% [markdown] -# ### Create image data (IData) object with different echo times -# %% -te_dicom_files = data_folder.glob('**/*.dcm') -idata_multi_te = IData.from_dicom_files(te_dicom_files) -# scaling the signal down to make the optimization easier -idata_multi_te.data[...] = idata_multi_te.data / 1500 - -# Move the data to the GPU -if flag_use_cuda: - idata_multi_te = idata_multi_te.cuda() - -if idata_multi_te.header.te is None: - raise ValueError('Echo times need to be defined in the DICOM files.') - -# %% -# Let's have a look at some of the images -fig, axes = plt.subplots(1, 3, squeeze=False) -for idx, ax in enumerate(axes.flatten()): - ax.imshow(torch.abs(idata_multi_te.data[idx, 0, 0, :, :]).cpu()) - ax.set_title(f'TE = {idata_multi_te.header.te[idx]:.3f}s') - -# %% [markdown] -# ### Signal model and loss function -# We use the model $q$ -# -# $q(TE) = M_0 e^{-TE/T_2^*}$ -# -# with the equilibrium magnetization $M_0$, the echo time $TE$, and $T_2^*$ - -# %% -model = MonoExponentialDecay(decay_time=idata_multi_te.header.te) - -# %% [markdown] -# As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal -# model $q$. -# %% -mse = MSE(idata_multi_te.data) - -# %% [markdown] -# Now we can simply combine the two into a functional which will then solve -# -# $ \min_{M_0, T_2^*} ||q(M_0, T_2^*, TE) - x||_2^2$ -# %% -functional = mse @ model - -# %% [markdown] -# ### Carry out fit - -# %% -# The shortest echo time is a good approximation of the equilibrium magnetization -m0_start = torch.abs(idata_multi_te.data[torch.argmin(idata_multi_te.header.te), ...]) -# 20 ms as a starting value for T2* -t2star_start = torch.ones(m0_start.shape, dtype=torch.float32, device=m0_start.device) * 20e-3 - -# Hyperparameters for optimizer -max_iter = 20000 -lr = 1e-3 - -if flag_use_cuda: - functional.cuda() - -# Run optimization -start_time = time.time() -params_result = adam(functional, [m0_start, t2star_start], max_iter=max_iter, lr=lr) -print(f'Optimization took {time.time() - start_time}s') -m0, t2star = (p.detach() for p in params_result) -m0[torch.isnan(t2star)] = 0 -t2star[torch.isnan(t2star)] = 0 - -# %% [markdown] -# ### Visualize the final results -# To get an impression of how well the fit has worked, we are going to calculate the relative error between -# -# $E_{relative} = \sum_{TE}\frac{|(q(M_0, T_2^*, TE) - x)|}{|x|}$ -# -# on a voxel-by-voxel basis. -# %% -img_mult_te_abs_sum = torch.sum(torch.abs(idata_multi_te.data), dim=0) -relative_absolute_error = torch.sum(torch.abs(model(m0, t2star)[0] - idata_multi_te.data), dim=0) / ( - img_mult_te_abs_sum + 1e-9 -) -fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False) -colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]] - -im = axes[0, 0].imshow(m0[0, 0, ...].cpu()) -axes[0, 0].set_title('$M_0$') -fig.colorbar(im, cax=colorbar_ax[0]) - -im = axes[0, 1].imshow(t2star[0, 0, ...].cpu(), vmin=0, vmax=5) -axes[0, 1].set_title('$T_2^*$') -fig.colorbar(im, cax=colorbar_ax[1], label='s') - -im = axes[0, 2].imshow(relative_absolute_error[0, 0, ...].cpu(), vmin=0, vmax=0.1) -axes[0, 2].set_title('Relative error') -fig.colorbar(im, cax=colorbar_ax[2]) - - -# %% -# Clean-up by removing temporary directory -shutil.rmtree(data_folder)