diff --git a/deepcave/custom_queue.py b/deepcave/custom_queue.py index 1a1b9903..332fd302 100644 --- a/deepcave/custom_queue.py +++ b/deepcave/custom_queue.py @@ -10,7 +10,7 @@ - Queue: This class defines all components for a job queue. """ -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional import redis from rq import Queue as _Queue @@ -256,7 +256,7 @@ def get_finished_jobs(self) -> List[Job]: """Get the finished jobs in the registry.""" return self.get_jobs(registry="finished") - def delete_job(self, job_id: str = None) -> None: + def delete_job(self, job_id: Optional[str] = None) -> None: """ Delete a job from the queue. If no job_id is given, delete all jobs. @@ -266,7 +266,7 @@ def delete_job(self, job_id: str = None) -> None: Id of the job, which should be removed. By default None. """ - def remove_jobs(registry: BaseRegistry, job_id: str = None) -> None: + def remove_jobs(registry: BaseRegistry, job_id: Optional[str] = None) -> None: """ Remove a job from the registry. If no job_id is given, remove all. @@ -280,17 +280,18 @@ def remove_jobs(registry: BaseRegistry, job_id: str = None) -> None: """ if job_id is not None: try: - registry.remove(job_id, delete_job=True) + registry.remove(self.get_job(job_id), delete_job=True) except Exception: pass else: # Remove all for job_id in registry.get_job_ids(): try: - registry.remove(job_id, delete_job=True) + registry.remove(self.get_job(job_id), delete_job=True) except Exception: - registry.remove(job_id) + registry.remove(self.get_job(job_id)) + # Issue opened remove_jobs(self._queue, job_id) remove_jobs(self._queue.finished_job_registry, job_id) remove_jobs(self._queue.canceled_job_registry, job_id) @@ -362,7 +363,7 @@ def enqueue( result_ttl=-1, # Make sure it's not automatically deleted. ) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """If function is not found, make sure we access self._queue directly.""" try: return self.__getattribute__(name) diff --git a/deepcave/evaluators/epm/test.py b/deepcave/evaluators/epm/test.py index 30e6e899..cf7d5a00 100644 --- a/deepcave/evaluators/epm/test.py +++ b/deepcave/evaluators/epm/test.py @@ -1,5 +1,6 @@ # noqa -"""from typing import Any, Dict, List +"""from typing import Any, Dict, List, Optional + import numpy as np from ConfigSpace import ConfigurationSpace @@ -104,7 +105,14 @@ styled_plot = styled_plot.StyledPlot() plt = styled_plot.plt -print(type(plt.render()))""" +print(type(plt.render())) list = [1, 2, 3, 4, 5, 6, 7, 8, 9] list_comp = [1.5 for i in list] + + +def test(hallo: Optional[int] = 2): + print(type(hallo)) + +test(5) +test()""" diff --git a/deepcave/runs/__init__.py b/deepcave/runs/__init__.py index 5fd0df5f..c3d32f47 100644 --- a/deepcave/runs/__init__.py +++ b/deepcave/runs/__init__.py @@ -36,6 +36,7 @@ from deepcave.utils.logs import get_logger + class AbstractRun(ABC): """ Can create, handle and get information of an abstract run. @@ -96,8 +97,8 @@ def reset(self) -> None: """ self.meta: Dict[str, Any] = {} self.configspace: ConfigSpace.ConfigurationSpace - self.configs: Dict[int, Configuration] = {} - self.origins: Dict[int, str] = {} + self.configs: Dict[int, Union[Configuration, Dict[Any, Any]]] = {} + self.origins: Dict[int, Optional[str]] = {} self.models: Dict[int, Optional[Union[str, "torch.nn.Module"]]] = {} # noqa: F821 self.history: List[Trial] = [] @@ -246,7 +247,7 @@ def empty(self) -> bool: """ return len(self.history) == 0 - def get_origin(self, config_id: int) -> str: + def get_origin(self, config_id: int) -> Optional[str]: """ Get the origin, given a config ID. @@ -372,7 +373,7 @@ def get_objective_names(self) -> List[str]: """ return [obj.name for obj in self.get_objectives()] - def get_configs(self, budget: Union[int, float] = None) -> Dict[int, Configuration]: + def get_configs(self, budget: Optional[Union[int, float]] = None) -> Dict[int, Configuration]: """ Get configurations of the run. @@ -449,7 +450,7 @@ def get_config_id(self, config: Union[Configuration, Dict]) -> Optional[int]: return None - def get_num_configs(self, budget: Union[int, float] = None) -> int: + def get_num_configs(self, budget: Optional[Union[int, float]] = None) -> int: """ Count the number of configurations stored in this object with a specific budget. @@ -1196,7 +1197,7 @@ def check_equality( If the budgets of the runs are not equal. If the objective of the runs are not equal. """ - result = {} + result: Dict[str, Any] = {} if len(runs) == 0: return result @@ -1259,6 +1260,7 @@ def check_equality( for o1_, o2_ in zip(o1, o2): o1_.merge(o2_) + assert o1 is not None serialized_objectives = [o.to_json() for o in o1] result["objectives"] = serialized_objectives if meta: diff --git a/deepcave/runs/converters/deepcave.py b/deepcave/runs/converters/deepcave.py index fa5ec15b..e60b4d19 100644 --- a/deepcave/runs/converters/deepcave.py +++ b/deepcave/runs/converters/deepcave.py @@ -55,6 +55,6 @@ def latest_change(self) -> Union[float, int]: return Path(self.path / "history.jsonl").stat().st_mtime @classmethod - def from_path(cls, path): + def from_path(cls, path: Path) -> "DeepCAVERun": """Get a DeepCAVE run from a given path.""" return DeepCAVERun(path.stem, path=Path(path)) diff --git a/deepcave/runs/group.py b/deepcave/runs/group.py index 4b870cfe..6ab4cdd1 100644 --- a/deepcave/runs/group.py +++ b/deepcave/runs/group.py @@ -9,7 +9,7 @@ - Group: Can group and manage a group of abstract runs. """ -from typing import Dict, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from copy import deepcopy @@ -133,7 +133,7 @@ def __init__(self, name: str, runs: List[AbstractRun]): except Exception as e: raise NotMergeableError(f"Runs can not be merged: {e}") - def __iter__(self): + def __iter__(self: "Group") -> Iterator[str]: """Allow to iterate over the object.""" for run in self.runs: yield run.name @@ -156,9 +156,9 @@ def id(self) -> str: return string_to_hash(f"{self.prefix}:{self.name}") @property - def latest_change(self) -> int: + def latest_change(self) -> float: """Get the latest change made to the grouped runs.""" - latest_change = 0 + latest_change = 0.0 for run in self.runs: if run.latest_change > latest_change: latest_change = run.latest_change @@ -183,7 +183,7 @@ def get_new_config_id(self, run_id: int, original_config_id: int) -> int: """Get a new identificator for a configuration.""" return self._new_config_mapping[(run_id, original_config_id)] - def get_original_config_id(self, config_id: int) -> id: + def get_original_config_id(self, config_id: int) -> int: """Get the original identificator of a configuration.""" return self._original_config_mapping[config_id][1] @@ -192,11 +192,12 @@ def get_original_run(self, config_id: int) -> AbstractRun: run_id = self._original_config_mapping[config_id][0] return self.runs[run_id] - def get_model(self, config_id: int): + def get_model(self, config_id: int) -> Optional[Any]: """Get the model of the runs.""" run_id, config_id = self._original_config_mapping[config_id] return self.runs[run_id].get_model(config_id) + # wait until meeting def get_trajectory(self, *args, **kwargs): """ Get the trajectory of the group. @@ -251,10 +252,10 @@ def get_trajectory(self, *args, **kwargs): all_costs.append(y) # Make numpy arrays - all_costs = np.array(all_costs) + all_costs_array = np.array(all_costs) times = all_times - costs_mean = np.mean(all_costs, axis=1) - costs_std = np.std(all_costs, axis=1) + costs_mean = np.mean(all_costs_array, axis=1) + costs_std = np.std(all_costs_array, axis=1) return times, list(costs_mean), list(costs_std), [], [] diff --git a/deepcave/runs/handler.py b/deepcave/runs/handler.py index 1c357534..c4eba8cc 100644 --- a/deepcave/runs/handler.py +++ b/deepcave/runs/handler.py @@ -90,8 +90,17 @@ def get_working_directory(self) -> Path: ------- Path Path of the working directory. + + Raises + ------ + AssertionError + If the working directory is not a string or a path like, an error is thrown. """ - return Path(self.c.get("working_dir")) + working_dir = self.c.get("working_dir") + assert isinstance( + working_dir, (str, Path) + ), "Working directory of cache must be a string or a Path like." + return Path(working_dir) def get_available_run_paths(self) -> Dict[str, str]: """ @@ -135,8 +144,17 @@ def get_selected_run_paths(self) -> List[str]: ------- List[str] Run paths as a list. + + Raises + ------ + AssertionError. + If the selected run paths are not a list, an error is thrown. """ - return self.c.get("selected_run_paths") + selected_run_paths = self.c.get("selected_run_paths") + assert isinstance( + selected_run_paths, list + ), "The selected run paths of the cache must be a list." + return selected_run_paths def get_selected_run_names(self) -> List[str]: """ @@ -166,8 +184,19 @@ def get_run_name(self, run_path: Union[Path, str]) -> str: return Path(run_path).stem def get_selected_groups(self) -> Dict[str, List[str]]: - """Get the selected groups.""" - return self.c.get("groups") + """ + Get the selected groups. + + Raises + ------ + AssertionError + If groups in cache is not a dict, an error is thrown. + """ + selected_groups = self.c.get("groups") + assert isinstance( + selected_groups, dict + ), "The groups aquired from the cache must be a dictionary." + return selected_groups def add_run(self, run_path: str) -> bool: """ @@ -201,9 +230,17 @@ def remove_run(self, run_path: str) -> None: ---------- run_path : str Path of a run. + + Raises + ------ + TypeError + If `selected_run_paths` or `groups` is None, an error is thrown. """ selected_run_paths = self.c.get("selected_run_paths") + if selected_run_paths is None: + raise TypeError("Selected run paths can not be None.") + if run_path in selected_run_paths: selected_run_paths.remove(run_path) self.c.set("selected_run_paths", value=selected_run_paths) @@ -211,7 +248,10 @@ def remove_run(self, run_path: str) -> None: # We have to check the groups here because the removed run_path may # still be included groups = {} - for group_name, run_paths in self.c.get("groups").items(): + group_it = self.c.get("groups") + if group_it is None: + raise TypeError("Groups can not be None.") + for group_name, run_paths in group_it.items(): if run_path in run_paths: run_paths.remove(run_path) groups[group_name] = run_paths @@ -365,7 +405,8 @@ def update_groups(self, groups: Optional[Dict[str, List[str]]] = None) -> None: ------ NotMergeableError If runs can not be merged, an error is thrown. - + TypeError + If `groups` is None, an error is thrown. """ instantiated_groups = {} if groups is None: diff --git a/deepcave/runs/recorder.py b/deepcave/runs/recorder.py index 4b8bc1db..672434b1 100644 --- a/deepcave/runs/recorder.py +++ b/deepcave/runs/recorder.py @@ -8,7 +8,7 @@ - Recorder: Define a Recorder for recording trial information. """ -from typing import Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import time from pathlib import Path @@ -16,9 +16,11 @@ import ConfigSpace import numpy as np from ConfigSpace import Configuration +from typing_extensions import Self from deepcave.runs import Status from deepcave.runs.converters.deepcave import DeepCAVERun +from deepcave.runs.objective import Objective class Recorder: @@ -48,11 +50,11 @@ class Recorder: def __init__( self, configspace: ConfigSpace.ConfigurationSpace, - objectives=None, - meta=None, - save_path="logs", - prefix="run", - overwrite=False, + objectives: Optional[List[Objective]] = None, + meta: Optional[Dict[str, Any]] = None, + save_path: str = "logs", + prefix: str = "run", + overwrite: bool = False, ): """ All objectives follow the scheme the lower the better. @@ -72,29 +74,41 @@ def __init__( if meta is None: meta = {} - self.path: Path = None + self.path: Path self._set_path(save_path, prefix, overwrite) # Set variables - self.last_trial_id = None + self.last_trial_id: Optional[ + Tuple[Union[Dict[Any, Any], Configuration], Optional[float]] + ] = None self.start_time = time.time() - self.start_times = {} - self.models = {} - self.origins = {} - self.additionals = {} + self.start_times: Dict[ + Tuple[Union[Dict[Any, Any], Configuration], Optional[float]], float + ] = {} + self.models: Dict[ + Tuple[Union[Dict[Any, Any], Configuration], Optional[float]], Optional[Union[str, Any]] + ] = {} + self.origins: Dict[ + Tuple[Union[Dict[Any, Any], Configuration], Optional[float]], Optional[str] + ] = {} + self.additionals: Dict[ + Tuple[Union[Dict[Any, Any], Configuration], Optional[float]], Dict[Any, Any] + ] = {} # Define trials container self.run = DeepCAVERun( self.path.stem, configspace=configspace, objectives=objectives, meta=meta ) - def __enter__(self): # noqa: D102, D105 + def __enter__(self) -> Self: return self - def __exit__(self, type, value, traceback): # noqa: D102, D105 + def __exit__(self, type, value, traceback) -> None: # type: ignore pass - def _set_path(self, path: Union[str, Path], prefix="run", overwrite=False) -> None: + def _set_path( + self, path: Union[str, Path], prefix: str = "run", overwrite: bool = False + ) -> None: """ Identify the latest run and sets the path with increased id. @@ -120,9 +134,9 @@ def _set_path(self, path: Union[str, Path], prefix="run", overwrite=False) -> No continue idx = file.name.split("_")[-1] if idx.isnumeric(): - idx = int(idx) - if idx > new_idx: - new_idx = idx + idx_int = int(idx) + if idx_int > new_idx: + new_idx = idx_int # And increase the id new_idx += 1 @@ -134,11 +148,11 @@ def start( self, config: Union[dict, Configuration], budget: Optional[float] = None, - model=None, - origin=None, + model: Optional[Union[str, Any]] = None, + origin: Optional[str] = None, additional: Optional[dict] = None, start_time: Optional[float] = None, - ): + ) -> None: """ Record the trial information. @@ -165,7 +179,7 @@ def start( if additional is None: additional = {} - id = (config, budget) + id: Tuple[Union[Dict[Any, Any], Configuration], Optional[float]] = (config, budget) if start_time is None: start_time = time.time() - self.start_time @@ -182,11 +196,11 @@ def end( self, costs: float = np.inf, status: Status = Status.SUCCESS, - config: Union[dict, Configuration] = None, - budget: float = np.inf, + config: Optional[Union[dict, Configuration]] = None, + budget: Optional[float] = np.inf, additional: Optional[dict] = None, end_time: Optional[float] = None, - ): + ) -> None: """ End the recording of the trial and add it to trial history. @@ -214,6 +228,11 @@ def end( end_time : Optional[float], optional The end time of the trial. Default is None. + + Raises + ------ + AssertionError + If no trial was started yet. """ if additional is None: additional = {} @@ -221,6 +240,7 @@ def end( if config is not None: id = (config, budget) else: + assert self.last_trial_id is not None, "No trial started yet." id = self.last_trial_id config, budget = id[0], id[1] @@ -232,6 +252,8 @@ def end( if end_time is None: end_time = time.time() - self.start_time + assert budget is not None + # Add to trial history self.run.add( costs=costs, diff --git a/deepcave/runs/run.py b/deepcave/runs/run.py index fcbb0bc1..c6b4abae 100644 --- a/deepcave/runs/run.py +++ b/deepcave/runs/run.py @@ -69,10 +69,10 @@ class Run(AbstractRun, ABC): def __init__( self, name: str, - configspace: ConfigSpace = None, - objectives: Union[Objective, List[Objective]] = None, - meta: Dict[str, Any] = None, - path: Optional[Union[str, Path]] = None, + configspace: Optional[ConfigSpace.ConfigurationSpace] = None, + objectives: Optional[Union[Objective, List[Objective]]] = None, + meta: Optional[Dict[str, Any]] = None, + path: Optional[Path] = None, ) -> None: super(Run, self).__init__(name) @@ -83,7 +83,8 @@ def __init__( # Reset and load configspace/path self.reset() - self.configspace = configspace + if configspace is not None: + self.configspace = configspace self.path = path if self.path is not None: self.load() @@ -213,6 +214,8 @@ def add( ------ RuntimeError If number of costs does not match number of objectives. + ValueError + If config id is None. """ if additional is None: additional = {} @@ -252,11 +255,14 @@ def add( config = config.get_dictionary() if config not in self.configs.values(): - config_id = len(self.configs) - self.configs[config_id] = config - self.origins[config_id] = origin + config_id_len = len(self.configs) + self.configs[config_id_len] = config + self.origins[config_id_len] = origin config_id = self.get_config_id(config) + if config_id is None: + raise ValueError("Config id is None.") + trial = Trial( config_id=config_id, budget=budget, diff --git a/deepcave/utils/run_caches.py b/deepcave/utils/run_caches.py index 76f9f7d2..34bf6ef1 100644 --- a/deepcave/utils/run_caches.py +++ b/deepcave/utils/run_caches.py @@ -10,7 +10,7 @@ - RunCaches: Hold the caches for the selected runs. """ -from typing import Any, Dict +from typing import Any, Dict, Optional import shutil @@ -108,7 +108,7 @@ def _reset(self, run: AbstractRun, cache: Cache) -> None: cache.write() - def get(self, run: AbstractRun, plugin_id: str, inputs_key: str) -> Dict[str, Any]: + def get(self, run: AbstractRun, plugin_id: str, inputs_key: str) -> Optional[Dict[str, Any]]: """ Return the raw outputs for the given run, plugin and inputs key. @@ -125,6 +125,11 @@ def get(self, run: AbstractRun, plugin_id: str, inputs_key: str) -> Dict[str, An ------- Dict[str, Any] Raw outputs for the given run, plugin and inputs key. + + Raises + ------ + AssertionError + If the outputs of the cache are not a dict. """ filename = self.cache_dir / run.id / plugin_id / f"{inputs_key}.json" @@ -132,7 +137,9 @@ def get(self, run: AbstractRun, plugin_id: str, inputs_key: str) -> Dict[str, An return None cache = Cache(filename, debug=self._debug, write_file=False) - return cache.get("outputs") + outputs = cache.get("outputs") + assert isinstance(outputs, dict), "Outputs of cache must be a dict." + return outputs def set(self, run: AbstractRun, plugin_id: str, inputs_key: str, value: Any) -> None: """