From beb7bb633544f92073023b5bf071e788d0596f21 Mon Sep 17 00:00:00 2001 From: ashmeigh Date: Mon, 25 Nov 2024 15:49:27 +0000 Subject: [PATCH] Refactor get_spectrum to no longer accept ROI as a string --- .../gui/windows/spectrum_viewer/model.py | 31 ++++---- .../gui/windows/spectrum_viewer/presenter.py | 70 +++++++++++-------- .../spectrum_viewer/spectrum_widget.py | 12 ++++ .../spectrum_viewer/test/model_test.py | 51 +++++++++----- .../spectrum_viewer/test/presenter_test.py | 6 +- .../spectrum_viewer/test/spectrum_test.py | 5 ++ 6 files changed, 109 insertions(+), 66 deletions(-) diff --git a/mantidimaging/gui/windows/spectrum_viewer/model.py b/mantidimaging/gui/windows/spectrum_viewer/model.py index 2f6b897b255..952b26df072 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/model.py +++ b/mantidimaging/gui/windows/spectrum_viewer/model.py @@ -218,16 +218,10 @@ def shuttercount_issue(self) -> str: return "Need 2 different ShutterCount stacks" return "" - def get_spectrum(self, - roi: str | SensibleROI, - mode: SpecType, - normalise_with_shuttercount: bool = False) -> np.ndarray: + def get_spectrum(self, roi: SensibleROI, mode: SpecType, normalise_with_shuttercount: bool = False) -> np.ndarray: if self._stack is None: return np.array([]) - if isinstance(roi, str): - roi = self.get_roi(roi) - if mode == SpecType.SAMPLE: return self.get_stack_spectrum(self._stack, roi) @@ -330,18 +324,25 @@ def has_stack(self) -> bool: """ return self._stack is not None - def save_csv(self, path: Path, normalise: bool, normalise_with_shuttercount: bool = False) -> None: + def save_csv(self, + path: Path, + rois: dict[str, SensibleROI], + normalise: bool, + normalise_with_shuttercount: bool = False) -> None: """ Iterates over all ROIs and saves the spectrum for each one to a CSV file. - @param path: The path to save the CSV file to. @param normalized: Whether to save the normalized spectrum. + """ if self._stack is None: raise ValueError("No stack selected") + if not rois: + raise ValueError("No ROIs provided") csv_output = CSVOutput() csv_output.add_column("ToF_index", np.arange(self._stack.data.shape[0]), "Index") + self.tof_data = self.get_stack_time_of_flight() if self.tof_data is not None: self.units.set_data_to_convert(self.tof_data) @@ -349,17 +350,17 @@ def save_csv(self, path: Path, normalise: bool, normalise_with_shuttercount: boo csv_output.add_column("ToF", self.units.tof_seconds_to_us(), "Microseconds") csv_output.add_column("Energy", self.units.tof_seconds_to_energy(), "MeV") - for roi_name in self.get_list_of_roi_names(): - csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE, normalise_with_shuttercount), + for roi_name, roi in rois.items(): + csv_output.add_column(roi_name, self.get_spectrum(roi, SpecType.SAMPLE, normalise_with_shuttercount), "Counts") + if normalise: if self._normalise_stack is None: raise RuntimeError("No normalisation stack selected") - csv_output.add_column(roi_name + "_open", self.get_spectrum(roi_name, SpecType.OPEN), "Counts") - csv_output.add_column(roi_name + "_norm", - self.get_spectrum(roi_name, SpecType.SAMPLE_NORMED, normalise_with_shuttercount), + csv_output.add_column(f"{roi_name}_open", self.get_spectrum(roi, SpecType.OPEN), "Counts") + csv_output.add_column(f"{roi_name}_norm", + self.get_spectrum(roi, SpecType.SAMPLE_NORMED, normalise_with_shuttercount), "Counts") - with path.open("w") as outfile: csv_output.write(outfile) self.save_roi_coords(self.get_roi_coords_filename(path)) diff --git a/mantidimaging/gui/windows/spectrum_viewer/presenter.py b/mantidimaging/gui/windows/spectrum_viewer/presenter.py index cdee740def6..99ef0c83027 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/presenter.py +++ b/mantidimaging/gui/windows/spectrum_viewer/presenter.py @@ -201,14 +201,15 @@ def handle_roi_moved(self, force_new_spectrums: bool = False) -> None: Handle changes to any ROI position and size. """ for name in self.model.get_list_of_roi_names(): - roi = self.view.spectrum_widget.get_roi(name) - if force_new_spectrums or roi != self.model.get_roi(name): - self.model.set_roi(name, roi) - self.view.set_spectrum( - name, - self.model.get_spectrum(name, - self.spectrum_mode, - normalise_with_shuttercount=self.view.shuttercount_norm_enabled())) + view_roi = self.view.spectrum_widget.get_roi(name) + if force_new_spectrums or view_roi != self.model.get_roi(name): + self.model.set_roi(name, view_roi) + spectrum = self.model.get_spectrum( + view_roi, + self.spectrum_mode, + normalise_with_shuttercount=self.view.shuttercount_norm_enabled(), + ) + self.view.set_spectrum(name, spectrum) def handle_roi_clicked(self, roi: SpectrumROI) -> None: if not roi.name == ROI_RITS: @@ -220,9 +221,10 @@ def redraw_spectrum(self, name: str) -> None: """ Redraw the spectrum with the given name """ + roi = self.model.get_roi(name) self.view.set_spectrum( name, - self.model.get_spectrum(name, + self.model.get_spectrum(roi, self.spectrum_mode, normalise_with_shuttercount=self.view.shuttercount_norm_enabled())) @@ -231,12 +233,13 @@ def redraw_all_rois(self) -> None: Redraw all ROIs and spectrum plots """ for name in self.model.get_list_of_roi_names(): - if name == "all" or self.view.spectrum_widget.roi_dict[name].isVisible() is False: + if name == "all" or not self.view.spectrum_widget.roi_dict[name].isVisible(): continue - self.model.set_roi(name, self.view.spectrum_widget.get_roi(name)) + roi = self.view.spectrum_widget.get_roi(name) + self.model.set_roi(name, roi) self.view.set_spectrum( name, - self.model.get_spectrum(name, + self.model.get_spectrum(roi, self.spectrum_mode, normalise_with_shuttercount=self.view.shuttercount_norm_enabled())) @@ -257,13 +260,17 @@ def handle_button_enabled(self) -> None: def handle_export_csv(self) -> None: path = self.view.get_csv_filename() - if path is None: + if not path: return + path = path.with_suffix(".csv") if path.suffix != ".csv" else path + rois = {roi.name: roi.as_sensible_roi() for roi in self.view.spectrum_widget.roi_dict.values()} - if path.suffix != ".csv": - path = path.with_suffix(".csv") - - self.model.save_csv(path, self.spectrum_mode == SpecType.SAMPLE_NORMED, self.view.shuttercount_norm_enabled()) + self.model.save_csv( + path, + rois, + normalise=self.spectrum_mode == SpecType.SAMPLE_NORMED, + normalise_with_shuttercount=self.view.shuttercount_norm_enabled(), + ) def handle_rits_export(self) -> None: """ @@ -331,10 +338,12 @@ def do_add_roi(self) -> None: roi_name = self.model.roi_name_generator() if roi_name in self.view.spectrum_widget.roi_dict: raise ValueError(f"ROI name already exists: {roi_name}") + self.model.set_new_roi(roi_name) - self.view.spectrum_widget.add_roi(self.model.get_roi(roi_name), roi_name) - self.view.set_spectrum( - roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled())) + roi = self.model.get_roi(roi_name) + self.view.spectrum_widget.add_roi(roi, roi_name) + spectrum = self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled()) + self.view.set_spectrum(roi_name, spectrum) self.view.auto_range_image() self.do_add_roi_to_table(roi_name) @@ -351,11 +360,11 @@ def change_roi_colour(self, roi_name: str, new_colour: tuple[int, int, int]) -> self.view.on_visibility_change() def add_rits_roi(self) -> None: - roi_name = ROI_RITS - self.model.set_new_roi(roi_name) - self.view.spectrum_widget.add_roi(self.model.get_roi(roi_name), roi_name) - self.view.set_spectrum( - roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled())) + self.model.set_new_roi(ROI_RITS) + roi = self.model.get_roi(ROI_RITS) + self.view.spectrum_widget.add_roi(roi, ROI_RITS) + self.view.set_spectrum(ROI_RITS, + self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled())) self.view.set_roi_visibility_flags(ROI_RITS, visible=False) def do_add_roi_to_table(self, roi_name: str) -> None: @@ -386,13 +395,16 @@ def do_remove_roi(self, roi_name: str | None = None) -> None: """ if roi_name is None: self.view.clear_all_rois() - for roi in self.get_roi_names(): - self.view.spectrum_widget.remove_roi(roi) + for name in self.get_roi_names(): + self.view.spectrum_widget.remove_roi(name) self.model.remove_all_roi() else: + roi = self.model.get_roi(roi_name) self.view.spectrum_widget.remove_roi(roi_name) - self.view.set_spectrum( - roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled())) + spectrum = self.model.get_spectrum(roi, + self.spectrum_mode, + normalise_with_shuttercount=self.view.shuttercount_norm_enabled()) + self.view.set_spectrum(roi_name, spectrum) self.model.remove_roi(roi_name) def handle_export_tab_change(self, index: int) -> None: diff --git a/mantidimaging/gui/windows/spectrum_viewer/spectrum_widget.py b/mantidimaging/gui/windows/spectrum_viewer/spectrum_widget.py index 227bfac29bd..72a30206038 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/spectrum_widget.py +++ b/mantidimaging/gui/windows/spectrum_viewer/spectrum_widget.py @@ -110,6 +110,18 @@ def adjust_spec_roi(self, roi: SensibleROI) -> None: def rename_roi(self, new_name: str) -> None: self._name = new_name + def as_sensible_roi(self) -> SensibleROI: + """ + Converts the SpectrumROI to a SensibleROI object. + """ + pos = self.pos() + size = self.size() + left, top = pos + width, height = size + right = left + width + bottom = top + height + return SensibleROI.from_list([left, top, right, bottom]) + class SpectrumWidget(QWidget): """ diff --git a/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py b/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py index e98f050a498..75ec40cdfc6 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py +++ b/mantidimaging/gui/windows/spectrum_viewer/test/model_test.py @@ -97,22 +97,20 @@ def test_get_averaged_image_range(self): def test_get_spectrum(self): stack, spectrum = self._set_sample_stack() - - model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE) + roi = SensibleROI(left=0, top=0, right=12, bottom=11) + model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE) self.assertEqual(model_spec.shape, (10, )) npt.assert_array_equal(model_spec, spectrum) def test_get_normalised_spectrum(self): stack, spectrum = self._set_sample_stack() - normalise_stack = ImageStack(np.ones([10, 11, 12]) * 2) self.model.set_normalise_stack(normalise_stack) - - model_open_spec = self.model.get_spectrum("roi", SpecType.OPEN) + roi = SensibleROI(left=0, top=0, right=12, bottom=11) + model_open_spec = self.model.get_spectrum(roi, SpecType.OPEN) self.assertEqual(model_open_spec.shape, (10, )) self.assertTrue(np.all(model_open_spec == 2)) - - model_norm_spec = self.model.get_spectrum("roi", SpecType.SAMPLE_NORMED) + model_norm_spec = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED) self.assertEqual(model_norm_spec.shape, (10, )) npt.assert_array_equal(model_norm_spec, spectrum / 2) @@ -122,8 +120,8 @@ def test_get_normalised_spectrum_zeros(self): normalise_stack = ImageStack(np.ones([10, 11, 12]) * 2) normalise_stack.data[5] = 0 self.model.set_normalise_stack(normalise_stack) - - model_norm_spec = self.model.get_spectrum("roi", SpecType.SAMPLE_NORMED) + roi = SensibleROI(left=0, top=0, right=12, bottom=11) + model_norm_spec = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED) expected_spec = spectrum / 2 expected_spec[5] = 0 self.assertEqual(model_norm_spec.shape, (10, )) @@ -134,8 +132,8 @@ def test_get_normalised_spectrum_different_size(self): normalise_stack = ImageStack(np.ones([10, 11, 13])) self.model.set_normalise_stack(normalise_stack) - - error_spectrum = self.model.get_spectrum("all", SpecType.SAMPLE_NORMED) + roi = SensibleROI(left=0, top=0, right=13, bottom=11) + error_spectrum = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED) np.testing.assert_array_equal(error_spectrum, np.array([])) def test_normalise_issue(self): @@ -168,12 +166,12 @@ def test_get_spectrum_roi(self): stack, spectrum = self._set_sample_stack() stack.data[:, :, 6:] *= 2 - self.model.set_roi('roi', SensibleROI.from_list([0, 0, 3, 3])) - model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE) + roi = SensibleROI.from_list([0, 0, 3, 3]) + model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE) npt.assert_array_equal(model_spec, spectrum) - self.model.set_roi('roi', SensibleROI.from_list([6, 0, 6 + 3, 3])) - model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE) + roi = SensibleROI.from_list([6, 0, 6 + 3, 3]) + model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE) npt.assert_array_equal(model_spec, spectrum * 2) def test_get_stack_spectrum(self): @@ -191,9 +189,13 @@ def test_save_csv(self): stack.data *= 2 self.model.set_normalise_stack(None) + roi_all = SensibleROI.from_list([0, 0, 12, 11]) + roi_specific = SensibleROI.from_list([0, 0, 3, 3]) + rois = {"all": roi_all, "roi": roi_specific} + mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.model.save_csv(mock_path, False) + self.model.save_csv(mock_path, rois=rois, normalise=False) mock_path.open.assert_called_once_with("w") self.assertIn("# ToF_index,all,roi", mock_stream.captured[0]) @@ -320,8 +322,12 @@ def test_save_csv_norm_missing_stack(self): stack, _ = self._set_sample_stack() stack.data *= 2 self.model.set_normalise_stack(None) + + roi_all = SensibleROI.from_list([0, 0, 12, 11]) + rois = {"all": roi_all} + with self.assertRaises(RuntimeError): - self.model.save_csv(mock.Mock(), True) + self.model.save_csv(mock.Mock(), rois=rois, normalise=True) def test_save_csv_norm(self): self._set_sample_stack() @@ -329,9 +335,13 @@ def test_save_csv_norm(self): open_stack = ImageStack(np.ones([10, 11, 12]) * 2) self.model.set_normalise_stack(open_stack) + roi_all = SensibleROI.from_list([0, 0, 12, 11]) + roi_specific = SensibleROI.from_list([0, 0, 3, 3]) + rois = {"all": roi_all, "roi": roi_specific} + mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.model.save_csv(mock_path, True) + self.model.save_csv(path=mock_path, rois=rois, normalise=True, normalise_with_shuttercount=False) mock_path.open.assert_called_once_with("w") self.assertIn("# ToF_index,all,all_open,all_norm,roi,roi_open,roi_norm", mock_stream.captured[0]) @@ -346,9 +356,12 @@ def test_save_csv_norm_with_tof_loaded(self): stack.data[:, :, :5] *= 2 self.model.set_normalise_stack(norm) + roi_all = SensibleROI.from_list([0, 0, 12, 11]) + rois = {"all": roi_all, "roi": roi_all} + mock_stream, mock_path = self._make_mock_path_stream() with mock.patch.object(self.model, "save_roi_coords"): - self.model.save_csv(mock_path, True) + self.model.save_csv(mock_path, rois=rois, normalise=True, normalise_with_shuttercount=False) mock_path.open.assert_called_once_with("w") self.assertIn("# ToF_index,Wavelength,ToF,Energy,all,all_open,all_norm,roi,roi_open,roi_norm", diff --git a/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py b/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py index d6cd7383119..a1257406bdb 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py +++ b/mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py @@ -205,13 +205,13 @@ def test_handle_export_csv(self, path_name: str, mock_save_csv: mock.Mock, mock_ self.view.get_csv_filename = mock.Mock(return_value=Path(path_name)) self.view.shuttercount_norm_enabled.return_value = False mock_shuttercount_issue.return_value = "Something wrong" - self.presenter.model.set_stack(generate_images()) - self.presenter.handle_export_csv() self.view.get_csv_filename.assert_called_once() - mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), False, False) + mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), {}, + normalise=False, + normalise_with_shuttercount=False) @parameterized.expand(["/fake/path", "/fake/path.dat"]) @mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_rits_roi") diff --git a/mantidimaging/gui/windows/spectrum_viewer/test/spectrum_test.py b/mantidimaging/gui/windows/spectrum_viewer/test/spectrum_test.py index 65b3add07a9..fee63dbf2a7 100644 --- a/mantidimaging/gui/windows/spectrum_viewer/test/spectrum_test.py +++ b/mantidimaging/gui/windows/spectrum_viewer/test/spectrum_test.py @@ -42,6 +42,11 @@ def test_WHEN_colour_is_not_valid_THEN_roi_colour_is_unchanged(self): self.spectrum_roi.onChangeColor() self.assertEqual(self.spectrum_roi.colour, (0, 0, 0, 255)) + def test_WHEN_as_sensible_roi_called_THEN_correct_sensible_roi_returned(self): + sensible_roi = self.spectrum_roi.as_sensible_roi() + self.assertEqual((sensible_roi.left, sensible_roi.top, sensible_roi.right, sensible_roi.bottom), + (10, 20, 30, 40)) + @mock_versions @start_qapplication