diff --git a/docs/release_notes/next/dev-1985-tgv-backend b/docs/release_notes/next/dev-1985-tgv-backend new file mode 100644 index 00000000000..2edb8213492 --- /dev/null +++ b/docs/release_notes/next/dev-1985-tgv-backend @@ -0,0 +1 @@ +#1985 : Backend work for TGV reconstruction diff --git a/mantidimaging/core/reconstruct/cil_recon.py b/mantidimaging/core/reconstruct/cil_recon.py index 1ebf8c9886e..f8622f2e637 100644 --- a/mantidimaging/core/reconstruct/cil_recon.py +++ b/mantidimaging/core/reconstruct/cil_recon.py @@ -10,9 +10,12 @@ import numpy as np -from cil.framework import AcquisitionData, AcquisitionGeometry, DataOrder, ImageGeometry, BlockGeometry +from cil.framework import (AcquisitionData, AcquisitionGeometry, DataOrder, ImageGeometry, BlockGeometry, + BlockDataContainer) from cil.optimisation.algorithms import PDHG, SPDHG from cil.optimisation.operators import GradientOperator, BlockOperator +from cil.optimisation.operators import SymmetrisedGradientOperator, ZeroOperator, IdentityOperator + from cil.optimisation.functions import MixedL21Norm, L2NormSquared, BlockFunction, ZeroFunction, IndicatorBox, Function from cil.plugins.astra.operators import ProjectionOperator @@ -82,6 +85,67 @@ def set_up_TV_regularisation( return (K, F, G) + @staticmethod + def set_up_TGV_regularisation( + image_geometry: ImageGeometry, acquisition_data: AcquisitionData, + recon_params: ReconstructionParameters) -> tuple[BlockOperator, BlockFunction, Function]: + + # Forward operator + A2d = ProjectionOperator(image_geometry, acquisition_data.geometry, 'gpu') + + if recon_params.stochastic: + for partition_geometry, partition_operator in zip(acquisition_data.geometry, A2d, strict=True): + CILRecon.set_approx_norm(partition_operator, partition_geometry, image_geometry) + else: + CILRecon.set_approx_norm(A2d, acquisition_data.geometry, image_geometry) + + # Define Gradient Operator and BlockOperator + alpha = recon_params.alpha + gamma = recon_params.gamma + beta = alpha * gamma + + f2 = MixedL21Norm() + f3 = MixedL21Norm() + + if recon_params.stochastic: + raise ValueError("TGV reconstruction does not yet support stochastic mode") + # now, A2d is a BlockOperator as acquisition_data is a BlockDataContainer + fs = [] + for i, _ in enumerate(acquisition_data.geometry): + fs.append(L2NormSquared(b=acquisition_data.get_item(i))) + + F = BlockFunction(*fs, f2, f3) + + else: + # Define BlockFunction F using the MixedL21Norm() and the L2NormSquared() + # mathematicians like to multiply 1/2 in front of L2NormSquared. This is not necessary + # it will mean that the regularisation parameter alpha is doubled + f1 = L2NormSquared(b=acquisition_data) + + F = BlockFunction(f1, f2, f3) + + # Define BlockOperator K + + # Set up the 3 operator A, Grad and Symmetrised Gradient + K11 = A2d + K21 = alpha * GradientOperator(K11.domain) + K32 = beta * SymmetrisedGradientOperator(K21.range) + # these define the domain and range of the other operators + K12 = ZeroOperator(K32.domain, K11.range) + K22 = -alpha * IdentityOperator(domain_geometry=K21.range, range_geometry=K32.range) + K31 = ZeroOperator(K11.domain, K32.range) + + K = BlockOperator(K11, K12, K21, K22, K31, K32, shape=(3, 2)) + + if recon_params.non_negative: + G = BlockFunction(IndicatorBox(lower=0, upper=None), ZeroFunction()) + + else: + # Define Function G simply as zero + G = ZeroFunction() + + return (K, F, G) + @staticmethod def set_approx_norm(A2d: BlockOperator, acquisition_data: AcquisitionGeometry, image_geometry: ImageGeometry) -> None: @@ -169,7 +233,12 @@ def single_sino(sino: np.ndarray, ig = ag.get_ImageGeometry() - K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params) + if recon_params.regulariser == 'TV': + K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params) + elif recon_params.regulariser == 'TGV': + K, F, G = CILRecon.set_up_TGV_regularisation(ig, data, recon_params) + else: + raise ValueError(f"Regulariser must be one of 'TV', 'TGV'. Received '{recon_params.regulariser}'") max_iteration = 100000 # this should set to a sensible number as evaluating the objective is costly @@ -210,6 +279,11 @@ def single_sino(sino: np.ndarray, progress.mark_complete() t1 = time.perf_counter() LOG.info(f"single_sino time: {t1-t0}s for shape {sino.shape}") + + if isinstance(algo.solution, BlockDataContainer): + # TGV case + return algo.solution[0].as_array() + return algo.solution.as_array() @staticmethod @@ -288,7 +362,13 @@ def full(images: ImageStack, num_subsets) ig = ag.get_ImageGeometry() - K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params) + + if recon_params.regulariser == 'TV': + K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params) + elif recon_params.regulariser == 'TGV': + K, F, G = CILRecon.set_up_TGV_regularisation(ig, data, recon_params) + else: + raise ValueError(f"Regulariser must be one of 'TV', 'TGV'. Received '{recon_params.regulariser}'") max_iteration = 100000 # this should set to a sensible number as evaluating the objective is costly @@ -325,7 +405,11 @@ def full(images: ImageStack, force_continue=False) algo.next() - volume = algo.solution.as_array() + if isinstance(algo.solution, BlockDataContainer): + # TGV case + volume = algo.solution[0].as_array() + else: + volume = algo.solution.as_array() LOG.info(f'Reconstructed 3D volume with shape: {volume.shape}') t1 = time.perf_counter() LOG.info(f"full reconstruction time: {t1-t0}s for shape {images.data.shape}") diff --git a/mantidimaging/core/utility/data_containers.py b/mantidimaging/core/utility/data_containers.py index 95dd73ce633..64ef1566bea 100644 --- a/mantidimaging/core/utility/data_containers.py +++ b/mantidimaging/core/utility/data_containers.py @@ -103,12 +103,14 @@ class ReconstructionParameters: tilt: Degrees | None = None pixel_size: float = 0.0 alpha: float = 0.0 + gamma: float = 1.0 non_negative: bool = False max_projection_angle: float = 360.0 beam_hardening_coefs: list[float] | None = None stochastic: bool = False projections_per_subset: int = 50 regularisation_percent: int = 30 + regulariser: str = "" def to_dict(self) -> dict: return { @@ -119,9 +121,11 @@ def to_dict(self) -> dict: 'tilt': str(self.tilt), 'pixel_size': self.pixel_size, 'alpha': self.alpha, + 'gamma': self.gamma, 'stochastic': self.stochastic, 'projections_per_subset': self.projections_per_subset, 'regularisation_percent': self.regularisation_percent, + 'regulariser': self.regulariser, } diff --git a/mantidimaging/gui/windows/recon/view.py b/mantidimaging/gui/windows/recon/view.py index 7ac9195c8d9..5bacb18ce7b 100644 --- a/mantidimaging/gui/windows/recon/view.py +++ b/mantidimaging/gui/windows/recon/view.py @@ -411,6 +411,10 @@ def pixel_size(self, value: int) -> None: def alpha(self) -> float: return self.alphaSpinBox.value() + @property + def gamma(self) -> float: + return 1 + @property def non_negative(self) -> bool: return self.nonNegativeCheckBox.isChecked() @@ -427,6 +431,10 @@ def projections_per_subset(self) -> int: def regularisation_percent(self) -> int: return self.regPercentSpinBox.value() + @property + def regulariser(self) -> str: + return "TV" + @property def beam_hardening_coefs(self) -> list[float] | None: if not self.lbhc_enabled.isChecked(): @@ -447,11 +455,13 @@ def recon_params(self) -> ReconstructionParameters: tilt=Degrees(self.tilt), pixel_size=self.pixel_size, alpha=self.alpha, + gamma=self.gamma, non_negative=self.non_negative, stochastic=self.stochastic, projections_per_subset=self.projections_per_subset, max_projection_angle=self.max_proj_angle, regularisation_percent=self.regularisation_percent, + regulariser=self.regulariser, beam_hardening_coefs=self.beam_hardening_coefs) def set_table_point(self, idx, slice_idx, cor):