Skip to content

Commit

Permalink
fixed a few type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
KrissiHub committed Oct 30, 2023
1 parent e053e92 commit 0f00bb7
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 67 deletions.
15 changes: 8 additions & 7 deletions deepcave/custom_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- Queue: This class defines all components for a job queue.
"""

from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional

import redis
from rq import Queue as _Queue
Expand Down Expand Up @@ -256,7 +256,7 @@ def get_finished_jobs(self) -> List[Job]:
"""Get the finished jobs in the registry."""
return self.get_jobs(registry="finished")

def delete_job(self, job_id: str = None) -> None:
def delete_job(self, job_id: Optional[str] = None) -> None:
"""
Delete a job from the queue. If no job_id is given, delete all jobs.
Expand All @@ -266,7 +266,7 @@ def delete_job(self, job_id: str = None) -> None:
Id of the job, which should be removed. By default None.
"""

def remove_jobs(registry: BaseRegistry, job_id: str = None) -> None:
def remove_jobs(registry: BaseRegistry, job_id: Optional[str] = None) -> None:
"""
Remove a job from the registry. If no job_id is given, remove all.
Expand All @@ -280,17 +280,18 @@ def remove_jobs(registry: BaseRegistry, job_id: str = None) -> None:
"""
if job_id is not None:
try:
registry.remove(job_id, delete_job=True)
registry.remove(self.get_job(job_id), delete_job=True)
except Exception:
pass
else:
# Remove all
for job_id in registry.get_job_ids():
try:
registry.remove(job_id, delete_job=True)
registry.remove(self.get_job(job_id), delete_job=True)
except Exception:
registry.remove(job_id)
registry.remove(self.get_job(job_id))

# Issue opened
remove_jobs(self._queue, job_id)
remove_jobs(self._queue.finished_job_registry, job_id)
remove_jobs(self._queue.canceled_job_registry, job_id)
Expand Down Expand Up @@ -362,7 +363,7 @@ def enqueue(
result_ttl=-1, # Make sure it's not automatically deleted.
)

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
"""If function is not found, make sure we access self._queue directly."""
try:
return self.__getattribute__(name)
Expand Down
12 changes: 10 additions & 2 deletions deepcave/evaluators/epm/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# noqa
"""from typing import Any, Dict, List
"""from typing import Any, Dict, List, Optional
import numpy as np
from ConfigSpace import ConfigurationSpace
Expand Down Expand Up @@ -104,7 +105,14 @@
styled_plot = styled_plot.StyledPlot()
plt = styled_plot.plt
print(type(plt.render()))"""
print(type(plt.render()))
list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
list_comp = [1.5 for i in list]
def test(hallo: Optional[int] = 2):
print(type(hallo))
test(5)
test()"""
14 changes: 8 additions & 6 deletions deepcave/runs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from deepcave.utils.logs import get_logger



class AbstractRun(ABC):
"""
Can create, handle and get information of an abstract run.
Expand Down Expand Up @@ -96,8 +97,8 @@ def reset(self) -> None:
"""
self.meta: Dict[str, Any] = {}
self.configspace: ConfigSpace.ConfigurationSpace
self.configs: Dict[int, Configuration] = {}
self.origins: Dict[int, str] = {}
self.configs: Dict[int, Union[Configuration, Dict[Any, Any]]] = {}
self.origins: Dict[int, Optional[str]] = {}
self.models: Dict[int, Optional[Union[str, "torch.nn.Module"]]] = {} # noqa: F821

self.history: List[Trial] = []
Expand Down Expand Up @@ -246,7 +247,7 @@ def empty(self) -> bool:
"""
return len(self.history) == 0

def get_origin(self, config_id: int) -> str:
def get_origin(self, config_id: int) -> Optional[str]:
"""
Get the origin, given a config ID.
Expand Down Expand Up @@ -372,7 +373,7 @@ def get_objective_names(self) -> List[str]:
"""
return [obj.name for obj in self.get_objectives()]

def get_configs(self, budget: Union[int, float] = None) -> Dict[int, Configuration]:
def get_configs(self, budget: Optional[Union[int, float]] = None) -> Dict[int, Configuration]:
"""
Get configurations of the run.
Expand Down Expand Up @@ -449,7 +450,7 @@ def get_config_id(self, config: Union[Configuration, Dict]) -> Optional[int]:

return None

def get_num_configs(self, budget: Union[int, float] = None) -> int:
def get_num_configs(self, budget: Optional[Union[int, float]] = None) -> int:
"""
Count the number of configurations stored in this object with a specific budget.
Expand Down Expand Up @@ -1196,7 +1197,7 @@ def check_equality(
If the budgets of the runs are not equal.
If the objective of the runs are not equal.
"""
result = {}
result: Dict[str, Any] = {}

if len(runs) == 0:
return result
Expand Down Expand Up @@ -1259,6 +1260,7 @@ def check_equality(
for o1_, o2_ in zip(o1, o2):
o1_.merge(o2_)

assert o1 is not None
serialized_objectives = [o.to_json() for o in o1]
result["objectives"] = serialized_objectives
if meta:
Expand Down
2 changes: 1 addition & 1 deletion deepcave/runs/converters/deepcave.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ def latest_change(self) -> Union[float, int]:
return Path(self.path / "history.jsonl").stat().st_mtime

@classmethod
def from_path(cls, path):
def from_path(cls, path: Path) -> "DeepCAVERun":
"""Get a DeepCAVE run from a given path."""
return DeepCAVERun(path.stem, path=Path(path))
19 changes: 10 additions & 9 deletions deepcave/runs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
- Group: Can group and manage a group of abstract runs.
"""

from typing import Dict, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from copy import deepcopy

Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(self, name: str, runs: List[AbstractRun]):
except Exception as e:
raise NotMergeableError(f"Runs can not be merged: {e}")

def __iter__(self):
def __iter__(self: "Group") -> Iterator[str]:
"""Allow to iterate over the object."""
for run in self.runs:
yield run.name
Expand All @@ -156,9 +156,9 @@ def id(self) -> str:
return string_to_hash(f"{self.prefix}:{self.name}")

@property
def latest_change(self) -> int:
def latest_change(self) -> float:
"""Get the latest change made to the grouped runs."""
latest_change = 0
latest_change = 0.0
for run in self.runs:
if run.latest_change > latest_change:
latest_change = run.latest_change
Expand All @@ -183,7 +183,7 @@ def get_new_config_id(self, run_id: int, original_config_id: int) -> int:
"""Get a new identificator for a configuration."""
return self._new_config_mapping[(run_id, original_config_id)]

def get_original_config_id(self, config_id: int) -> id:
def get_original_config_id(self, config_id: int) -> int:
"""Get the original identificator of a configuration."""
return self._original_config_mapping[config_id][1]

Expand All @@ -192,11 +192,12 @@ def get_original_run(self, config_id: int) -> AbstractRun:
run_id = self._original_config_mapping[config_id][0]
return self.runs[run_id]

def get_model(self, config_id: int):
def get_model(self, config_id: int) -> Optional[Any]:
"""Get the model of the runs."""
run_id, config_id = self._original_config_mapping[config_id]
return self.runs[run_id].get_model(config_id)

# wait until meeting
def get_trajectory(self, *args, **kwargs):
"""
Get the trajectory of the group.
Expand Down Expand Up @@ -251,10 +252,10 @@ def get_trajectory(self, *args, **kwargs):
all_costs.append(y)

# Make numpy arrays
all_costs = np.array(all_costs)
all_costs_array = np.array(all_costs)

times = all_times
costs_mean = np.mean(all_costs, axis=1)
costs_std = np.std(all_costs, axis=1)
costs_mean = np.mean(all_costs_array, axis=1)
costs_std = np.std(all_costs_array, axis=1)

return times, list(costs_mean), list(costs_std), [], []
53 changes: 47 additions & 6 deletions deepcave/runs/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,17 @@ def get_working_directory(self) -> Path:
-------
Path
Path of the working directory.
Raises
------
AssertionError
If the working directory is not a string or a path like, an error is thrown.
"""
return Path(self.c.get("working_dir"))
working_dir = self.c.get("working_dir")
assert isinstance(
working_dir, (str, Path)
), "Working directory of cache must be a string or a Path like."
return Path(working_dir)

def get_available_run_paths(self) -> Dict[str, str]:
"""
Expand Down Expand Up @@ -135,8 +144,17 @@ def get_selected_run_paths(self) -> List[str]:
-------
List[str]
Run paths as a list.
Raises
------
AssertionError.
If the selected run paths are not a list, an error is thrown.
"""
return self.c.get("selected_run_paths")
selected_run_paths = self.c.get("selected_run_paths")
assert isinstance(
selected_run_paths, list
), "The selected run paths of the cache must be a list."
return selected_run_paths

def get_selected_run_names(self) -> List[str]:
"""
Expand Down Expand Up @@ -166,8 +184,19 @@ def get_run_name(self, run_path: Union[Path, str]) -> str:
return Path(run_path).stem

def get_selected_groups(self) -> Dict[str, List[str]]:
"""Get the selected groups."""
return self.c.get("groups")
"""
Get the selected groups.
Raises
------
AssertionError
If groups in cache is not a dict, an error is thrown.
"""
selected_groups = self.c.get("groups")
assert isinstance(
selected_groups, dict
), "The groups aquired from the cache must be a dictionary."
return selected_groups

def add_run(self, run_path: str) -> bool:
"""
Expand Down Expand Up @@ -201,17 +230,28 @@ def remove_run(self, run_path: str) -> None:
----------
run_path : str
Path of a run.
Raises
------
TypeError
If `selected_run_paths` or `groups` is None, an error is thrown.
"""
selected_run_paths = self.c.get("selected_run_paths")

if selected_run_paths is None:
raise TypeError("Selected run paths can not be None.")

if run_path in selected_run_paths:
selected_run_paths.remove(run_path)
self.c.set("selected_run_paths", value=selected_run_paths)

# We have to check the groups here because the removed run_path may
# still be included
groups = {}
for group_name, run_paths in self.c.get("groups").items():
group_it = self.c.get("groups")
if group_it is None:
raise TypeError("Groups can not be None.")
for group_name, run_paths in group_it.items():
if run_path in run_paths:
run_paths.remove(run_path)
groups[group_name] = run_paths
Expand Down Expand Up @@ -365,7 +405,8 @@ def update_groups(self, groups: Optional[Dict[str, List[str]]] = None) -> None:
------
NotMergeableError
If runs can not be merged, an error is thrown.
TypeError
If `groups` is None, an error is thrown.
"""
instantiated_groups = {}
if groups is None:
Expand Down
Loading

0 comments on commit 0f00bb7

Please sign in to comment.