Skip to content

Commit

Permalink
feat: add to fsi
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 15, 2023
1 parent b07e732 commit e75c519
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 9 deletions.
2 changes: 1 addition & 1 deletion R/fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fselect = function(
check_values = check_values,
callbacks = callbacks,
ties_method = ties_method)
} else {
} else {
FSelectInstanceMultiCrit$new(
task = task,
learner = learner,
Expand Down
1 change: 1 addition & 0 deletions R/mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ load_callback_svm_rfe = function() {
#'
#' @description
#' Selects the smallest feature set within one standard error of the best as the result.
#' If there are multiple feature sets with the same performance and number of features, the first one is selected.
#'
#' @examples
#' clbk("mlr3fselect.one_se_rule")
Expand Down
40 changes: 37 additions & 3 deletions R/sugar.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,47 @@ fss = function(.keys, ...) {
#' @template param_store_models
#' @template param_check_values
#' @template param_callbacks
#' @template param_ties_method
#'
#' @inheritSection FSelectInstanceSingleCrit Resources
#' @inheritSection FSelectInstanceSingleCrit Default Measures
#'
#' @export
#' @inherit FSelectInstanceSingleCrit examples
fsi = function(task, learner, resampling, measures = NULL, terminator, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) {
FSelectInstance = if (!is.list(measures)) FSelectInstanceSingleCrit else FSelectInstanceMultiCrit
FSelectInstance$new(task, learner, resampling, measures, terminator, store_benchmark_result, store_models, check_values, callbacks)
fsi = function(
task,
learner,
resampling,
measures = NULL,
terminator,
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = list(),
ties_method = "n_features"
) {
if (!is.list(measures)) {
FSelectInstanceSingleCrit$new(
task = task,
learner = learner,
resampling = resampling,
measure = measures,
terminator = terminator,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks,
ties_method = ties_method)
} else {
FSelectInstanceMultiCrit$new(
task = task,
learner = learner,
resampling = resampling,
measures = measures,
terminator = terminator,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
}
}
12 changes: 11 additions & 1 deletion man/fsi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr3fselect.one_se_rule.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 57 additions & 3 deletions tests/testthat/test_ArchiveFSelect.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,61 @@ test_that("ArchiveFSelect as.data.table function works", {
expect_equal(tab$batch_nr, 1:10)
})

test_that("best method works with ties and maximization", {
test_that("global ties method works", {
design = mlr3misc::rowwise_table(
~x1, ~x2, ~x3, ~x4,
FALSE, TRUE, FALSE, TRUE,
TRUE, FALSE, FALSE, TRUE,
TRUE, FALSE, FALSE, FALSE,
FALSE, TRUE, FALSE, FALSE
)

score_design = data.table(
score = c(0.1, 0.2, 0.2, 0.1),
features = list(c("x2", "x4"), c("x1", "x4"), "x1", c("x1", "x2"))
)
measure = msr("dummy", score_design = score_design, minimize = FALSE)

# n_features
instance = fselect(
fselector = fs("design_points", design = design),
task = TEST_MAKE_TSK(),
learner = lrn("regr.rpart"),
resampling = rsmp("cv", folds = 3),
measures = measure,
ties_method = "n_features"
)

expect_equal(instance$result_feature_set, "x1")

# first
instance$clear()
instance = fselect(
fselector = fs("design_points", design = design),
task = TEST_MAKE_TSK(),
learner = lrn("regr.rpart"),
resampling = rsmp("cv", folds = 3),
measures = measure,
ties_method = "first"
)

expect_equal(instance$result_feature_set, c("x1", "x4"))

# random
instance$clear()
instance = fselect(
fselector = fs("design_points", design = design),
task = TEST_MAKE_TSK(),
learner = lrn("regr.rpart"),
resampling = rsmp("cv", folds = 3),
measures = measure,
ties_method = "random"
)

expect_names(instance$result_feature_set, must.include = "x1")
})

test_that("local ties method works when maximize measure", {

design = mlr3misc::rowwise_table(
~x1, ~x2, ~x3, ~x4,
Expand Down Expand Up @@ -169,7 +223,7 @@ test_that("best method works with ties and maximization", {
expect_features(instance$archive$best(ties_method = "n_features")[, list(x1, x2, x3, x4)], identical_to = "x1")
})

test_that("best method works with ties and minimization", {
test_that("local ties method works when minimize measure", {

design = mlr3misc::rowwise_table(
~x1, ~x2, ~x3, ~x4,
Expand Down Expand Up @@ -198,7 +252,7 @@ test_that("best method works with ties and minimization", {
expect_features(instance$archive$best(ties_method = "n_features")[, list(x1, x2, x3, x4)], identical_to = "x2")
})

test_that("best method works with batches and ties", {
test_that("local ties method works with batches", {

design = mlr3misc::rowwise_table(
~x1, ~x2, ~x3, ~x4,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ test_that("fselect interface is equal to FSelectInstanceSingleCrit", {

test_that("fselect interface is equal to FSelectInstanceMultiCrit", {
fselect_args = formalArgs(fselect)
fselect_args = fselect_args[fselect_args != "fselector"]
fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method")]

instance_args = formalArgs(FSelectInstanceMultiCrit$public_methods$initialize)
instance_args = c(instance_args, "term_evals", "term_time")
Expand Down

0 comments on commit e75c519

Please sign in to comment.