Skip to content

Commit

Permalink
Notebooks updated
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Jan 6, 2025
1 parent 295cb6f commit 0f3f410
Showing 1 changed file with 83 additions and 74 deletions.
157 changes: 83 additions & 74 deletions examples/example_DictionaryMatchOp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
"outputs": [],
"source": [
"# Imports\n",
"import shutil\n",
"import tempfile\n",
"import zipfile\n",
"from collections.abc import Callable\n",
"from pathlib import Path\n",
"\n",
"import einops\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import zenodo_get\n",
"from einops import rearrange\n",
"from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped]\n",
"from mrpro.algorithms.optimizers import adam\n",
"from mrpro.data import IData\n",
"from mrpro.operators import MagnitudeOp, Operator\n",
"from mrpro.operators import MagnitudeOp, DictionaryMatchOp\n",
"from mrpro.operators.functionals import MSEDataDiscrepancy\n",
"from mrpro.operators.models import InversionRecovery\n",
"from typing_extensions import Self, TypeVarTuple, Unpack"
"from mrpro.operators.models import InversionRecovery"
]
},
{
Expand Down Expand Up @@ -206,108 +206,117 @@
"outputs": [],
"source": [
"# Define 100 T1 values between 100 and 3000 ms\n",
"t1_dictionary = torch.linspace(0.1, 3, 100).double()\n",
"\n",
"\n",
"# Calculate the signal corresponding to each of these T1 values. We set M0 to 1, but this is arbitrary because M0 is\n",
"# just a scaling factor and we are going to normalize the signal curves.\n",
"(signal_dictionary,) = model(torch.ones(1), t1_dictionary)\n",
"signal_dictionary = signal_dictionary + 0j\n",
"# signal_dictionary = signal_dictionary.to(dtype=torch.complex128)\n",
"vector_norm = torch.linalg.vector_norm(signal_dictionary, dim=0)\n",
"signal_dictionary /= vector_norm\n",
"\n",
"# Calculate the dot-product and select for each voxel the T1 values that correspond to the maximum of the dot-product\n",
"n_y, n_x = idata_multi_ti.data.shape[-2:]\n",
"data = idata_multi_ti.data.to(torch.complex128)\n",
"dot_product = torch.mm(rearrange(data, 'other 1 z y x->(z y x) other'), signal_dictionary)\n",
"# print(signal_dictionary)\n",
"idx_best_match = torch.argmax(torch.abs(dot_product), dim=1)\n",
"# print(torch.abs(dot_product))\n",
"t1_start = rearrange(t1_dictionary[idx_best_match], '(y x)->1 1 y x', y=n_y, x=n_x)\n",
"\n",
"\n",
"Tin = TypeVarTuple('Tin')\n",
"\n",
"\n",
"class DictionaryMatchOp(Operator[torch.Tensor, tuple[*Tin]]):\n",
" def __init__(self, generating_function: Callable[[Unpack[Tin]], tuple[torch.Tensor,]]):\n",
" super().__init__()\n",
" self._f = generating_function\n",
" self.x: list[torch.Tensor] = []\n",
" self.y = torch.tensor([])\n",
"\n",
" def append(self, *x: Unpack[Tin]) -> Self:\n",
" (newy,) = self._f(*x)\n",
" newy = newy / torch.linalg.norm(newy, dim=0, keepdim=True)\n",
" newy = newy.flatten(start_dim=1)\n",
" newx = [x.flatten() for x in torch.broadcast_tensors(*x)]\n",
" if not self.x:\n",
" self.x = newx\n",
" self.y = newy\n",
" return self\n",
" self.x = [torch.cat(old, new) for old, new in zip(self.x, newx, strict=True)]\n",
" self.y = torch.cat((self.y, newy))\n",
" return\n",
"\n",
" def forward(self, input_signal: torch.Tensor) -> tuple[Unpack[Tin]]:\n",
" similar = einops.einsum(input_signal, self.y, 't ..., t idx -> idx ...')\n",
" idx = torch.argmax(similar, dim=0)\n",
" match = [x[idx] for x in self.x]\n",
" return match\n",
"\n",
"\n",
"t1_dictionary = torch.linspace(100, 3000, 100)\n",
"# Dictionary Matching\n",
"dict_match_op = DictionaryMatchOp(model)\n",
"dictionary = dict_match_op.append(torch.ones(1), t1_dictionary)\n",
"t1_start_new = dict_match_op.forward(idata_multi_ti.rss().double())[1]\n",
"(t1_start == t1_start_new).all()"
"t1_start = dict_match_op.forward(idata_multi_ti.data)[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b84827e",
"metadata": {
"lines_to_next_cell": 2
},
"metadata": {},
"outputs": [],
"source": []
"source": [
"# The image with the longest inversion time is a good approximation of the equilibrium magnetization\n",
"m0_start = torch.abs(idata_multi_ti.data[torch.argmax(idata_multi_ti.header.ti), ...])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efbd2edc",
"metadata": {
"lines_to_next_cell": 0
},
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.matshow(t1_start.real.squeeze())"
"# Visualize the starting values\n",
"fig, axes = plt.subplots(1, 2, figsize=(8, 2), squeeze=False)\n",
"colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n",
"im = axes[0, 0].imshow(m0_start[0, 0, ...])\n",
"axes[0, 0].set_title('M0 start values')\n",
"fig.colorbar(im, cax=colorbar_ax[0])\n",
"im = axes[0, 1].imshow(t1_start[0, 0, ...], vmin=0, vmax=2500)\n",
"axes[0, 1].set_title('T1 start values')\n",
"fig.colorbar(im, cax=colorbar_ax[1])"
]
},
{
"cell_type": "markdown",
"id": "df7e62da",
"metadata": {},
"source": [
"### Carry out fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6e44d8b",
"metadata": {
"lines_to_next_cell": 0
},
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Hyperparameters for optimizer\n",
"max_iter = 2000\n",
"lr = 1e0\n",
"\n",
"plt.matshow(t1_start_new.real.squeeze())"
"# Run optimization\n",
"params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr)\n",
"m0, t1 = (p.detach() for p in params_result)\n",
"m0[torch.isnan(t1)] = 0\n",
"t1[torch.isnan(t1)] = 0"
]
},
{
"cell_type": "markdown",
"id": "b2a713f6",
"metadata": {},
"source": [
"### Visualize the final results\n",
"To get an impression of how well the fit has worked, we are going to calculate the relative error between\n",
"\n",
"$E_{relative} = \\sum_{TI}\\frac{|(q(M_0, T1, TI) - x)|}{|x|}$\n",
"\n",
"on a voxel-by-voxel basis"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de2af924",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"img_mult_te_abs_sum = torch.sum(torch.abs(idata_multi_ti.data), dim=0)\n",
"relative_absolute_error_old = torch.sum(torch.abs(model(m0, t1)[0] - idata_multi_ti.data), dim=0) / (\n",
" img_mult_te_abs_sum + 1e-9\n",
")\n",
"fig, axes = plt.subplots(1, 3, figsize=(10, 2), squeeze=False)\n",
"colorbar_ax = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) for ax in axes[0, :]]\n",
"im = axes[0, 0].imshow(m0[0, 0, ...])\n",
"axes[0, 0].set_title('M0')\n",
"fig.colorbar(im, cax=colorbar_ax[0])\n",
"im = axes[0, 1].imshow(t1[0, 0, ...], vmin=0, vmax=2500)\n",
"axes[0, 1].set_title('T1')\n",
"fig.colorbar(im, cax=colorbar_ax[1])\n",
"im = axes[0, 2].imshow(relative_absolute_error_old[0, 0, ...], vmin=0, vmax=1.0)\n",
"axes[0, 2].set_title('Relative error')\n",
"fig.colorbar(im, cax=colorbar_ax[2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fd3685e",
"metadata": {},
"outputs": [],
"source": []
"source": [
"# Clean-up by removing temporary directory\n",
"shutil.rmtree(data_folder)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 0f3f410

Please sign in to comment.