Skip to content

Commit

Permalink
Merge branch 'master' into s3params_compat
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Jan 20, 2024
2 parents 2162b17 + 8114ae2 commit c7fe236
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
21 changes: 10 additions & 11 deletions R/Archive.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ Archive = R6Class("Archive",
ii = which_max(y, ties_method = "random")
tab[ii]
} else {
# use partial sort to find the best points
y = tab[[self$cols_y]] * self$codomain$maximization_to_minimization
i = sort(y, partial = n_select)[n_select]
ii = which(y <= i)
tab[ii]
# copy table to avoid changing the order of the archive
if (is.null(batch)) tab = copy(self$data)
# use data.table fast sort to find the best points
setorderv(tab, cols = self$cols_y, order = self$codomain$maximization_to_minimization)
head(tab, n_select)
}
} else {
# use non-dominated sorting to find the best points
Expand All @@ -134,17 +134,16 @@ Archive = R6Class("Archive",
#'
#' @return [data.table::data.table()]
nds_selection = function(batch = NULL, n_select = 1, ref_point = NULL) {
if (self$n_batch == 0L) stop("No results stored in archive")
if (is.null(batch)) batch = seq_len(self$n_batch)
assert_integerish(batch, lower = 1L, upper = self$n_batch, coerce = TRUE)
if (!self$n_batch) return(data.table())
assert_subset(batch, seq_len(self$n_batch))

tab = self$data[get("batch_nr") %in% batch, ]
tab = if (is.null(batch)) self$data else self$data[list(batch), , on = "batch_nr"]
assert_int(n_select, lower = 1L, upper = nrow(tab))

points = t(as.matrix(tab[, self$cols_y, with = FALSE]))
minimize = map_lgl(self$codomain$target_tags, has_element, "minimize")
inds = nds_selection(points, n_select, ref_point, minimize)
tab[inds, ]
ii = nds_selection(points, n_select, ref_point, minimize)
tab[ii, ]
},

#' @description
Expand Down
28 changes: 27 additions & 1 deletion tests/testthat/test_Archive.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,20 @@ test_that("best method returns top n results with maximization", {
ydt = data.table(y = c(1, 0.25, 2, 0.5, 0.3))
archive$add_evals(xdt, xss_trafoed, ydt)

expect_equal(archive$best(n_select = 2)$y, c(1, 2))
expect_equal(archive$best(n_select = 2)$y, c(2, 1))
})

test_that("best method returns top n results with maximization and ties", {
codomain = FUN_2D_CODOMAIN
codomain$params$y$tags = "maximize"

archive = Archive$new(PS_2D, FUN_2D_CODOMAIN)
xdt = data.table(x1 = runif(5), x2 = runif(5))
xss_trafoed = list(list(x1 = runif(5), x2 = runif(5)))
ydt = data.table(y = c(1, 1, 2, 0.5, 0.5))
archive$add_evals(xdt, xss_trafoed, ydt)

expect_equal(archive$best(n_select = 2)$y, c(2, 1))
})

test_that("best method returns top n results with minimization", {
Expand All @@ -194,3 +207,16 @@ test_that("best method returns top n results with minimization", {

expect_equal(archive$best(n_select = 2)$y, c(0.25, 0.3))
})

test_that("best method returns top n results with minimization and ties", {
codomain = FUN_2D_CODOMAIN
codomain$params$y$tags = "minimize"

archive = Archive$new(PS_2D, FUN_2D_CODOMAIN)
xdt = data.table(x1 = runif(5), x2 = runif(5))
xss_trafoed = list(list(x1 = runif(5), x2 = runif(5)))
ydt = data.table(y = c(1, 0.25, 0.5, 0.3, 0.3))
archive$add_evals(xdt, xss_trafoed, ydt)

expect_equal(archive$best(n_select = 2)$y, c(0.25, 0.3))
})

0 comments on commit c7fe236

Please sign in to comment.