Skip to content

Commit

Permalink
Refactor get_spectrum to no longer accept ROI as a string (#2416)
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc authored Nov 29, 2024
2 parents e2bbcf4 + beb7bb6 commit d674578
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 66 deletions.
31 changes: 16 additions & 15 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,10 @@ def shuttercount_issue(self) -> str:
return "Need 2 different ShutterCount stacks"
return ""

def get_spectrum(self,
roi: str | SensibleROI,
mode: SpecType,
normalise_with_shuttercount: bool = False) -> np.ndarray:
def get_spectrum(self, roi: SensibleROI, mode: SpecType, normalise_with_shuttercount: bool = False) -> np.ndarray:
if self._stack is None:
return np.array([])

if isinstance(roi, str):
roi = self.get_roi(roi)

if mode == SpecType.SAMPLE:
return self.get_stack_spectrum(self._stack, roi)

Expand Down Expand Up @@ -330,36 +324,43 @@ def has_stack(self) -> bool:
"""
return self._stack is not None

def save_csv(self, path: Path, normalise: bool, normalise_with_shuttercount: bool = False) -> None:
def save_csv(self,
path: Path,
rois: dict[str, SensibleROI],
normalise: bool,
normalise_with_shuttercount: bool = False) -> None:
"""
Iterates over all ROIs and saves the spectrum for each one to a CSV file.
@param path: The path to save the CSV file to.
@param normalized: Whether to save the normalized spectrum.
"""
if self._stack is None:
raise ValueError("No stack selected")
if not rois:
raise ValueError("No ROIs provided")

csv_output = CSVOutput()
csv_output.add_column("ToF_index", np.arange(self._stack.data.shape[0]), "Index")

self.tof_data = self.get_stack_time_of_flight()
if self.tof_data is not None:
self.units.set_data_to_convert(self.tof_data)
csv_output.add_column("Wavelength", self.units.tof_seconds_to_wavelength_in_angstroms(), "Angstrom")
csv_output.add_column("ToF", self.units.tof_seconds_to_us(), "Microseconds")
csv_output.add_column("Energy", self.units.tof_seconds_to_energy(), "MeV")

for roi_name in self.get_list_of_roi_names():
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE, normalise_with_shuttercount),
for roi_name, roi in rois.items():
csv_output.add_column(roi_name, self.get_spectrum(roi, SpecType.SAMPLE, normalise_with_shuttercount),
"Counts")

if normalise:
if self._normalise_stack is None:
raise RuntimeError("No normalisation stack selected")
csv_output.add_column(roi_name + "_open", self.get_spectrum(roi_name, SpecType.OPEN), "Counts")
csv_output.add_column(roi_name + "_norm",
self.get_spectrum(roi_name, SpecType.SAMPLE_NORMED, normalise_with_shuttercount),
csv_output.add_column(f"{roi_name}_open", self.get_spectrum(roi, SpecType.OPEN), "Counts")
csv_output.add_column(f"{roi_name}_norm",
self.get_spectrum(roi, SpecType.SAMPLE_NORMED, normalise_with_shuttercount),
"Counts")

with path.open("w") as outfile:
csv_output.write(outfile)
self.save_roi_coords(self.get_roi_coords_filename(path))
Expand Down
70 changes: 41 additions & 29 deletions mantidimaging/gui/windows/spectrum_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,15 @@ def handle_roi_moved(self, force_new_spectrums: bool = False) -> None:
Handle changes to any ROI position and size.
"""
for name in self.model.get_list_of_roi_names():
roi = self.view.spectrum_widget.get_roi(name)
if force_new_spectrums or roi != self.model.get_roi(name):
self.model.set_roi(name, roi)
self.view.set_spectrum(
name,
self.model.get_spectrum(name,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled()))
view_roi = self.view.spectrum_widget.get_roi(name)
if force_new_spectrums or view_roi != self.model.get_roi(name):
self.model.set_roi(name, view_roi)
spectrum = self.model.get_spectrum(
view_roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled(),
)
self.view.set_spectrum(name, spectrum)

def handle_roi_clicked(self, roi: SpectrumROI) -> None:
if not roi.name == ROI_RITS:
Expand All @@ -220,9 +221,10 @@ def redraw_spectrum(self, name: str) -> None:
"""
Redraw the spectrum with the given name
"""
roi = self.model.get_roi(name)
self.view.set_spectrum(
name,
self.model.get_spectrum(name,
self.model.get_spectrum(roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled()))

Expand All @@ -231,12 +233,13 @@ def redraw_all_rois(self) -> None:
Redraw all ROIs and spectrum plots
"""
for name in self.model.get_list_of_roi_names():
if name == "all" or self.view.spectrum_widget.roi_dict[name].isVisible() is False:
if name == "all" or not self.view.spectrum_widget.roi_dict[name].isVisible():
continue
self.model.set_roi(name, self.view.spectrum_widget.get_roi(name))
roi = self.view.spectrum_widget.get_roi(name)
self.model.set_roi(name, roi)
self.view.set_spectrum(
name,
self.model.get_spectrum(name,
self.model.get_spectrum(roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled()))

Expand All @@ -257,13 +260,17 @@ def handle_button_enabled(self) -> None:

def handle_export_csv(self) -> None:
path = self.view.get_csv_filename()
if path is None:
if not path:
return
path = path.with_suffix(".csv") if path.suffix != ".csv" else path
rois = {roi.name: roi.as_sensible_roi() for roi in self.view.spectrum_widget.roi_dict.values()}

if path.suffix != ".csv":
path = path.with_suffix(".csv")

self.model.save_csv(path, self.spectrum_mode == SpecType.SAMPLE_NORMED, self.view.shuttercount_norm_enabled())
self.model.save_csv(
path,
rois,
normalise=self.spectrum_mode == SpecType.SAMPLE_NORMED,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled(),
)

def handle_rits_export(self) -> None:
"""
Expand Down Expand Up @@ -331,10 +338,12 @@ def do_add_roi(self) -> None:
roi_name = self.model.roi_name_generator()
if roi_name in self.view.spectrum_widget.roi_dict:
raise ValueError(f"ROI name already exists: {roi_name}")

self.model.set_new_roi(roi_name)
self.view.spectrum_widget.add_roi(self.model.get_roi(roi_name), roi_name)
self.view.set_spectrum(
roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
roi = self.model.get_roi(roi_name)
self.view.spectrum_widget.add_roi(roi, roi_name)
spectrum = self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled())
self.view.set_spectrum(roi_name, spectrum)
self.view.auto_range_image()
self.do_add_roi_to_table(roi_name)

Expand All @@ -351,11 +360,11 @@ def change_roi_colour(self, roi_name: str, new_colour: tuple[int, int, int]) ->
self.view.on_visibility_change()

def add_rits_roi(self) -> None:
roi_name = ROI_RITS
self.model.set_new_roi(roi_name)
self.view.spectrum_widget.add_roi(self.model.get_roi(roi_name), roi_name)
self.view.set_spectrum(
roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
self.model.set_new_roi(ROI_RITS)
roi = self.model.get_roi(ROI_RITS)
self.view.spectrum_widget.add_roi(roi, ROI_RITS)
self.view.set_spectrum(ROI_RITS,
self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
self.view.set_roi_visibility_flags(ROI_RITS, visible=False)

def do_add_roi_to_table(self, roi_name: str) -> None:
Expand Down Expand Up @@ -386,13 +395,16 @@ def do_remove_roi(self, roi_name: str | None = None) -> None:
"""
if roi_name is None:
self.view.clear_all_rois()
for roi in self.get_roi_names():
self.view.spectrum_widget.remove_roi(roi)
for name in self.get_roi_names():
self.view.spectrum_widget.remove_roi(name)
self.model.remove_all_roi()
else:
roi = self.model.get_roi(roi_name)
self.view.spectrum_widget.remove_roi(roi_name)
self.view.set_spectrum(
roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
spectrum = self.model.get_spectrum(roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled())
self.view.set_spectrum(roi_name, spectrum)
self.model.remove_roi(roi_name)

def handle_export_tab_change(self, index: int) -> None:
Expand Down
12 changes: 12 additions & 0 deletions mantidimaging/gui/windows/spectrum_viewer/spectrum_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ def adjust_spec_roi(self, roi: SensibleROI) -> None:
def rename_roi(self, new_name: str) -> None:
self._name = new_name

def as_sensible_roi(self) -> SensibleROI:
"""
Converts the SpectrumROI to a SensibleROI object.
"""
pos = self.pos()
size = self.size()
left, top = pos
width, height = size
right = left + width
bottom = top + height
return SensibleROI.from_list([left, top, right, bottom])


class SpectrumWidget(QWidget):
"""
Expand Down
51 changes: 32 additions & 19 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,20 @@ def test_get_averaged_image_range(self):

def test_get_spectrum(self):
stack, spectrum = self._set_sample_stack()

model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE)
roi = SensibleROI(left=0, top=0, right=12, bottom=11)
model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE)
self.assertEqual(model_spec.shape, (10, ))
npt.assert_array_equal(model_spec, spectrum)

def test_get_normalised_spectrum(self):
stack, spectrum = self._set_sample_stack()

normalise_stack = ImageStack(np.ones([10, 11, 12]) * 2)
self.model.set_normalise_stack(normalise_stack)

model_open_spec = self.model.get_spectrum("roi", SpecType.OPEN)
roi = SensibleROI(left=0, top=0, right=12, bottom=11)
model_open_spec = self.model.get_spectrum(roi, SpecType.OPEN)
self.assertEqual(model_open_spec.shape, (10, ))
self.assertTrue(np.all(model_open_spec == 2))

model_norm_spec = self.model.get_spectrum("roi", SpecType.SAMPLE_NORMED)
model_norm_spec = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED)
self.assertEqual(model_norm_spec.shape, (10, ))
npt.assert_array_equal(model_norm_spec, spectrum / 2)

Expand All @@ -122,8 +120,8 @@ def test_get_normalised_spectrum_zeros(self):
normalise_stack = ImageStack(np.ones([10, 11, 12]) * 2)
normalise_stack.data[5] = 0
self.model.set_normalise_stack(normalise_stack)

model_norm_spec = self.model.get_spectrum("roi", SpecType.SAMPLE_NORMED)
roi = SensibleROI(left=0, top=0, right=12, bottom=11)
model_norm_spec = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED)
expected_spec = spectrum / 2
expected_spec[5] = 0
self.assertEqual(model_norm_spec.shape, (10, ))
Expand All @@ -134,8 +132,8 @@ def test_get_normalised_spectrum_different_size(self):

normalise_stack = ImageStack(np.ones([10, 11, 13]))
self.model.set_normalise_stack(normalise_stack)

error_spectrum = self.model.get_spectrum("all", SpecType.SAMPLE_NORMED)
roi = SensibleROI(left=0, top=0, right=13, bottom=11)
error_spectrum = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED)
np.testing.assert_array_equal(error_spectrum, np.array([]))

def test_normalise_issue(self):
Expand Down Expand Up @@ -168,12 +166,12 @@ def test_get_spectrum_roi(self):
stack, spectrum = self._set_sample_stack()
stack.data[:, :, 6:] *= 2

self.model.set_roi('roi', SensibleROI.from_list([0, 0, 3, 3]))
model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE)
roi = SensibleROI.from_list([0, 0, 3, 3])
model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE)
npt.assert_array_equal(model_spec, spectrum)

self.model.set_roi('roi', SensibleROI.from_list([6, 0, 6 + 3, 3]))
model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE)
roi = SensibleROI.from_list([6, 0, 6 + 3, 3])
model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE)
npt.assert_array_equal(model_spec, spectrum * 2)

def test_get_stack_spectrum(self):
Expand All @@ -191,9 +189,13 @@ def test_save_csv(self):
stack.data *= 2
self.model.set_normalise_stack(None)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
roi_specific = SensibleROI.from_list([0, 0, 3, 3])
rois = {"all": roi_all, "roi": roi_specific}

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_csv(mock_path, False)
self.model.save_csv(mock_path, rois=rois, normalise=False)

mock_path.open.assert_called_once_with("w")
self.assertIn("# ToF_index,all,roi", mock_stream.captured[0])
Expand Down Expand Up @@ -320,18 +322,26 @@ def test_save_csv_norm_missing_stack(self):
stack, _ = self._set_sample_stack()
stack.data *= 2
self.model.set_normalise_stack(None)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
rois = {"all": roi_all}

with self.assertRaises(RuntimeError):
self.model.save_csv(mock.Mock(), True)
self.model.save_csv(mock.Mock(), rois=rois, normalise=True)

def test_save_csv_norm(self):
self._set_sample_stack()

open_stack = ImageStack(np.ones([10, 11, 12]) * 2)
self.model.set_normalise_stack(open_stack)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
roi_specific = SensibleROI.from_list([0, 0, 3, 3])
rois = {"all": roi_all, "roi": roi_specific}

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_csv(mock_path, True)
self.model.save_csv(path=mock_path, rois=rois, normalise=True, normalise_with_shuttercount=False)

mock_path.open.assert_called_once_with("w")
self.assertIn("# ToF_index,all,all_open,all_norm,roi,roi_open,roi_norm", mock_stream.captured[0])
Expand All @@ -346,9 +356,12 @@ def test_save_csv_norm_with_tof_loaded(self):
stack.data[:, :, :5] *= 2
self.model.set_normalise_stack(norm)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
rois = {"all": roi_all, "roi": roi_all}

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_csv(mock_path, True)
self.model.save_csv(mock_path, rois=rois, normalise=True, normalise_with_shuttercount=False)

mock_path.open.assert_called_once_with("w")
self.assertIn("# ToF_index,Wavelength,ToF,Energy,all,all_open,all_norm,roi,roi_open,roi_norm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ def test_handle_export_csv(self, path_name: str, mock_save_csv: mock.Mock, mock_
self.view.get_csv_filename = mock.Mock(return_value=Path(path_name))
self.view.shuttercount_norm_enabled.return_value = False
mock_shuttercount_issue.return_value = "Something wrong"

self.presenter.model.set_stack(generate_images())

self.presenter.handle_export_csv()

self.view.get_csv_filename.assert_called_once()
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), False, False)
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), {},
normalise=False,
normalise_with_shuttercount=False)

@parameterized.expand(["/fake/path", "/fake/path.dat"])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_rits_roi")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def test_WHEN_colour_is_not_valid_THEN_roi_colour_is_unchanged(self):
self.spectrum_roi.onChangeColor()
self.assertEqual(self.spectrum_roi.colour, (0, 0, 0, 255))

def test_WHEN_as_sensible_roi_called_THEN_correct_sensible_roi_returned(self):
sensible_roi = self.spectrum_roi.as_sensible_roi()
self.assertEqual((sensible_roi.left, sensible_roi.top, sensible_roi.right, sensible_roi.bottom),
(10, 20, 30, 40))


@mock_versions
@start_qapplication
Expand Down

0 comments on commit d674578

Please sign in to comment.