From 50fd5b6baeb83ecf09189ef86dc9b3f73dcbf2ea Mon Sep 17 00:00:00 2001 From: Mike Sullivan Date: Thu, 8 Aug 2024 16:45:39 +0100 Subject: [PATCH] added ability to read fits files via delayed dask functions --- mantidimaging/gui/windows/live_viewer/model.py | 18 +++++++++++++++--- .../gui/windows/live_viewer/presenter.py | 7 ++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/mantidimaging/gui/windows/live_viewer/model.py b/mantidimaging/gui/windows/live_viewer/model.py index 7ce41c8cc1e..2f6c09be921 100644 --- a/mantidimaging/gui/windows/live_viewer/model.py +++ b/mantidimaging/gui/windows/live_viewer/model.py @@ -11,6 +11,7 @@ from PyQt5.QtCore import QFileSystemWatcher, QObject, pyqtSignal, QTimer import dask_image.imread +from astropy.io import fits if TYPE_CHECKING: from os import stat_result @@ -25,10 +26,18 @@ class DaskImageDataStack: """ delayed_stack: dask.array.Array | None = None - def __init__(self, image_list: list[Image_Data]): + def __init__(self, image_list: list[Image_Data] | None): if image_list: if image_list[0].create_delayed_array: - self.delayed_stack = dask.array.concatenate([image_data.delayed_array for image_data in image_list]) + if image_list[0].image_path.suffix.lower() in [".tif", ".tiff"]: + arrays = [image_data.delayed_array for image_data in image_list] + self.delayed_stack = dask.array.stack(dask.array.array(arrays)) + elif image_list[0].image_path.suffix.lower() in [".fits"]: + with fits.open(image_list[0].image_path.__str__()) as fit: + sample = fit[0].data + arrays = [image_data.delayed_array for image_data in image_list] + lazy_arrays =[dask.array.from_delayed(x, shape=sample.shape, dtype=sample.dtype) for x in arrays] + self.delayed_stack = dask.array.stack(lazy_arrays) @property def shape(self): @@ -88,7 +97,10 @@ def image_modified_time_stamp(self) -> str: return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.image_modified_time)) def set_delayed_array(self) -> None: - self.delayed_array = dask_image.imread.imread(self.image_path) + if self.image_path.suffix.lower() in [".tif", ".tiff"]: + self.delayed_array = dask_image.imread.imread(self.image_path)[0] + elif self.image_path.suffix.lower() == ".fits": + self.delayed_array = dask.delayed(fits.open)(self.image_path)[0].data class SubDirectory: diff --git a/mantidimaging/gui/windows/live_viewer/presenter.py b/mantidimaging/gui/windows/live_viewer/presenter.py index 09c675a83e1..3e5899539bc 100644 --- a/mantidimaging/gui/windows/live_viewer/presenter.py +++ b/mantidimaging/gui/windows/live_viewer/presenter.py @@ -110,11 +110,8 @@ def load_image(image_data_obj: Image_Data) -> np.ndarray: Load a .Tif, .Tiff or .Fits file only if it exists and returns as an ndarray """ - if image_data_obj.image_path.suffix.lower() in [".tif", ".tiff"]: - image_data = image_data_obj.delayed_array.compute()[0] - elif image_data_obj.image_path.suffix.lower() == ".fits": - with fits.open(image_data_obj.image_path.__str__()) as fit: - image_data = fit[0].data + if image_data_obj.image_path.suffix.lower() in [".tif", ".tiff", ".fits"]: + image_data = image_data_obj.delayed_array.compute() return image_data def update_image_modified(self, image_path: Path) -> None: