Skip to content

Commit

Permalink
Rits transmission error model (#1987)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackEAllen authored Dec 13, 2023
2 parents c335583 + 24848d0 commit 90c2b6a
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 87 deletions.
2 changes: 1 addition & 1 deletion mantidimaging/core/io/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,6 @@ def export_to_dat_rits_format(rits_formatted_data: str, path: Path) -> None:
:return: None
"""
with open(path, 'w', encoding='utf-8') as f:
with path.open('w') as f:
f.write(rits_formatted_data)
LOG.info('RITS formatted data saved to: {}'.format(path))
62 changes: 52 additions & 10 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class SpecType(Enum):
SAMPLE_NORMED = 3


class ErrorMode(Enum):
STANDARD_DEVIATION = 1
PROPAGATED = 2


class SpectrumViewerWindowModel:
"""
The model for the spectrum viewer window.
Expand Down Expand Up @@ -124,6 +129,12 @@ def get_stack_spectrum(stack: ImageStack, roi: SensibleROI):
roi_data = stack.data[:, top:bottom, left:right]
return roi_data.mean(axis=(1, 2))

@staticmethod
def get_stack_spectrum_summed(stack: ImageStack, roi: SensibleROI):
left, top, right, bottom = roi
roi_data = stack.data[:, top:bottom, left:right]
return roi_data.sum(axis=(1, 2))

def normalise_issue(self) -> str:
if self._stack is None or self._normalise_stack is None:
return "Need 2 selected stacks"
Expand Down Expand Up @@ -153,6 +164,34 @@ def get_spectrum(self, roi_name: str, mode: SpecType) -> 'np.ndarray':
roi_norm_spectrum = self.get_stack_spectrum(self._normalise_stack, roi)
return np.divide(roi_spectrum, roi_norm_spectrum, out=np.zeros_like(roi_spectrum), where=roi_norm_spectrum != 0)

def get_transmission_error_standard_dev(self, roi_name: str) -> np.ndarray:
"""
Get the transmission error standard deviation for a given roi
@param: roi_name The roi name
@return: a numpy array representing the standard deviation of the transmission
"""
if self._stack is None or self._normalise_stack is None:
raise RuntimeError("Sample and open beam must be selected")
left, top, right, bottom = self.get_roi(roi_name)
sample = self._stack.data[:, top:bottom, left:right]
open_beam = self._normalise_stack.data[:, top:bottom, left:right]
safe_divide = np.divide(sample, open_beam, out=np.zeros_like(sample), where=open_beam != 0)
return np.std(safe_divide, axis=(1, 2))

def get_transmission_error_propagated(self, roi_name: str) -> np.ndarray:
"""
Get the transmission error using propagation of sqrt(n) error for a given roi
@param: roi_name The roi name
@return: a numpy array representing the error of the transmission
"""
if self._stack is None or self._normalise_stack is None:
raise RuntimeError("Sample and open beam must be selected")
roi = self.get_roi(roi_name)
sample = self.get_stack_spectrum_summed(self._stack, roi)
open_beam = self.get_stack_spectrum_summed(self._normalise_stack, roi)
error = np.sqrt(sample / open_beam**2 + sample**2 / open_beam**3)
return error

def get_image_shape(self) -> tuple[int, int]:
if self._stack is not None:
return self._stack.data.shape[1:]
Expand Down Expand Up @@ -191,31 +230,34 @@ def save_csv(self, path: Path, normalized: bool) -> None:
csv_output.write(outfile)
self.save_roi_coords(self.get_roi_coords_filename(path))

def save_rits(self, path: Path, normalized: bool) -> None:
def save_rits(self, path: Path, normalized: bool, error_mode: ErrorMode) -> None:
"""
Saves the spectrum for one ROI to a RITS file.
@param path: The path to save the CSV file to.
@param normalized: Whether to save the normalized spectrum.
@param error_mode: Which version (standard deviation or propagated) of the error to use in the RITS export
"""
if self._stack is None:
raise ValueError("No stack selected")

if not normalized or self._normalise_stack is None:
raise ValueError("Normalisation must be enabled, and a normalise stack must be selected")
tof = self.get_stack_time_of_flight()
if tof is None:
raise ValueError("No Time of Flights for sample. Make sure spectra log has been loaded")

# RITS expects ToF in μs
tof *= 1e6
tof *= 1e6 # RITS expects ToF in μs
transmission = self.get_spectrum(ROI_RITS, SpecType.SAMPLE_NORMED)

transmission_error = np.full_like(tof, 0.1)
if normalized:
if self._normalise_stack is None:
raise RuntimeError("No normalisation stack selected")
transmission = self.get_spectrum(ROI_RITS, SpecType.SAMPLE_NORMED)
self.export_spectrum_to_rits(path, tof, transmission, transmission_error)
if error_mode == ErrorMode.STANDARD_DEVIATION:
transmission_error = self.get_transmission_error_standard_dev(ROI_RITS)
elif error_mode == ErrorMode.PROPAGATED:
transmission_error = self.get_transmission_error_propagated(ROI_RITS)
else:
LOG.error("Data is not normalised to open beam. This will not export to a valid RITS format")
raise ValueError("Invalid error_mode given")

self.export_spectrum_to_rits(path, tof, transmission, transmission_error)

def get_stack_time_of_flight(self) -> np.array | None:
if self._stack is None or self._stack.log_file is None:
Expand Down
4 changes: 2 additions & 2 deletions mantidimaging/gui/windows/spectrum_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from logging import getLogger
from mantidimaging.core.data.dataset import StrictDataset
from mantidimaging.gui.mvp_base import BasePresenter
from mantidimaging.gui.windows.spectrum_viewer.model import SpectrumViewerWindowModel, SpecType, ROI_RITS
from mantidimaging.gui.windows.spectrum_viewer.model import SpectrumViewerWindowModel, SpecType, ROI_RITS, ErrorMode

if TYPE_CHECKING:
from mantidimaging.gui.windows.spectrum_viewer.view import SpectrumViewerWindowView # pragma: no cover
Expand Down Expand Up @@ -177,7 +177,7 @@ def handle_rits_export(self) -> None:
return
if path.suffix != ".dat":
path = path.with_suffix(".dat")
self.model.save_rits(path, self.spectrum_mode == SpecType.SAMPLE_NORMED)
self.model.save_rits(path, self.spectrum_mode == SpecType.SAMPLE_NORMED, ErrorMode.STANDARD_DEVIATION)

def handle_enable_normalised(self, enabled: bool) -> None:
if enabled:
Expand Down
Loading

0 comments on commit 90c2b6a

Please sign in to comment.