From 5d556f53df3ce850e064f2b392f7e20240f8579e Mon Sep 17 00:00:00 2001 From: be-marc Date: Mon, 18 Dec 2023 11:52:49 +0100 Subject: [PATCH] refactor: optimize the runtime of archive$best --- NAMESPACE | 3 ++ NEWS.md | 2 ++ R/Archive.R | 45 +++++++++++++++++------------- man/Archive.Rd | 16 +++++------ tests/testthat/test_Archive.R | 52 +++++++++++++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 27 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index ebfc18f3b..93177b1c2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -66,6 +66,9 @@ import(mlr3misc) import(paradox) importFrom(R6,R6Class) importFrom(methods,formalArgs) +importFrom(mlr3misc,clbk) +importFrom(mlr3misc,clbks) +importFrom(mlr3misc,mlr_callbacks) importFrom(utils,bibentry) importFrom(utils,capture.output) importFrom(utils,head) diff --git a/NEWS.md b/NEWS.md index fd556a360..575c7039e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # bbotk (development version) +* refactor: Optimize the runtime of `archive$best()` method with partial sorting. + # bbotk 0.7.3 * fix: `OptimInstance$print()` errored when the search space was empty. diff --git a/R/Archive.R b/R/Archive.R index ccfc9c7b8..81c4e90f9 100644 --- a/R/Archive.R +++ b/R/Archive.R @@ -82,41 +82,48 @@ Archive = R6Class("Archive", }, #' @description - #' Returns the best scoring evaluation(s). For single-crit optimization, - #' the solution that minimizes / maximizes the objective function. + #' Returns the best scoring evaluation(s). + #' For single-crit optimization, the solution that minimizes / maximizes the objective function. #' For multi-crit optimization, the Pareto set / front. #' #' @param batch (`integer()`)\cr - #' The batch number(s) to limit the best results to. Default is - #' all batches. + #' The batch number(s) to limit the best results to. + #' Default is all batches. #' @param n_select (`integer(1L)`)\cr - #' Amount of points to select. Ignored for multi-crit optimization. + #' Amount of points to select. + #' Ignored for multi-crit optimization. #' #' @return [data.table::data.table()] - best = function(batch = NULL, n_select = 1) { - if (self$n_batch == 0L) return(data.table()) - if (is.null(batch)) batch = seq_len(self$n_batch) + best = function(batch = NULL, n_select = 1L) { + if (!self$n_batch) return(data.table()) assert_subset(batch, seq_len(self$n_batch)) + assert_int(n_select, lower = 1L) - tab = self$data[get("batch_nr") %in% batch, ] - assert_int(n_select, lower = 1L, upper = nrow(tab)) + tab = if (is.null(batch)) self$data else self$data[list(batch), , on = "batch_nr"] - max_to_min = self$codomain$maximization_to_minimization if (self$codomain$target_length == 1L) { - setorderv(tab, self$cols_y, order = max_to_min, na.last = TRUE) - res = tab[seq_len(n_select), ] + if (n_select == 1L) { + # use which_max to find the best point + y = tab[[self$cols_y]] * -self$codomain$maximization_to_minimization + ii = which_max(y, ties_method = "random") + tab[ii] + } else { + # use partial sort to find the best points + y = tab[[self$cols_y]] * self$codomain$maximization_to_minimization + i = sort(y, partial = n_select)[n_select] + ii = which(y <= i) + tab[ii] + } } else { + # use non-dominated sorting to find the best points ymat = t(as.matrix(tab[, self$cols_y, with = FALSE])) - ymat = max_to_min * ymat - res = tab[!is_dominated(ymat)] + ymat = self$codomain$maximization_to_minimization * ymat + tab[!is_dominated(ymat)] } - - return(res) }, #' @description - #' Calculate best points w.r.t. non dominated sorting with hypervolume - #' contribution. + #' Calculate best points w.r.t. non dominated sorting with hypervolume contribution. #' #' @param batch (`integer()`)\cr #' The batch number(s) to limit the best points to. Default is diff --git a/man/Archive.Rd b/man/Archive.Rd index 4c4008fb8..8294d7d79 100644 --- a/man/Archive.Rd +++ b/man/Archive.Rd @@ -134,22 +134,23 @@ Optimal outcome.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-Archive-best}{}}} \subsection{Method \code{best()}}{ -Returns the best scoring evaluation(s). For single-crit optimization, -the solution that minimizes / maximizes the objective function. +Returns the best scoring evaluation(s). +For single-crit optimization, the solution that minimizes / maximizes the objective function. For multi-crit optimization, the Pareto set / front. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Archive$best(batch = NULL, n_select = 1)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Archive$best(batch = NULL, n_select = 1L)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{batch}}{(\code{integer()})\cr -The batch number(s) to limit the best results to. Default is -all batches.} +The batch number(s) to limit the best results to. +Default is all batches.} \item{\code{n_select}}{(\code{integer(1L)})\cr -Amount of points to select. Ignored for multi-crit optimization.} +Amount of points to select. +Ignored for multi-crit optimization.} } \if{html}{\out{
}} } @@ -161,8 +162,7 @@ Amount of points to select. Ignored for multi-crit optimization.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-Archive-nds_selection}{}}} \subsection{Method \code{nds_selection()}}{ -Calculate best points w.r.t. non dominated sorting with hypervolume -contribution. +Calculate best points w.r.t. non dominated sorting with hypervolume contribution. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{Archive$nds_selection(batch = NULL, n_select = 1, ref_point = NULL)}\if{html}{\out{
}} } diff --git a/tests/testthat/test_Archive.R b/tests/testthat/test_Archive.R index 83eb69ab3..36dece5b8 100644 --- a/tests/testthat/test_Archive.R +++ b/tests/testthat/test_Archive.R @@ -120,3 +120,55 @@ test_that("deep clone works", { expect_different_address(a1$search_space, a2$search_space) expect_different_address(a1$codomain, a2$codomain) }) + +test_that("best method works with maximization", { + codomain = FUN_2D_CODOMAIN + codomain$params$y$tags = "maximize" + + archive = Archive$new(PS_2D, FUN_2D_CODOMAIN) + xdt = data.table(x1 = runif(5), x2 = runif(5)) + xss_trafoed = list(list(x1 = runif(5), x2 = runif(5))) + ydt = data.table(y = c(1, 0.25, 2, 0.5, 0.3)) + archive$add_evals(xdt, xss_trafoed, ydt) + + expect_equal(archive$best()$y, 2) +}) + +test_that("best method works with minimization", { + codomain = FUN_2D_CODOMAIN + codomain$params$y$tags = "minimize" + + archive = Archive$new(PS_2D, FUN_2D_CODOMAIN) + xdt = data.table(x1 = runif(5), x2 = runif(5)) + xss_trafoed = list(list(x1 = runif(5), x2 = runif(5))) + ydt = data.table(y = c(1, 0.25, 2, 0.5, 0.3)) + archive$add_evals(xdt, xss_trafoed, ydt) + + expect_equal(archive$best()$y, 0.25) +}) + +test_that("best method returns top n results with maximization", { + codomain = FUN_2D_CODOMAIN + codomain$params$y$tags = "maximize" + + archive = Archive$new(PS_2D, FUN_2D_CODOMAIN) + xdt = data.table(x1 = runif(5), x2 = runif(5)) + xss_trafoed = list(list(x1 = runif(5), x2 = runif(5))) + ydt = data.table(y = c(1, 0.25, 2, 0.5, 0.3)) + archive$add_evals(xdt, xss_trafoed, ydt) + + expect_equal(archive$best(n_select = 2)$y, c(1, 2)) +}) + +test_that("best method returns top n results with minimization", { + codomain = FUN_2D_CODOMAIN + codomain$params$y$tags = "minimize" + + archive = Archive$new(PS_2D, FUN_2D_CODOMAIN) + xdt = data.table(x1 = runif(5), x2 = runif(5)) + xss_trafoed = list(list(x1 = runif(5), x2 = runif(5))) + ydt = data.table(y = c(1, 0.25, 2, 0.5, 0.3)) + archive$add_evals(xdt, xss_trafoed, ydt) + + expect_equal(archive$best(n_select = 2)$y, c(0.25, 0.3)) +})