Skip to content

Commit

Permalink
refactor: async
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Apr 16, 2024
1 parent 7965d39 commit 2117841
Show file tree
Hide file tree
Showing 22 changed files with 179 additions and 134 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Collate:
'Archive.R'
'ArchiveBest.R'
'ArchiveAsync.R'
'ArchiveBest.R'
'CallbackOptimization.R'
'Codomain.R'
'ContextOptimization.R'
Expand Down
6 changes: 3 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ S3method(as_terminators,list)
S3method(bb_optimize,"function")
S3method(bb_optimize,Objective)
export(Archive)
export(ArchiveBest)
export(ArchiveAsync)
export(ArchiveBest)
export(CallbackOptimization)
export(Codomain)
export(ContextOptimization)
Expand Down Expand Up @@ -58,8 +58,7 @@ export(assert_terminators)
export(assign_result_default)
export(bb_optimize)
export(bbotk_reflections)
export(bbotk_worker_loop_centralized)
export(bbotk_worker_loop_decentralized)
export(bbotk_worker_loop)
export(branin)
export(branin_wu)
export(callback_optimization)
Expand All @@ -72,6 +71,7 @@ export(mlr_optimizers)
export(mlr_terminators)
export(nds_selection)
export(oi)
export(oi_async)
export(opt)
export(optimize_default)
export(opts)
Expand Down
2 changes: 0 additions & 2 deletions R/OptimInstanceAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ OptimInstanceAsync = R6Class("OptimInstanceAsync",
.result = NULL,
.objective_function = NULL,
.context = NULL,
.freeze_archive = NULL,
.detect_lost_tasks = NULL,

.assign_result = function(xdt, y) {
stop("Abstract class")
Expand Down
2 changes: 1 addition & 1 deletion R/OptimizerAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ start_async_optimize = function(inst, self, private) {
packages = c(self$packages, "bbotk") # add packages from objective

inst$rush$start_workers(
worker_loop = bbotk_worker_loop_decentralized,
worker_loop = bbotk_worker_loop,
packages = packages,
optimizer = self,
instance = inst,
Expand Down
11 changes: 11 additions & 0 deletions R/OptimizerRandomSearch.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ OptimizerRandomSearch = R6Class("OptimizerRandomSearch",
man = "bbotk::mlr_optimizers_random_search"
)
}
),

private = list(
.optimize = function(inst) {
batch_size = self$param_set$values$batch_size
sampler = SamplerUnif$new(inst$search_space)
repeat { # iterate until we have an exception from eval_batch
design = sampler$sample(batch_size)
inst$eval_batch(design$data)
}
}
)
)

Expand Down
2 changes: 1 addition & 1 deletion R/TerminatorClockTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TerminatorClockTime = R6Class("TerminatorClockTime",
#'
#' @return `logical(1)`.
is_terminated = function(archive) {
assert_r6(archive, "Archive")
assert_multi_class(archive, c("Archive", "ArchiveAsync"))
return(Sys.time() >= self$param_set$values$stop_time)
}
),
Expand Down
2 changes: 1 addition & 1 deletion R/TerminatorCombo.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ TerminatorCombo = R6Class("TerminatorCombo",
#'
#' @return `logical(1)`.
is_terminated = function(archive) {
assert_r6(archive, "Archive")
assert_multi_class(archive, c("Archive", "ArchiveAsync"))
g = if (self$param_set$values$any) any else all
g(map_lgl(self$terminators, function(t) t$is_terminated(archive)))
},
Expand Down
2 changes: 1 addition & 1 deletion R/TerminatorEvals.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TerminatorEvals = R6Class("TerminatorEvals",
#'
#' @return `logical(1)`.
is_terminated = function(archive) {
# assert_r6(archive, "Archive")
assert_multi_class(archive, c("Archive", "ArchiveAsync"))
pv = self$param_set$values
archive$n_evals >= pv$n_evals + pv$k * archive$search_space$length
}
Expand Down
2 changes: 1 addition & 1 deletion R/TerminatorNone.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ TerminatorNone = R6Class("TerminatorNone",
#'
#' @return `logical(1)`.
is_terminated = function(archive) {
assert_r6(archive, "Archive")
assert_multi_class(archive, c("Archive", "ArchiveAsync"))
return(FALSE)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/TerminatorPerfReached.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TerminatorPerfReached = R6Class("TerminatorPerfReached",
#'
#' @return `logical(1)`.
is_terminated = function(archive) {
assert_r6(archive, "Archive")
assert_multi_class(archive, c("Archive", "ArchiveAsync"))
level = self$param_set$values$level
ycol = archive$cols_y
minimize = "minimize" %in% archive$codomain$tags
Expand Down
2 changes: 1 addition & 1 deletion R/TerminatorRunTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TerminatorRunTime = R6Class("TerminatorRunTime",
#'
#' @return `logical(1)`.
is_terminated = function(archive) {
#assert_r6(archive, "Archive")
assert_multi_class(archive, c("Archive", "ArchiveAsync"))
if (is.null(archive$start_time)) return(FALSE)
d = as.numeric(difftime(Sys.time(), archive$start_time, units = "secs"))
return(d >= self$param_set$values$secs)
Expand Down
2 changes: 1 addition & 1 deletion R/TerminatorStagnation.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TerminatorStagnation = R6Class("TerminatorStagnation",
#'
#' @return `logical(1)`.
is_terminated = function(archive) {
assert_r6(archive, "Archive")
assert_multi_class(archive, c("Archive", "ArchiveAsync"))
pv = self$param_set$values
iters = pv$iters
ycol = archive$cols_y
Expand Down
38 changes: 20 additions & 18 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ is_dominated = function(ymat) {
}

#' @title Calculates the transformed x-values
#'
#' @description
#' Transforms a given `data.table()` to a list with transformed x values.
#' If no trafo is defined it will just convert the `data.table()` to a list.
Expand All @@ -51,6 +52,25 @@ transform_xdt_to_xss = function(xdt, search_space) {
design$transpose(trafo = TRUE, filter_na = TRUE)
}

#' @title Calculate the transformed x-values
#'
#' @description
#' Transforms a given `list()` to a list with transformed x values.
#'
#' @param xs (`list()`) \cr
#' List of x-values.
#' @param search_space [paradox::ParamSet]\cr
#' Search space.
#'
#' @export
trafo_xs = function(xs, search_space) {
xs = discard(xs, is_scalar_na)
if (search_space$has_trafo) {
xs = search_space$trafo(xs, search_space)
}
return(xs)
}

#' @title Get start values for optimizers
#'
#' @description
Expand Down Expand Up @@ -111,22 +131,4 @@ allow_partial_matching = list(
warnPartialMatchDollar = FALSE
)

#' @title Calculate the transformed x-values
#'
#' @description
#' Transforms a given `list()` to a `list()`` with transformed x values.
#'
#' @param xs (`list()`) \cr
#' List of x-values.
#' @param search_space [paradox::ParamSet]\cr
#' Search space.
#'
#' @export
trafo_xs = function(xs, search_space) {
xs = discard(xs, is_scalar_na)
if (search_space$has_trafo) {
xs = search_space$trafo(xs, search_space)
}
return(xs)
}

57 changes: 39 additions & 18 deletions R/sugar.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ opts = function(.keys, ...) {
#' @title Syntactic Sugar for Optimization Instance Construction
#'
#' @description
#' Function to construct a [OptimInstanceSingleCrit], [OptimInstanceMultiCrit], [OptimInstanceAsyncSingleCrit] or [OptimInstanceAsyncMultiCrit].
#' Function to construct a [OptimInstanceSingleCrit] and [OptimInstanceMultiCrit].
#'
#' @template param_objective
#' @template param_search_space
Expand All @@ -90,21 +90,42 @@ oi = function(
) {
assert_r6(objective, "Objective")

if (rush_available()) {
Instance = if (objective$codomain$target_length == 1) OptimInstanceSingleCrit else OptimInstanceMultiCrit
Instance$new(
objective = objective,
search_space = search_space,
terminator = terminator,
keep_evals = keep_evals,
check_values = check_values,
callbacks = callbacks)
} else {
Instance = if (objective$codomain$target_length == 1) OptimInstanceAsyncSingleCrit else OptimInstanceAsyncMultiCrit
Instance$new(
objective = objective,
search_space = search_space,
terminator = terminator,
callbacks = callbacks)
}
Instance = if (objective$codomain$target_length == 1) OptimInstanceAsyncSingleCrit else OptimInstanceAsyncMultiCrit
Instance$new(
objective = objective,
search_space = search_space,
terminator = terminator,
callbacks = callbacks)
}

#' @title Syntactic Sugar for Asynchronous Optimization Instance Construction
#'
#' @description
#' Function to construct an [OptimInstanceAsyncSingleCrit] and [OptimInstanceAsyncMultiCrit].
#'
#' @template param_objective
#' @template param_search_space
#' @template param_terminator
#' @template param_callbacks
#' @template param_rush
#'
#' @export
oi_async = function(
objective,
search_space = NULL,
terminator,
callbacks = list(),
rush = NULL
) {
assert_r6(objective, "Objective")

Instance = if (objective$codomain$target_length == 1) OptimInstanceSingleCrit else OptimInstanceMultiCrit
Instance$new(
objective = objective,
search_space = search_space,
terminator = terminator,
keep_evals = keep_evals,
check_values = check_values,
callbacks = callbacks,
rush = rush)
}
26 changes: 4 additions & 22 deletions R/worker_loops.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,12 @@
#' Pushes the results back to the data base.
#'
#' @template param_rush
#' @template param_objective
#' @template param_search_space
#'
#' @param optimizer [OptimizerAsync].
#' @param instance [OptimInstanceAsync].
#'
#' @export
bbotk_worker_loop_centralized = function(rush, objective, search_space) {
while(!rush$terminated) {
task = rush$pop_task(fields = c("xs", "seed"))
xs_trafoed = trafo_xs(task$xs, search_space)

if (!is.null(task)) {
tryCatch({
ys = with_rng_state(objective$eval, args = list(xs = xs_trafoed), seed = task$seed)
rush$push_results(task$key, yss = list(ys), extra = list(list(x_domain = list(xs_trafoed), timestamp_ys = Sys.time())))
}, error = function(e) {
condition = list(message = e$message)
rush$push_failed(task$key, conditions = list(condition))
})
}
}
return(NULL)
}

#' @export
bbotk_worker_loop_decentralized = function(rush, optimizer, instance) {
bbotk_worker_loop = function(rush, optimizer, instance) {
# replace controller with worker
instance$rush = rush
instance$archive$rush = rush
Expand Down
2 changes: 1 addition & 1 deletion man/Archive.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2117841

Please sign in to comment.