Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit tgv backend #1997

Merged
merged 11 commits into from
May 1, 2024
1 change: 1 addition & 0 deletions docs/release_notes/next/dev-1985-tgv-backend
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#1985 : Backend work for TGV reconstruction
92 changes: 88 additions & 4 deletions mantidimaging/core/reconstruct/cil_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

import numpy as np

from cil.framework import AcquisitionData, AcquisitionGeometry, DataOrder, ImageGeometry, BlockGeometry
from cil.framework import (AcquisitionData, AcquisitionGeometry, DataOrder, ImageGeometry, BlockGeometry,
BlockDataContainer)
from cil.optimisation.algorithms import PDHG, SPDHG
from cil.optimisation.operators import GradientOperator, BlockOperator
from cil.optimisation.operators import SymmetrisedGradientOperator, ZeroOperator, IdentityOperator

from cil.optimisation.functions import MixedL21Norm, L2NormSquared, BlockFunction, ZeroFunction, IndicatorBox, Function
from cil.plugins.astra.operators import ProjectionOperator

Expand Down Expand Up @@ -82,6 +85,67 @@ def set_up_TV_regularisation(

return (K, F, G)

@staticmethod
def set_up_TGV_regularisation(
image_geometry: ImageGeometry, acquisition_data: AcquisitionData,
recon_params: ReconstructionParameters) -> tuple[BlockOperator, BlockFunction, Function]:

# Forward operator
A2d = ProjectionOperator(image_geometry, acquisition_data.geometry, 'gpu')

if recon_params.stochastic:
for partition_geometry, partition_operator in zip(acquisition_data.geometry, A2d, strict=True):
CILRecon.set_approx_norm(partition_operator, partition_geometry, image_geometry)
else:
CILRecon.set_approx_norm(A2d, acquisition_data.geometry, image_geometry)

# Define Gradient Operator and BlockOperator
alpha = recon_params.alpha
gamma = recon_params.gamma
beta = alpha * gamma

f2 = MixedL21Norm()
f3 = MixedL21Norm()

if recon_params.stochastic:
raise ValueError("TGV reconstruction does not yet support stochastic mode")
# now, A2d is a BlockOperator as acquisition_data is a BlockDataContainer
fs = []
for i, _ in enumerate(acquisition_data.geometry):
fs.append(L2NormSquared(b=acquisition_data.get_item(i)))

F = BlockFunction(*fs, f2, f3)

else:
# Define BlockFunction F using the MixedL21Norm() and the L2NormSquared()
# mathematicians like to multiply 1/2 in front of L2NormSquared. This is not necessary
# it will mean that the regularisation parameter alpha is doubled
f1 = L2NormSquared(b=acquisition_data)

F = BlockFunction(f1, f2, f3)

# Define BlockOperator K

# Set up the 3 operator A, Grad and Symmetrised Gradient
K11 = A2d
K21 = alpha * GradientOperator(K11.domain)
K32 = beta * SymmetrisedGradientOperator(K21.range)
# these define the domain and range of the other operators
K12 = ZeroOperator(K32.domain, K11.range)
K22 = -alpha * IdentityOperator(domain_geometry=K21.range, range_geometry=K32.range)
K31 = ZeroOperator(K11.domain, K32.range)

K = BlockOperator(K11, K12, K21, K22, K31, K32, shape=(3, 2))

if recon_params.non_negative:
G = BlockFunction(IndicatorBox(lower=0, upper=None), ZeroFunction())

else:
# Define Function G simply as zero
G = ZeroFunction()

return (K, F, G)

@staticmethod
def set_approx_norm(A2d: BlockOperator, acquisition_data: AcquisitionGeometry,
image_geometry: ImageGeometry) -> None:
Expand Down Expand Up @@ -169,7 +233,12 @@ def single_sino(sino: np.ndarray,

ig = ag.get_ImageGeometry()

K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params)
if recon_params.regulariser == 'TV':
K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params)
elif recon_params.regulariser == 'TGV':
K, F, G = CILRecon.set_up_TGV_regularisation(ig, data, recon_params)
else:
raise ValueError(f"Regulariser must be one of 'TV', 'TGV'. Received '{recon_params.regulariser}'")

max_iteration = 100000
# this should set to a sensible number as evaluating the objective is costly
Expand Down Expand Up @@ -210,6 +279,11 @@ def single_sino(sino: np.ndarray,
progress.mark_complete()
t1 = time.perf_counter()
LOG.info(f"single_sino time: {t1-t0}s for shape {sino.shape}")

if isinstance(algo.solution, BlockDataContainer):
# TGV case
return algo.solution[0].as_array()

return algo.solution.as_array()

@staticmethod
Expand Down Expand Up @@ -288,7 +362,13 @@ def full(images: ImageStack,
num_subsets)

ig = ag.get_ImageGeometry()
K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params)

if recon_params.regulariser == 'TV':
K, F, G = CILRecon.set_up_TV_regularisation(ig, data, recon_params)
elif recon_params.regulariser == 'TGV':
K, F, G = CILRecon.set_up_TGV_regularisation(ig, data, recon_params)
else:
raise ValueError(f"Regulariser must be one of 'TV', 'TGV'. Received '{recon_params.regulariser}'")

max_iteration = 100000
# this should set to a sensible number as evaluating the objective is costly
Expand Down Expand Up @@ -325,7 +405,11 @@ def full(images: ImageStack,
force_continue=False)
algo.next()

volume = algo.solution.as_array()
if isinstance(algo.solution, BlockDataContainer):
# TGV case
volume = algo.solution[0].as_array()
else:
volume = algo.solution.as_array()
LOG.info(f'Reconstructed 3D volume with shape: {volume.shape}')
t1 = time.perf_counter()
LOG.info(f"full reconstruction time: {t1-t0}s for shape {images.data.shape}")
Expand Down
4 changes: 4 additions & 0 deletions mantidimaging/core/utility/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ class ReconstructionParameters:
tilt: Degrees | None = None
pixel_size: float = 0.0
alpha: float = 0.0
gamma: float = 1.0
non_negative: bool = False
max_projection_angle: float = 360.0
beam_hardening_coefs: list[float] | None = None
stochastic: bool = False
projections_per_subset: int = 50
regularisation_percent: int = 30
regulariser: str = ""

def to_dict(self) -> dict:
return {
Expand All @@ -119,9 +121,11 @@ def to_dict(self) -> dict:
'tilt': str(self.tilt),
'pixel_size': self.pixel_size,
'alpha': self.alpha,
'gamma': self.gamma,
'stochastic': self.stochastic,
'projections_per_subset': self.projections_per_subset,
'regularisation_percent': self.regularisation_percent,
'regulariser': self.regulariser,
}


Expand Down
10 changes: 10 additions & 0 deletions mantidimaging/gui/windows/recon/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ def pixel_size(self, value: int) -> None:
def alpha(self) -> float:
return self.alphaSpinBox.value()

@property
def gamma(self) -> float:
return 1

@property
def non_negative(self) -> bool:
return self.nonNegativeCheckBox.isChecked()
Expand All @@ -427,6 +431,10 @@ def projections_per_subset(self) -> int:
def regularisation_percent(self) -> int:
return self.regPercentSpinBox.value()

@property
def regulariser(self) -> str:
return "TV"

@property
def beam_hardening_coefs(self) -> list[float] | None:
if not self.lbhc_enabled.isChecked():
Expand All @@ -447,11 +455,13 @@ def recon_params(self) -> ReconstructionParameters:
tilt=Degrees(self.tilt),
pixel_size=self.pixel_size,
alpha=self.alpha,
gamma=self.gamma,
non_negative=self.non_negative,
stochastic=self.stochastic,
projections_per_subset=self.projections_per_subset,
max_projection_angle=self.max_proj_angle,
regularisation_percent=self.regularisation_percent,
regulariser=self.regulariser,
beam_hardening_coefs=self.beam_hardening_coefs)

def set_table_point(self, idx, slice_idx, cor):
Expand Down
Loading