Skip to content

Commit

Permalink
refactor: OptimInstanceRush
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Oct 23, 2023
1 parent 9d566bc commit c44d72f
Show file tree
Hide file tree
Showing 46 changed files with 1,589 additions and 677 deletions.
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ Collate:
'ObjectiveRFunMany.R'
'OptimInstance.R'
'OptimInstanceMultiCrit.R'
'OptimInstanceRush.R'
'OptimInstanceRushMultiCrit.R'
'OptimInstanceRushSingleCrit.R'
'OptimInstanceSingleCrit.R'
'mlr_optimizers.R'
'Optimizer.R'
Expand Down
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ S3method(as.data.table,Archive)
S3method(as.data.table,ArchiveRush)
S3method(as.data.table,DictionaryOptimizer)
S3method(as.data.table,DictionaryTerminator)
S3method(assign_result_default,OptimInstance)
S3method(assign_result_default,OptimInstanceRush)
S3method(bb_optimize,"function")
S3method(bb_optimize,Objective)
S3method(optimize_default,OptimInstance)
S3method(optimize_default,OptimInstanceRush)
export(Archive)
export(ArchiveBest)
export(ArchiveRush)
Expand All @@ -18,6 +22,9 @@ export(ObjectiveRFunDt)
export(ObjectiveRFunMany)
export(OptimInstance)
export(OptimInstanceMultiCrit)
export(OptimInstanceRush)
export(OptimInstanceRushMultiCrit)
export(OptimInstanceRushSingleCrit)
export(OptimInstanceSingleCrit)
export(Optimizer)
export(OptimizerCmaes)
Expand Down Expand Up @@ -56,6 +63,7 @@ export(mlr_callbacks)
export(mlr_optimizers)
export(mlr_terminators)
export(nds_selection)
export(oi)
export(opt)
export(optimize_default)
export(opts)
Expand Down
12 changes: 5 additions & 7 deletions R/Archive.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,19 @@
#' @template param_ydt
#' @template param_n_select
#' @template param_ref_point
#'
#' @template field_search_space
#' @template field_codomain
#' @template field_start_time
#'
#' @export
Archive = R6Class("Archive",
public = list(

#' @field search_space ([paradox::ParamSet])\cr
#' Search space of objective.
search_space = NULL,

#' @field codomain ([Codomain])\cr
#' Codomain of objective function.
codomain = NULL,

#' @field start_time ([POSIXct])\cr
#' Time stamp of when the optimization started. The time is set by the
#' [Optimizer].
start_time = NULL,

#' @field check_values (`logical(1)`)\cr
Expand Down
43 changes: 14 additions & 29 deletions R/ArchiveRush.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,32 @@
#' Returns a tabular view of all performed function calls of the
#' Objective. The `x_domain` column is unnested to separate columns.
#'
#' @template param_codomain
#' @template param_search_space
#' @template param_xdt
#' @template param_ydt
#' @template param_n_select
#' @template param_ref_point
#' @template param_codomain
#' @template param_rush
#'
#' @template field_search_space
#' @template field_codomain
#' @template field_start_time
#' @template field_rush
#'
#' @export
ArchiveRush = R6Class("ArchiveRush",
public = list(

#' @field search_space ([paradox::ParamSet])\cr
#' Search space of objective.
search_space = NULL,

#' @field codomain ([Codomain])\cr
#' Codomain of objective function.
codomain = NULL,

#' @field start_time ([POSIXct])\cr
#' Time stamp of when the optimization started.
#' The time is set by the [Optimizer].
start_time = NULL,

#' @field check_values (`logical(1)`)\cr
#' Determines if points and results are checked for validity.
check_values = NULL,

#' @field rush ([rush::Rush])\cr
#' Rush.
rush = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param check_values (`logical(1)`)\cr
#' Should x-values that are added to the archive be checked for validity?
#' Search space that is logged into archive.
#' @param rush ([rush::Rush])\cr
#' Rush.
initialize = function(search_space, codomain, check_values = TRUE, rush) {
initialize = function(search_space, codomain, rush) {
self$search_space = assert_param_set(search_space)
self$codomain = Codomain$new(assert_param_set(codomain)$params)
self$check_values = assert_flag(check_values)
self$rush = assert_class(rush, "Rush")
private$.data = data.table()
},
Expand All @@ -62,9 +45,7 @@ ArchiveRush = R6Class("ArchiveRush",
#' For single-crit optimization, the solution that minimizes / maximizes the objective function.
#' For multi-crit optimization, the Pareto set / front.
#'
#' @param n_select (`integer(1L)`)\cr
#' Amount of points to select.
#' Ignored for multi-crit optimization.
#' @template param_n_select
#'
#' @return [data.table::data.table()]
best = function(n_select = 1) {
Expand All @@ -87,6 +68,9 @@ ArchiveRush = R6Class("ArchiveRush",
#' @description
#' Calculate best points w.r.t. non dominated sorting with hypervolume contribution.
#'
#' @template param_n_select
#' @template param_ref_point
#'
#' @return [data.table::data.table()]
nds_selection = function(n_select = 1, ref_point = NULL) {
tab = self$data
Expand Down Expand Up @@ -119,6 +103,7 @@ ArchiveRush = R6Class("ArchiveRush",
clear = function() {
self$rush$reset()
self$start_time = NULL
private$.data = data.table()
},

#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/ContextOptimization.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ContextOptimization = R6Class("ContextOptimization",
#' @param instance ([OptimInstance]).
#' @param optimizer ([Optimizer]).
initialize = function(instance, optimizer) {
self$instance = assert_class(instance, "OptimInstance")
self$instance = assert_multi_class(instance, c("OptimInstance", "OptimInstanceRush"))
self$optimizer = optimizer
}
),
Expand Down
178 changes: 28 additions & 150 deletions R/OptimInstance.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @title Optimization Instance with budget and archive
#' @title Optimization Instance
#'
#' @description
#' Abstract base class.
Expand All @@ -14,12 +14,7 @@
#' @template param_search_space
#' @template param_keep_evals
#' @template param_callbacks
#' @template param_rush
#' @template param_start_workers
#'
#' @template field_rush
#' @template field_freeze_archive
#' @template field_detect_lost_tasks
#'
#' @export
OptimInstance = R6Class("OptimInstance",
Expand Down Expand Up @@ -47,12 +42,6 @@ OptimInstance = R6Class("OptimInstance",
#' @field callbacks (List of [CallbackOptimization]s).
callbacks = NULL,

rush = NULL,

freeze_archive = NULL,

detect_lost_tasks = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
Expand All @@ -67,57 +56,25 @@ OptimInstance = R6Class("OptimInstance",
terminator,
keep_evals = "all",
check_values = TRUE,
callbacks = list(),
rush = NULL,
start_workers = FALSE) {

callbacks = list()
) {
self$objective = assert_r6(objective, "Objective")
self$search_space = choose_search_space(self$objective, search_space)
self$terminator = assert_terminator(terminator, self)
assert_choice(keep_evals, c("all", "best"))
assert_flag(check_values)
self$callbacks = assert_callbacks(as_callbacks(callbacks))
self$rush = assert_class(rush, "Rush", null.ok = TRUE)
assert_flag(start_workers)
self$freeze_archive = FALSE

# set search space
domain_search_space = self$objective$domain$search_space()
self$search_space = if (is.null(search_space) && domain_search_space$length == 0) {
# use whole domain as search space
self$objective$domain
} else if (is.null(search_space) && domain_search_space$length > 0) {
# create search space from tune token in domain
domain_search_space
} else if (!is.null(search_space) && domain_search_space$length == 0) {
# use supplied search space
assert_param_set(search_space)
} else {
stop("If the domain contains TuneTokens, you cannot supply a search_space.")
}

# use minimal archive if only best points are needed
self$archive = if (!is.null(self$rush)) {
ArchiveRush$new(
search_space = self$search_space,
codomain = objective$codomain,
check_values = check_values,
rush = self$rush)
} else if (keep_evals == "all") {
Archive$new(search_space = self$search_space, codomain = objective$codomain, check_values = check_values)
} else if (keep_evals == "best") {
ArchiveBest$new(search_space = self$search_space, codomain = objective$codomain, check_values = check_values)
}
Archive = if (keep_evals == "all") Archive else ArchiveBest
self$archive = Archive$new(
search_space = self$search_space,
codomain = objective$codomain,
check_values = check_values)

# disable objective function if search space is not all numeric
if (!self$search_space$all_numeric) {
private$.objective_function = objective_error
} else {
private$.objective_function = objective_function
}
private$.objective_function = if (!self$search_space$all_numeric) objective_error else objective_function
self$objective_multiplicator = self$objective$codomain$maximization_to_minimization

# start rush
if (!is.null(self$rush) && start_workers) self$start_workers()
},

#' @description
Expand Down Expand Up @@ -147,59 +104,6 @@ OptimInstance = R6Class("OptimInstance",
}
},

#' @description
#' Start workers with `future`.
#'
#' @param n_workers (`integer(1)`)\cr
#' Number of workers to be started.
#' If `NULL` the maximum number of free workers is used.
#' @param await_workers (`logical(1)`)\cr
#' Whether to wait until all workers are available.
#'
#' @template param_packages
#' @template param_host
#' @template param_heartbeat_period
#' @template param_heartbeat_expire
#' @template param_lgr_thresholds
#' @template param_freeze_archive
#' @template param_detect_lost_tasks
start_workers = function(
n_workers = NULL,
packages = NULL,
host = "local",
heartbeat_period = NULL,
heartbeat_expire = NULL,
lgr_thresholds = NULL,
await_workers = TRUE,
detect_lost_tasks = FALSE,
freeze_archive = FALSE) {

self$detect_lost_tasks = assert_flag(detect_lost_tasks)
self$freeze_archive = assert_flag(freeze_archive)

objective = self$objective
search_space = self$search_space

self$rush$start_workers(
worker_loop = bbotk_worker_loop,
n_workers = n_workers,
globals = c("objective", "search_space"),
packages = c(packages, "bbotk"),
host = host,
heartbeat_period = heartbeat_period,
heartbeat_expire = heartbeat_expire,
lgr_thresholds = lgr_thresholds,
objective = objective,
search_space = search_space,
await_workers = await_workers)
},

#' @description
#' Create a script to start workers.
create_worker_script = function() {
NULL
},

#' @description
#' Evaluates all input values in `xdt` by calling
#' the [Objective]. Applies possible transformations to the input values
Expand Down Expand Up @@ -244,50 +148,6 @@ OptimInstance = R6Class("OptimInstance",
return(invisible(ydt[, self$archive$cols_y, with = FALSE]))
},

#' @description
#' Evaluate xdt asynchronously.
#'
#' @param xdt (`data.table::data.table()`)\cr
#' x values as `data.table()` with one point per row.
#' Contains the value in the *search space* of the [OptimInstance] object.
#' Can contain additional columns for extra information.
#' @param wait (`logical(1)`)\cr
#' If `TRUE`, wait for all evaluations to finish.
eval_async = function(xdt, wait = FALSE) {

if (self$is_terminated) stop(terminated_error(self))

assert_data_table(xdt)
assert_names(colnames(xdt), must.include = self$search_space$ids())

lg$info("Evaluating %i configuration(s):", max(1, nrow(xdt)))
lg$info(capture.output(print(xdt,
class = FALSE, row.names = FALSE, print.keys = FALSE)))

xss = transpose_list(xdt[, self$search_space$ids(), with = FALSE])
xdt[, timestamp_xs := Sys.time()]
extra = transpose_list(xdt[, !self$search_space$ids(), with = FALSE])

if (!is.null(xdt$priority_id)) {
keys = self$rush$push_priority_tasks(xss, extra, priority = xdt$priority_id)
} else {
keys = self$rush$push_tasks(xss, extra)
}

# optimizer can request to wait for all evaluations to finish
if (wait) {
self$rush$await_tasks(keys, detect_lost_tasks = self$detect_lost_tasks)
}

# terminate optimization if all workers crashed
if (self$rush$n_running_workers == 0) {
lg$warn("Optimization terminated because %i workers crashed.", length(self$rush$lost_worker_ids))
stop(terminated_error(self))
}

if (self$detect_lost_tasks) self$rush$detect_lost_tasks()
},

#' @description
#' The [Optimizer] object writes the best found point
#' and estimated performance value here. For internal use.
Expand Down Expand Up @@ -390,3 +250,21 @@ objective_error = function(x, inst, maximization_to_minimization) {
stop("$objective_function can only be called if search_space only
contains numeric values")
}

# used by OptimInstance and OptimInstanceRush
choose_search_space = function(objective, search_space) {
# create search space
domain_search_space = objective$domain$search_space()
if (is.null(search_space) && domain_search_space$length == 0) {
# use whole domain as search space
objective$domain
} else if (is.null(search_space) && domain_search_space$length > 0) {
# create search space from tune token in domain
domain_search_space
} else if (!is.null(search_space) && domain_search_space$length == 0) {
# use supplied search space
assert_param_set(search_space)
} else {
stop("If the domain contains TuneTokens, you cannot supply a search_space.")
}
}
Loading

0 comments on commit c44d72f

Please sign in to comment.