Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize Operation Execution with compute_function Approach #2074

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mantidimaging/core/operations/arithmetic/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ class ArithmeticFilter(BaseFilter):
"""
filter_name = "Arithmetic"

@classmethod
def filter_func(cls,
images: ImageStack,
@staticmethod
def filter_func(images: ImageStack,
div_val: float = 1.0,
mult_val: float = 1.0,
add_val: float = 0.0,
Expand All @@ -47,7 +46,8 @@ def filter_func(cls,
raise ValueError("Unable to proceed with operation because division/multiplication value is zero.")

params = {'div': div_val, 'mult': mult_val, 'add': add_val, 'sub': sub_val}
ps.run_compute_func(cls.compute_function, images.data.shape[0], images.shared_array, params, progress)
ps.run_compute_func(ArithmeticFilter.compute_function, images.data.shape[0], images.shared_array, params,
progress)

return images

Expand Down
175 changes: 48 additions & 127 deletions mantidimaging/core/operations/flat_fielding/flat_fielding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@

import numpy as np

from mantidimaging import helper as h
from mantidimaging.core.operations.base_filter import BaseFilter, FilterGroup
from mantidimaging.core.parallel import utility as pu, shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.parallel import shared as ps
from mantidimaging.gui.utility.qt_helpers import Type
from mantidimaging.gui.widgets.dataset_selector import DatasetSelectorWidgetView

Expand Down Expand Up @@ -48,21 +46,6 @@ def enable_correct_fields_only(selected_flat_fielding_widget, flat_before_widget


class FlatFieldFilter(BaseFilter):
"""Uses the flat (open beam) and dark images to normalise a stack of images (radiograms, projections),
and to correct for a beam profile, scintillator imperfections and/or detector inhomogeneities. This
operation produces images of transmission values.

In practice, several open beam and dark images are averaged in the flat-fielding process.

Intended to be used on: Projections

When: As one of the first pre-processing steps

Caution: Make sure the correct stacks are selected for flat and dark.

Caution: Check that the flat and dark images don't have any very bright pixels,
or this will introduce additional noise in the sample. Remove outliers before flat-fielding.
"""
filter_name = 'Flat-fielding'

@staticmethod
Expand All @@ -74,57 +57,56 @@ def filter_func(images: ImageStack,
selected_flat_fielding: str | None = None,
use_dark: bool = True,
progress=None) -> ImageStack:
"""Do background correction with flat and dark images.

:param images: Sample data which is to be processed. Expected in radiograms
:param flat_before: Flat (open beam) image to use in normalization, collected before the sample was imaged
:param flat_after: Flat (open beam) image to use in normalization, collected after the sample was imaged
:param dark_before: Dark image to use in normalization, collected before the sample was imaged
:param dark_after: Dark image to use in normalization, collected after the sample was imaged
:param selected_flat_fielding: Select which of the flat fielding methods to use, just Before stacks, just After
stacks or combined.
:param use_dark: Whether to use dark frame subtraction
:return: Filtered data (stack of images)
"""
h.check_data_stack(images)

if selected_flat_fielding == "Both, concatenated" and flat_after is not None and flat_before is not None \
and dark_after is not None and dark_before is not None:
flat_avg = (flat_before.data.mean(axis=0) + flat_after.data.mean(axis=0)) / 2.0
if use_dark:
dark_avg = (dark_before.data.mean(axis=0) + dark_after.data.mean(axis=0)) / 2.0
elif selected_flat_fielding == "Only Before" and flat_before is not None and dark_before is not None:
flat_avg = flat_before.data.mean(axis=0)
if use_dark:
dark_avg = dark_before.data.mean(axis=0)
elif selected_flat_fielding == "Only After" and flat_after is not None and dark_after is not None:
flat_avg = flat_after.data.mean(axis=0)
if use_dark:
dark_avg = dark_after.data.mean(axis=0)
else:
raise ValueError("selected_flat_fielding not in:", valid_methods)

if not use_dark:
dark_avg = np.zeros_like(flat_avg)
if images.num_projections < 2:
return images

if flat_avg is not None and dark_avg is not None:
if 2 != flat_avg.ndim or 2 != dark_avg.ndim:
raise ValueError(
f"Incorrect shape of the flat image ({flat_avg.shape}) or dark image ({dark_avg.shape}) \
which should match the shape of the sample images ({images.data.shape})")

if not images.data.shape[1:] == flat_avg.shape == dark_avg.shape:
raise ValueError(f"Not all images are the expected shape: {images.data.shape[1:]}, instead "
f"flat had shape: {flat_avg.shape}, and dark had shape: {dark_avg.shape}")
params = {
"flat_before": flat_before,
"flat_after": flat_after,
"dark_before": dark_before,
"dark_after": dark_after,
"selected_flat_fielding": selected_flat_fielding,
"use_dark": use_dark
}

progress = Progress.ensure_instance(progress,
num_steps=images.data.shape[0],
task_name='Background Correction')
_execute(images, flat_avg, dark_avg, progress)
ps.run_compute_func(FlatFieldFilter.compute_function, images.num_sinograms, images.shared_array, params,
progress)

h.check_data_stack(images)
return images

@staticmethod
def compute_function(index: int, array: np.ndarray, params: Dict[str, Any]):
flat_before = params["flat_before"]
flat_after = params["flat_after"]
dark_before = params["dark_before"]
dark_after = params["dark_after"]
selected_flat_fielding = params["selected_flat_fielding"]
use_dark = params["use_dark"]

if selected_flat_fielding == "Only Before":
flat_image = flat_before.data[index] if flat_before is not None else None
dark_image = dark_before.data[index] if dark_before is not None else None
elif selected_flat_fielding == "Only After":
flat_image = flat_after.data[index] if flat_after is not None else None
dark_image = dark_after.data[index] if dark_after is not None else None
elif selected_flat_fielding == "Both, concatenated":
flat_image = (flat_before.data[index] + flat_after.data[index]) / 2.0 \
if flat_before is not None and flat_after is not None else None
dark_image = (dark_before.data[index] + dark_after.data[index]) / 2.0 \
if dark_before is not None and dark_after is not None else None
else:
raise ValueError("Unknown flat fielding method")

if flat_image is not None and dark_image is not None:
corrected_flat = flat_image - dark_image
corrected_flat[corrected_flat <= 0] = MINIMUM_PIXEL_VALUE
corrected_data = np.divide(array, corrected_flat, out=np.zeros_like(array), where=(corrected_flat != 0))
if use_dark:
corrected_data -= dark_image
np.clip(corrected_data, MINIMUM_PIXEL_VALUE, MAXIMUM_PIXEL_VALUE, out=corrected_data)
array[:] = corrected_data

@staticmethod
def register_gui(form, on_change, view) -> Dict[str, Any]:
from mantidimaging.gui.utility import add_property_to_form
Expand Down Expand Up @@ -242,76 +224,15 @@ def execute_wrapper( # type: ignore

@staticmethod
def validate_execute_kwargs(kwargs):
# Validate something is in both path text inputs
if 'selected_flat_fielding_widget' not in kwargs:
return False

if 'flat_before_widget' not in kwargs and 'dark_before_widget' not in kwargs or\
'flat_after_widget' not in kwargs and 'dark_after_widget' not in kwargs:
if ('flat_before_widget' not in kwargs or 'dark_before_widget' not in kwargs) \
or ('flat_after_widget' not in kwargs or 'dark_after_widget' not in kwargs):
return False
assert isinstance(kwargs["flat_before_widget"], DatasetSelectorWidgetView)
assert isinstance(kwargs["flat_after_widget"], DatasetSelectorWidgetView)
assert isinstance(kwargs["dark_before_widget"], DatasetSelectorWidgetView)
assert isinstance(kwargs["dark_after_widget"], DatasetSelectorWidgetView)

return True

@staticmethod
def group_name() -> FilterGroup:
return FilterGroup.Basic


def _divide(data, norm_divide):
np.true_divide(data, norm_divide, out=data)


def _subtract(data, dark=None):
# specify out to do in place, otherwise the data is copied
np.subtract(data, dark, out=data)


def _norm_divide(flat: np.ndarray, dark: np.ndarray) -> np.ndarray:
# subtract dark from flat
return np.subtract(flat, dark)


def _execute(images: ImageStack, flat=None, dark=None, progress=None):
"""A benchmark justifying the current implementation, performed on
500x2048x2048 images.

#1 Separate runs
Subtract (sequential with np.subtract(data, dark, out=data)) - 13s
Divide (par) - 1.15s

#2 Separate parallel runs
Subtract (par) - 5.5s
Divide (par) - 1.15s

#3 Added subtract into _divide so that it is:
np.true_divide(
np.subtract(data, dark, out=data), norm_divide, out=data)
Subtract then divide (par) - 55s
"""
with progress:
progress.update(msg="Applying background correction")

if images.uses_shared_memory:
shared_dark = pu.copy_into_shared_memory(dark)
norm_divide = pu.copy_into_shared_memory(_norm_divide(flat, dark))
else:
shared_dark = pu.SharedArray(dark, None)
norm_divide = pu.SharedArray(_norm_divide(flat, dark), None)

# prevent divide-by-zero issues, and negative pixels make no sense
norm_divide.array[norm_divide.array == 0] = MINIMUM_PIXEL_VALUE

# subtract the dark from all images
do_subtract = ps.create_partial(_subtract, fwd_function=ps.inplace_second_2d)
arrays = [images.shared_array, shared_dark]
ps.execute(do_subtract, arrays, images.data.shape[0], progress)

# divide the data by (flat - dark)
do_divide = ps.create_partial(_divide, fwd_function=ps.inplace_second_2d)
arrays = [images.shared_array, norm_divide]
ps.execute(do_divide, arrays, images.data.shape[0], progress)

return images
59 changes: 22 additions & 37 deletions mantidimaging/core/operations/median_filter/median_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
from __future__ import annotations

from functools import partial
from logging import getLogger
from typing import Callable, Dict, Any, TYPE_CHECKING, Tuple

import numpy as np
import scipy.ndimage as scipy_ndimage
from PyQt5.QtGui import QValidator
from PyQt5.QtWidgets import QSpinBox, QLabel, QSizePolicy

import scipy.ndimage as scipy_ndimage

from mantidimaging import helper as h
from mantidimaging.core.gpu import utility as gpu
from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.core.parallel import shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.utility import add_property_to_form
from mantidimaging.gui.utility.qt_helpers import Type, on_change_and_disable

Expand Down Expand Up @@ -83,19 +82,32 @@ def filter_func(data: ImageStack, size=None, mode="reflect", progress=None, forc
:return: Returns the processed data

"""
# Validation
h.check_data_stack(data)

if not size or not size > 1:
if size is None or size <= 1:
raise ValueError(f'Size parameter must be greater than 1, but value provided was {size}')

if not force_cpu:
_execute_gpu(data.data, size, mode, progress)
else:
_execute(data, size, mode, progress)
params = {'mode': mode, 'force_cpu': force_cpu}
if force_cpu:
params['size'] = size # Pass size only if using CPU
ps.run_compute_func(MedianFilter.compute_function, data.data.shape[0], data.shared_array, params)

h.check_data_stack(data)
return data

@staticmethod
def compute_function(i: int, array: np.ndarray, params: Dict[str, Any]):
mode = params['mode']
force_cpu = params['force_cpu']
size = params.get('size')
progress = params.get('progress')

if not force_cpu:
cuda = gpu.CudaExecuter(array.dtype)
cuda.median_filter(i, array, mode=mode, progress=progress) # Call without size if it is None
else:
array[i] = _median_filter(array[i], size=size if size is not None else 3, mode=mode)

@staticmethod
def register_gui(form: 'QFormLayout', on_change: Callable, view) -> Dict[str, Any]:

Expand Down Expand Up @@ -133,36 +145,9 @@ def modes():
return ['reflect', 'constant', 'nearest', 'mirror', 'wrap']


def _median_filter(data: np.ndarray, size: int, mode: str):
# Replaces NaNs with negative infinity before median filter
# so they do not effect neighbouring pixels
def _median_filter(data: np.ndarray, size: int, mode: str) -> np.ndarray:
nans = np.isnan(data)
data = np.where(nans, -np.inf, data)
data = scipy_ndimage.median_filter(data, size=size, mode=mode)
# Put the original NaNs back
data = np.where(nans, np.nan, data)
return data


def _execute(images: ImageStack, size, mode, progress=None):
log = getLogger(__name__)
progress = Progress.ensure_instance(progress, task_name='Median filter')

# create the partial function to forward the parameters
f = ps.create_partial(_median_filter, ps.return_to_self, size=size, mode=mode)

with progress:
log.info(f"PARALLEL median filter, with pixel data type: {images.dtype}, filter size/width: {size}.")

ps.execute(f, [images.shared_array], images.data.shape[0], progress, msg="Median filter")


def _execute_gpu(data, size, mode, progress=None):
log = getLogger(__name__)
progress = Progress.ensure_instance(progress, num_steps=data.shape[0], task_name="Median filter GPU")
cuda = gpu.CudaExecuter(data.dtype)

with progress:
log.info(f"GPU median filter, with pixel data type: {data.dtype}, filter size/width: {size}.")

cuda.median_filter(data, size, mode, progress)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.core.parallel import shared as ps
from mantidimaging.core.parallel import utility as pu

if TYPE_CHECKING:
from mantidimaging.core.data import ImageStack
Expand Down Expand Up @@ -38,21 +37,24 @@ def filter_func(images: ImageStack, progress=None) -> ImageStack:
:return: The ImageStack object which has been normalised.
"""
if images.num_projections == 1:
# we can't really compute the preview as the image stack copy
# passed in doesn't have the logfile in it
raise RuntimeError("No logfile available for this stack.")

counts = images.counts()

if counts is None:
raise RuntimeError("No loaded log values for this stack.")

counts_val = pu.copy_into_shared_memory(counts.value / counts.value[0])
do_division = ps.create_partial(_divide_by_counts, fwd_function=ps.inplace2)
arrays = [images.shared_array, counts_val]
ps.execute(do_division, arrays, images.num_projections, progress)
normalization_factor = counts.value / counts.value[0]
params = {'normalization_factor': normalization_factor}
ps.run_compute_func(MonitorNormalisation.compute_function, images.data.shape[0], images.shared_array, params,
progress)

return images

@staticmethod
def compute_function(i: int, array: np.ndarray, params: Dict[str, np.ndarray]):
normalization_factor = params['normalization_factor']
array[i] /= normalization_factor

@staticmethod
def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BaseMainWindowView') -> Dict[str, 'QWidget']:
return {}
Expand Down
Loading
Loading