From 25b7b1bcaa2ef3d1a370348793ac81cae6a4dbb2 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 13 Mar 2024 12:36:53 +0100 Subject: [PATCH] fix: transformation functions in random search --- R/OptimizerRandomSearchV2.R | 16 +++++++++++----- .../testthat/test_OptimInstanceRushSingleCrit.R | 6 +++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/R/OptimizerRandomSearchV2.R b/R/OptimizerRandomSearchV2.R index 3f73fff0..4f9c60cd 100644 --- a/R/OptimizerRandomSearchV2.R +++ b/R/OptimizerRandomSearchV2.R @@ -25,19 +25,25 @@ OptimizerRandomSearchV2 = R6Class("OptimizerRandomSearchV2", private = list( .optimize = function(inst) { + search_space = inst$search_space + rush = inst$rush while(!inst$is_terminated) { # ask - sampler = SamplerUnif$new(inst$search_space) + sampler = SamplerUnif$new(search_space) xdt = sampler$sample(1)$data - xs = transform_xdt_to_xss(xdt, inst$search_space)[[1]] - key = inst$rush$push_running_task(list(xs)) + xss = transpose_list(xdt) + xs = xss[[1]][inst$archive$cols_x] + xs_trafoed = trafo_xs(xs, search_space) + keys = inst$rush$push_running_task(list(xs), extra = list(list(timestamp_xs = Sys.time()))) # eval - ys = inst$objective$eval(xs) + ys = inst$objective$eval(xs_trafoed) # tell - inst$rush$push_results(key, list(ys)) + rush$push_results(keys, yss = list(ys), extra = list(list( + x_domain = list(xs_trafoed), + timestamp_ys = Sys.time()))) } } ) diff --git a/tests/testthat/test_OptimInstanceRushSingleCrit.R b/tests/testthat/test_OptimInstanceRushSingleCrit.R index 9db10bbe..754ee36a 100644 --- a/tests/testthat/test_OptimInstanceRushSingleCrit.R +++ b/tests/testthat/test_OptimInstanceRushSingleCrit.R @@ -68,18 +68,18 @@ test_that("starting workers and evaluating points works in a decentralized netwo test_that("random search v2 works", { skip_on_cran() flush_redis() + options(bbotk_local = TRUE) rush_plan(n_workers = 2) instance = OptimInstanceRushSingleCrit$new( objective = OBJ_2D, search_space = PS_2D, - terminator = trm("evals", n_evals = 1L), + terminator = trm("evals", n_evals = 20L), ) optimizer = opt("random_search_v2") - instance$rush = RushWorker$new(instance$rush$network_id, host = "local") - get_private(optimizer)$.optimize_remote(instance) + optimizer$optimize(instance) })