Skip to content

Commit

Permalink
fixed a few type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
KrissiHub committed Nov 1, 2023
1 parent 0f00bb7 commit 1327166
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 49 deletions.
18 changes: 12 additions & 6 deletions deepcave/evaluators/epm/fanova_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
- FanovaForest: A fanova forest wrapper for pyrfr.
"""

from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import itertools as it

Expand Down Expand Up @@ -126,7 +126,7 @@ def _train(self, X: np.ndarray, Y: np.ndarray) -> None:

# compute midpoints and interval sizes for variables in each tree
for tree_split_values in forest_split_values:
sizes = []
sizes: List = []
midpoints = []
for i, split_vals in enumerate(tree_split_values):
if np.isnan(self.bounds[i][1]): # categorical parameter
Expand Down Expand Up @@ -156,16 +156,21 @@ def _train(self, X: np.ndarray, Y: np.ndarray) -> None:
# and the value list contains \hat{f}_U for the individual trees
# reset all the variance fractions computed
self.trees_variance_fractions: dict = {}
self.V_U_total = {}
self.V_U_individual = {}
self.V_U_total: Dict[Tuple[int, ...], List[Union[Any, float]]] = {}
self.V_U_individual: Dict[Tuple[int, ...], List[Union[Any, float]]] = {}

# Set cut-off
self._model.set_cutoffs(self.cutoffs[0], self.cutoffs[1])

# recompute the trees' total variance
self.trees_total_variance = self._model.get_trees_total_variances()

def compute_marginals(self, hp_ids: List[int], depth: int = 1):
def compute_marginals(
self, hp_ids: Union[List[int], Tuple[int, ...]], depth: int = 1
) -> Tuple[
Dict[Tuple[int, ...], List[Union[Any, float]]],
Dict[Tuple[int, ...], List[Union[Any, float]]],
]:
"""
Return the marginal of selected parameters.
Expand All @@ -178,7 +183,8 @@ def compute_marginals(self, hp_ids: List[int], depth: int = 1):
-------
The marginal of selected parameters.
"""
hp_ids = tuple(hp_ids)
if not isinstance(hp_ids, tuple):
hp_ids = tuple(hp_ids)

# check if values has been previously computed
if hp_ids in self.V_U_individual:
Expand Down
17 changes: 10 additions & 7 deletions deepcave/evaluators/epm/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def __init__(
self.seed = seed

# Set types and bounds automatically
self.types, self.bounds = get_types(configspace, instance_features)
types, self.bounds = get_types(configspace, instance_features)
self.types = np.array(types)

# Prepare everything for PCA
self.n_params = len(configspace.get_hyperparameters())
Expand Down Expand Up @@ -176,13 +177,14 @@ def _get_model_options(self, **kwargs: Dict[str, Any]) -> regression.forest_opts
# Now we set the options
options = regression.forest_opts()

def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
def rgetattr(obj: object, attr: str, *args: Any) -> Any:
def _getattr(obj: object, attr: object) -> Any:
attr = str(attr)
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split("."))

def rsetattr(obj, attr, val):
def rsetattr(obj: object, attr: str, val: Any) -> None:
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)

Expand Down Expand Up @@ -365,7 +367,7 @@ def _train(self, X: np.ndarray, Y: np.ndarray) -> None:
self._model.options.num_data_points_per_tree = X.shape[0]
self._model.fit(data, rng=rng)

def predict(self, X: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Predict means and variances for a given X.
Expand All @@ -378,7 +380,7 @@ def predict(self, X: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
-------
means : np.ndarray [n_samples, n_objectives]
Predictive mean.
vars : Optional[np.ndarray] [n_samples, n_objectives] or [n_samples, n_samples]
vars : np.ndarray [n_samples, n_objectives] or [n_samples, n_samples]
Predictive variance or standard deviation.
"""
self._check_dimensions(X)
Expand Down Expand Up @@ -529,6 +531,7 @@ def predict_marginalized(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

return mean_, var

def get_leaf_values(self, x: np.ndarray):
# Wait until meeting
def get_leaf_values(self, x: np.ndarray) -> regression.binary_rss_forest:
"""Get the leaf values of the model."""
return self._model.all_leaf_values(x) # type: ignore[np-untyped-def]
4 changes: 2 additions & 2 deletions deepcave/evaluators/epm/random_forest_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- RandomForestSurrogate: Random forest surrogate for the pyPDP package.
"""

from typing import Tuple
from typing import Optional, Tuple

import ConfigSpace as CS
import numpy as np
Expand All @@ -29,7 +29,7 @@ class RandomForestSurrogate(SurrogateModel):
def __init__(
self,
configspace: CS.ConfigurationSpace,
seed: int = None,
seed: Optional[int] = None,
):
super().__init__(configspace, seed=seed)
self._model = RandomForest(configspace=configspace, seed=seed)
Expand Down
2 changes: 1 addition & 1 deletion deepcave/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def toggle_raw_data_modal(n: Optional[int], is_open: bool) -> Tuple[bool, str]:
Input(self.get_internal_id("show_help"), "n_clicks"),
State(self.get_internal_id("help"), "is_open"),
)
def toggle_help_modal(n: Optional[int], is_open: bool) -> Tuple[bool, str]:
def toggle_help_modal(n: Optional[int], is_open: bool) -> bool:
"""
Toggle the help modal.
Expand Down
49 changes: 31 additions & 18 deletions deepcave/runs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from deepcave.utils.logs import get_logger



class AbstractRun(ABC):
"""
Can create, handle and get information of an abstract run.
Expand Down Expand Up @@ -99,10 +98,13 @@ def reset(self) -> None:
self.configspace: ConfigSpace.ConfigurationSpace
self.configs: Dict[int, Union[Configuration, Dict[Any, Any]]] = {}
self.origins: Dict[int, Optional[str]] = {}
# Wait until meeting
self.models: Dict[int, Optional[Union[str, "torch.nn.Module"]]] = {} # noqa: F821

self.history: List[Trial] = []
self.trial_keys: Dict[Tuple[int, int], int] = {} # (config_id, budget) -> trial_id
self.trial_keys: Dict[
Tuple[int, Union[int, float, None]], int
] = {} # (config_id, budget) -> trial_id

# Cached data
self._highest_budget: Dict[int, Union[int, float]] = {} # config_id -> budget
Expand Down Expand Up @@ -177,7 +179,7 @@ def latest_change(self) -> float:
@staticmethod
def get_trial_key(
config_id: int, budget: Union[int, float, None]
) -> Tuple[int, Union[int, float]]:
) -> Tuple[int, Union[int, float, None]]:
"""
Get the trial key for the configuration and the budget.
Expand All @@ -195,7 +197,7 @@ def get_trial_key(
"""
return (config_id, budget)

def get_trial(self, trial_key: Tuple[int, Union[int, float]]) -> Optional[Trial]:
def get_trial(self, trial_key: Tuple[int, int]) -> Optional[Trial]:
"""
Get the trial with the responding key if existing.
Expand Down Expand Up @@ -468,7 +470,7 @@ def get_num_configs(self, budget: Optional[Union[int, float]] = None) -> int:
"""
return len(self.get_configs(budget=budget))

def get_budget(self, id: Union[int, str], human: bool = False) -> float:
def get_budget(self, id: Union[int, str], human: bool = False) -> Union[int, float]:
"""
Get the budget given an id.
Expand All @@ -479,11 +481,19 @@ def get_budget(self, id: Union[int, str], human: bool = False) -> float:
Returns
-------
float
float, int
Budget.
Raises
------
TypeError
If the budget with this id is invalid.
"""
budgets = self.get_budgets(human=human)
return budgets[int(id)]
budget = budgets[int(id)]
if isinstance(budget, str):
raise TypeError("The budget with this id is invalid.")
return budget

def get_budget_ids(self, include_combined: bool = True) -> List[int]:
"""
Expand All @@ -508,7 +518,7 @@ def get_budget_ids(self, include_combined: bool = True) -> List[int]:

def get_budgets(
self, human: bool = False, include_combined: bool = True
) -> List[Union[int, float]]:
) -> List[Union[int, float, str]]:
"""
Return the budgets from the meta data.
Expand All @@ -519,15 +529,15 @@ def get_budgets(
Returns
-------
List[Union[int, float]]
List of budgets.
List[Union[int, float, str]]
List of budgets. In a readable form, if human is True.
"""
budgets = self.meta["budgets"].copy()
if include_combined and len(budgets) > 1 and COMBINED_BUDGET not in budgets:
budgets += [COMBINED_BUDGET]

if human:
readable_budgets = []
readable_budgets: List[Union[str, float]] = []
for b in budgets:
if b == COMBINED_BUDGET:
readable_budgets += ["Combined"]
Expand Down Expand Up @@ -853,6 +863,7 @@ def merge_costs(

return cost

# Wait until meeting
def get_model(self, config_id: int) -> Optional["torch.nn.Module"]: # noqa: F821
"""
Get a torch model associated with a configuration ID.
Expand All @@ -870,6 +881,7 @@ def get_model(self, config_id: int) -> Optional["torch.nn.Module"]: # noqa: F82
"""
import torch

# Issue is opened
filename = self.models_dir / f"{str(config_id)}.pth"
if not filename.exists():
return None
Expand Down Expand Up @@ -951,7 +963,7 @@ def get_trajectory(
current_cost = cost

costs_mean.append(cost)
costs_std.append(0)
costs_std.append(0.0)
times.append(trial.end_time)
ids.append(id)
config_ids.append(trial.config_id)
Expand Down Expand Up @@ -1102,8 +1114,8 @@ def get_encoded_data(
y_set.append(y)
config_ids.append(config_id)

x_set = np.array(x_set) # type: ignore
y_set = np.array(y_set) # type: ignore
x_set_array = np.array(x_set)
y_set_array = np.array(y_set)
config_ids = np.array(config_ids).reshape(-1, 1) # type: ignore

# Imputation: Easiest case is to replace all nans with -1
Expand Down Expand Up @@ -1135,8 +1147,8 @@ def get_encoded_data(
raise ValueError("Hyperparameter not supported.")

if conditional[idx] is True:
non_finite_mask = ~np.isfinite(x_set[:, idx])
x_set[non_finite_mask, idx] = impute_values[idx]
non_finite_mask = ~np.isfinite(x_set_array[:, idx])
x_set_array[non_finite_mask, idx] = impute_values[idx]

# Now we create dataframes for both values and labels
# [CONFIG_ID, HP1, HP2, ..., HPn, OBJ1, OBJ2, ..., OBJm, COMBINED_COST]
Expand All @@ -1152,9 +1164,10 @@ def get_encoded_data(
columns += [COMBINED_COST_NAME]

if include_config_ids:
data: np.ndarray = np.concatenate((config_ids, x_set, y_set), axis=1)
# wait till meeting
data: np.ndarray = np.concatenate((config_ids, x_set_array, y_set_array), axis=1)
else:
data = np.concatenate((x_set, y_set), axis=1)
data = np.concatenate((x_set_array, y_set_array), axis=1)

data = pd.DataFrame(data=data, columns=columns)

Expand Down
2 changes: 1 addition & 1 deletion deepcave/runs/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __post_init__(self) -> None:

assert isinstance(self.status, Status)

def get_key(self) -> Tuple[int, int]:
def get_key(self) -> Tuple[int, Union[int, float, None]]:
"""
Generate a key based on the configuration ID and the budget.
Expand Down
13 changes: 7 additions & 6 deletions deepcave/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Cache:
def __init__(
self,
filename: Optional[Path] = None,
defaults: Dict = None,
defaults: Optional[Dict] = None,
debug: bool = False,
write_file: bool = True,
) -> None:
Expand Down Expand Up @@ -88,7 +88,7 @@ def write(self) -> None:
else:
json.dump(self._data, f, separators=JSON_DENSE_SEPARATORS)

def set(self, *keys, value: Any, write_file: bool = True) -> None:
def set(self, *keys: str, value: Any, write_file: bool = True) -> None:
"""
Set a value from a chain of keys.
Expand Down Expand Up @@ -138,7 +138,7 @@ def set_dict(self, d: Dict, write_file: bool = True) -> None:
if write_file:
self.write()

def get(self, *keys) -> Optional[Any]:
def get(self, *keys: str) -> Optional[Any]:
"""Retrieve value for a specific key."""
d = deepcopy(self._data)
for key in keys:
Expand All @@ -149,7 +149,7 @@ def get(self, *keys) -> Optional[Any]:

return d

def has(self, *keys) -> bool:
def has(self, *keys: str) -> bool:
"""Check whether cache has specific key."""
d = self._data
for key in keys:
Expand All @@ -163,7 +163,8 @@ def clear(self, write_file: bool = True) -> None:
"""Clear all cache and reset to defaults."""
filename = self._filename

if filename is not None and filename.exists():
self._filename.unlink()
if self._filename is not None:
if self._filename.exists():
self._filename.unlink()

self._setup(filename, write_file=write_file)
16 changes: 8 additions & 8 deletions deepcave/utils/configspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def sample_border_config(configspace: ConfigurationSpace) -> Iterator[Configurat
config[hp_name] = value

try:
config = deactivate_inactive_hyperparameters(config, configspace)
config.is_valid_configuration() # type: ignore[attr-defined]
configuration = deactivate_inactive_hyperparameters(config, configspace)
configuration.is_valid_configuration() # type: ignore[attr-defined]
except Exception:
continue

yield config
yield configuration


def sample_random_config(
Expand Down Expand Up @@ -91,7 +91,7 @@ def sample_random_config(
rng = np.random.RandomState(0)

while True:
config = {}
config_dict = {}

# Iterates over the hyperparameters to get considered values
for hp_name, hp in zip(
Expand Down Expand Up @@ -121,12 +121,12 @@ def sample_random_config(

# Get a random choice
value = rng.choice(values)
config[hp_name] = value
config_dict[hp_name] = value

try:
config = deactivate_inactive_hyperparameters(config, configspace)
config.is_valid_configuration()
configuration = deactivate_inactive_hyperparameters(config_dict, configspace)
configuration.is_valid_configuration()
except Exception:
continue

yield config
yield configuration

0 comments on commit 1327166

Please sign in to comment.