Skip to content

Commit

Permalink
Added type annobnations to core
Browse files Browse the repository at this point in the history
Signed-off-by: ashmeigh <[email protected]>
  • Loading branch information
ashmeigh committed Jul 25, 2024
1 parent 45416e7 commit b2857db
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 57 deletions.
28 changes: 14 additions & 14 deletions mantidimaging/core/gpu/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@
}


def _cupy_on_system():
def _cupy_on_system() -> bool:
"""
:return: True if cupy is installed on the system, False otherwise.
"""
return not CUPY_NOT_IMPORTED


def _cupy_installed_correctly():
def _cupy_installed_correctly() -> bool:
"""
:return: True if cupy is able to run on the system, False otherwise.
"""
Expand All @@ -64,14 +64,14 @@ def _cupy_installed_correctly():
return False


def gpu_available():
def gpu_available() -> bool:
"""
:return: True if cupy is installed AND working, False otherwise.
"""
return _cupy_on_system() and _cupy_installed_correctly()


def _load_cuda_kernel(dtype):
def _load_cuda_kernel(dtype) -> str:
"""
Loads the CUDA kernel so that cupy can act as a mediator. Replaces instances of 'float' with 'double' if the dtype
is float64.
Expand All @@ -86,7 +86,7 @@ def _load_cuda_kernel(dtype):
return cuda_kernel


def _free_memory_pool(arrays=None):
def _free_memory_pool(arrays=None) -> None:
"""
Delete any given GPU arrays and instruct the memory pool to free unused blocks.
"""
Expand All @@ -95,7 +95,7 @@ def _free_memory_pool(arrays=None):
mempool.free_all_blocks()


def _create_pinned_memory(cpu_array):
def _create_pinned_memory(cpu_array: np.ndarray) -> np.ndarray:
"""
Use pinned memory in order to store a numpy array on the GPU.
:param cpu_array: The numpy array to be transferred to the GPU.
Expand All @@ -107,7 +107,7 @@ def _create_pinned_memory(cpu_array):
return src


def _send_single_array_to_gpu(cpu_array, stream):
def _send_single_array_to_gpu(cpu_array: np.ndarray, stream: cp.cuda.Stream) -> cp.ndarray:
"""
Sends a single array to the GPU using pinned memory and a stream.
:param cpu_array: The numpy array to be transferred to the GPU.
Expand All @@ -120,7 +120,7 @@ def _send_single_array_to_gpu(cpu_array, stream):
return gpu_array


def _send_arrays_to_gpu_with_pinned_memory(cpu_arrays, streams):
def _send_arrays_to_gpu_with_pinned_memory(cpu_arrays, streams) -> list[cp.ndarray]:
"""
Transfer the arrays to the GPU using pinned memory. Raises an error if the GPU runs out of memory.
:param cpu_arrays: A list of numpy arrays to be transferred to the GPU.
Expand All @@ -145,7 +145,7 @@ def _send_arrays_to_gpu_with_pinned_memory(cpu_arrays, streams):
return []


def _create_block_and_grid_args(data):
def _create_block_and_grid_args(data: cp.ndarray):
"""
Create the block and grid arguments that are passed to the cupy. These determine how the array
is broken up.
Expand All @@ -158,7 +158,7 @@ def _create_block_and_grid_args(data):
return block_size, grid_size


def _create_padded_array(data, filter_size, scipy_mode):
def _create_padded_array(data: np.ndarray, filter_size: int, scipy_mode: str) -> np.ndarray:
"""
Creates the padded array on the CPU for the median filter.
:param data: The data array to be padded.
Expand All @@ -171,7 +171,7 @@ def _create_padded_array(data, filter_size, scipy_mode):
return np.pad(data, pad_width=((pad_size, pad_size), (pad_size, pad_size)), mode=EQUIVALENT_PAD_MODE[scipy_mode])


def _replace_gpu_array_contents(gpu_array, cpu_array, stream):
def _replace_gpu_array_contents(gpu_array: cp.ndarray, cpu_array: np.ndarray, stream: cp.cuda.Stream) -> None:
"""
Overwrites the contents of an existing GPU array with a given CPU array.
:param gpu_array: The GPU array to be overwritten.
Expand All @@ -181,7 +181,7 @@ def _replace_gpu_array_contents(gpu_array, cpu_array, stream):
gpu_array.set(cpu_array, stream)


def _get_padding_value(filter_size):
def _get_padding_value(filter_size: int) -> int:
"""
Determine the padding value by using the filter size.
:param filter_size: The filter size.
Expand All @@ -202,7 +202,7 @@ def __init__(self, dtype):
# Warm up the CUDA functions
self._warm_up(dtype)

def _warm_up(self, dtype):
def _warm_up(self, dtype) -> None:
"""
Runs the median filter on a small test array in order to allow it to compile then deleted the GPU arrays.
:param dtype: The data type of the input array.
Expand All @@ -219,7 +219,7 @@ def _warm_up(self, dtype):
# Clear the test arrays
_free_memory_pool([test_data, test_padding])

def _cuda_single_image_median_filter(self, input_data, padded_data, filter_size, grid_size, block_size):
def _cuda_single_image_median_filter(self, input_data, padded_data, filter_size, grid_size, block_size) -> None:
"""
Run the median filter on a single 2D image using CUDA.
:param input_data: A 2D GPU data array.
Expand Down
4 changes: 2 additions & 2 deletions mantidimaging/core/io/instrument_log_implmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def read_imat_date(time_stamp: str) -> datetime:
locale.setlocale(locale.LC_TIME, lc)

@staticmethod
def _has_imat_header(line: str):
def _has_imat_header(line: str) -> bool:
HEADERS = [
"TIME STAMP,IMAGE TYPE,IMAGE COUNTER,COUNTS BM3 before image,COUNTS BM3 after image",
"TIME STAMP IMAGE TYPE IMAGE COUNTER COUNTS BM3 before image COUNTS BM3 after image",
]
return line.strip() in HEADERS

@classmethod
def _has_imat_data_line(cls, line: str):
def _has_imat_data_line(cls, line: str) -> bool:
try:
_ = cls.read_imat_date(line[:24])
except ValueError:
Expand Down
21 changes: 12 additions & 9 deletions mantidimaging/core/io/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@
package_version = CheckVersion().get_version()


def write_fits(data: np.ndarray, filename: str, overwrite: bool = False, description: str | None = ""):
def write_fits(data: np.ndarray, filename: str, overwrite: bool = False, description: str | None = "") -> None:
hdu = fits.PrimaryHDU(data)
hdulist = fits.HDUList([hdu])
hdulist.writeto(filename, overwrite=overwrite)


def write_img(data: np.ndarray, filename: str, overwrite: bool = False, description: str | None = ""):
def write_img(data: np.ndarray, filename: str, overwrite: bool = False, description: str | None = "") -> None:
tifffile.imwrite(filename, data, description=description, metadata=None, software="Mantid Imaging")


def write_nxs(data: np.ndarray, filename: str, projection_angles: np.ndarray | None = None, overwrite: bool = False):
def write_nxs(data: np.ndarray,
filename: str,
projection_angles: np.ndarray | None = None,
overwrite: bool = False) -> None:
import h5py
nxs = h5py.File(filename, 'w')

Expand Down Expand Up @@ -177,7 +180,7 @@ def image_save(images: ImageStack,
return names


def nexus_save(dataset: StrictDataset, path: str, sample_name: str, save_as_float: bool):
def nexus_save(dataset: StrictDataset, path: str, sample_name: str, save_as_float: bool) -> None:
"""
Uses information from a StrictDataset to create a NeXus file.
:param dataset: The dataset to save as a NeXus file.
Expand All @@ -199,7 +202,7 @@ def nexus_save(dataset: StrictDataset, path: str, sample_name: str, save_as_floa
nexus_file.close()


def _nexus_save(nexus_file: h5py.File, dataset: StrictDataset, sample_name: str, save_as_float: bool):
def _nexus_save(nexus_file: h5py.File, dataset: StrictDataset, sample_name: str, save_as_float: bool) -> None:
"""
Takes a NeXus file and writes the StrictDataset information to it.
:param nexus_file: The NeXus file.
Expand Down Expand Up @@ -252,7 +255,7 @@ def _nexus_save(nexus_file: h5py.File, dataset: StrictDataset, sample_name: str,


def _save_processed_data_to_nexus(nexus_file: h5py.File, dataset: StrictDataset, rotation_angle: h5py.Dataset,
image_key: h5py.Dataset, save_as_float: bool):
image_key: h5py.Dataset, save_as_float: bool) -> None:
data = nexus_file.create_group(NEXUS_PROCESSED_DATA_PATH)
data["rotation_angle"] = rotation_angle
data["image_key"] = image_key
Expand All @@ -266,7 +269,7 @@ def _save_processed_data_to_nexus(nexus_file: h5py.File, dataset: StrictDataset,
process.create_dataset("version", data=np.bytes_(package_version))


def _save_image_stacks_to_nexus(dataset: StrictDataset, data_group: h5py.Group, save_as_float: bool):
def _save_image_stacks_to_nexus(dataset: StrictDataset, data_group: h5py.Group, save_as_float: bool) -> None:
combined_data_shape = (sum([len(arr) for arr in dataset.nexus_arrays]), ) + dataset.nexus_arrays[0].shape[1:]

index = 0
Expand Down Expand Up @@ -305,7 +308,7 @@ def scale_row(row):
return converted, factors


def _save_recon_to_nexus(nexus_file: h5py.File, recon: ImageStack, sample_path: str):
def _save_recon_to_nexus(nexus_file: h5py.File, recon: ImageStack, sample_path: str) -> None:
"""
Saves a recon to a NeXus file.
:param nexus_file: The NeXus file.
Expand Down Expand Up @@ -369,7 +372,7 @@ def _create_pixel_size_arrays(recon: ImageStack) -> tuple[np.ndarray, np.ndarray
return x_arr, y_arr, z_arr


def _set_nx_class(group: h5py.Group, class_name: str):
def _set_nx_class(group: h5py.Group, class_name: str) -> None:
"""
Sets the NX_class attribute of data in a NeXus file.
:param group: The h5py group.
Expand Down
4 changes: 2 additions & 2 deletions mantidimaging/core/net/help_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
SECTION_USER_GUIDE = f"{DOCS_BASE}/user_guide/"


def open_user_operation_docs(operation_name: str):
def open_user_operation_docs(operation_name: str) -> None:
page_url = "operations/index"
section = operation_name.lower().replace(" ", "-")
open_help_webpage(SECTION_USER_GUIDE, page_url, section)


def open_help_webpage(section_url: str, page_url: str, section: str | None = None):
def open_help_webpage(section_url: str, page_url: str, section: str | None = None) -> None:
if section is not None:
url = f"{section_url}{page_url}.html#{section}"
else:
Expand Down
2 changes: 1 addition & 1 deletion mantidimaging/core/parallel/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run_compute_func(func: ComputeFuncType,
num_operations: int,
arrays: list[pu.SharedArray] | pu.SharedArray,
params: dict[str, Any],
progress=None):
progress=None) -> None:
if isinstance(arrays, pu.SharedArray):
arrays = [arrays]
all_data_in_shared_memory, data = _check_shared_mem_and_get_data(arrays)
Expand Down
6 changes: 3 additions & 3 deletions mantidimaging/core/parallel/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
LOG = getLogger(__name__)


def enough_memory(shape, dtype):
def enough_memory(shape, dtype) -> bool:
return full_size_KB(shape=shape, dtype=dtype) < system_free_memory().kb()


Expand Down Expand Up @@ -103,7 +103,7 @@ def multiprocessing_necessary(shape: int, is_shared_data: bool) -> bool:
return True


def execute_impl(img_num: int, partial_func: partial, is_shared_data: bool, progress: Progress, msg: str):
def execute_impl(img_num: int, partial_func: partial, is_shared_data: bool, progress: Progress, msg: str) -> None:
task_name = f"{msg}"
progress = Progress.ensure_instance(progress, num_steps=img_num, task_name=task_name)
indices_list = range(img_num)
Expand All @@ -128,7 +128,7 @@ def run_compute_func_impl(worker_func: Callable[[int], None],
num_operations: int,
is_shared_data: bool,
progress=None,
msg: str = ""):
msg: str = "") -> None:
task_name = f"{msg}"
progress = Progress.ensure_instance(progress, num_steps=num_operations, task_name=task_name)
indices_list = range(num_operations)
Expand Down
4 changes: 2 additions & 2 deletions mantidimaging/core/reconstruct/astra_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
# Full credit for following code to Daniil Kazantzev
# Source:
# https://github.com/dkazanc/ToMoBAR/blob/5990aaa264e2f08bd9b0069c8847e5021fbf2ee2/src/Python/tomobar/supp/astraOP.py#L20-L70
def rotation_matrix2d(theta: float):
def rotation_matrix2d(theta: float) -> np.ndarray:
return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])


def vec_geom_init2d(angles_rad: ProjectionAngles, detector_spacing_x: float, center_rot_offset: float):
def vec_geom_init2d(angles_rad: ProjectionAngles, detector_spacing_x: float, center_rot_offset: float) -> np.ndarray:
angles_value = angles_rad.value
s0 = [0.0, -1.0] # source
u0 = [detector_spacing_x, 0.0] # detector coordinates
Expand Down
2 changes: 1 addition & 1 deletion mantidimaging/core/reconstruct/base_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def find_cor(images: ImageStack, slice_idx: int, start_cor: float, recon_params:
raise NotImplementedError("Base class call")

@staticmethod
def prepare_sinogram(data: np.ndarray, recon_params: ReconstructionParameters):
def prepare_sinogram(data: np.ndarray, recon_params: ReconstructionParameters) -> np.ndarray:
logged_data = BaseRecon.negative_log(data)
if recon_params.beam_hardening_coefs is not None:
coefs = np.array([0.0, 1.0] + recon_params.beam_hardening_coefs)
Expand Down
4 changes: 2 additions & 2 deletions mantidimaging/core/reconstruct/tomopy_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def single_sino(sino: np.ndarray,
cor: ScalarCoR,
proj_angles: ProjectionAngles,
recon_params: ReconstructionParameters,
progress: Progress | None = None):
progress: Progress | None = None) -> np.ndarray:
sino = BaseRecon.prepare_sinogram(sino, recon_params)
volume = tomopy.recon(tomo=[sino],
sinogram_order=True,
Expand All @@ -51,7 +51,7 @@ def single_sino(sino: np.ndarray,
def full(images: ImageStack,
cors: list[ScalarCoR],
recon_params: ReconstructionParameters,
progress: Progress | None = None):
progress: Progress | None = None) -> np.ndarray:
"""
Performs a volume reconstruction using sample data provided as sinograms.
Expand Down
Loading

0 comments on commit b2857db

Please sign in to comment.