From e75c5199d49f48938b4872ee6384e612468c2f49 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 15 Dec 2023 18:18:02 +0100 Subject: [PATCH] feat: add to fsi --- R/fselect.R | 2 +- R/mlr_callbacks.R | 1 + R/sugar.R | 40 +++++++++++++++++-- man/fsi.Rd | 12 +++++- man/mlr3fselect.one_se_rule.Rd | 1 + tests/testthat/test_ArchiveFSelect.R | 60 ++++++++++++++++++++++++++-- tests/testthat/test_fselect.R | 2 +- 7 files changed, 109 insertions(+), 9 deletions(-) diff --git a/R/fselect.R b/R/fselect.R index c2474841..31164a84 100644 --- a/R/fselect.R +++ b/R/fselect.R @@ -91,7 +91,7 @@ fselect = function( check_values = check_values, callbacks = callbacks, ties_method = ties_method) - } else { + } else { FSelectInstanceMultiCrit$new( task = task, learner = learner, diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R index 07217bcb..320a1e24 100644 --- a/R/mlr_callbacks.R +++ b/R/mlr_callbacks.R @@ -116,6 +116,7 @@ load_callback_svm_rfe = function() { #' #' @description #' Selects the smallest feature set within one standard error of the best as the result. +#' If there are multiple feature sets with the same performance and number of features, the first one is selected. #' #' @examples #' clbk("mlr3fselect.one_se_rule") diff --git a/R/sugar.R b/R/sugar.R index 916cdabb..a4193b54 100644 --- a/R/sugar.R +++ b/R/sugar.R @@ -46,13 +46,47 @@ fss = function(.keys, ...) { #' @template param_store_models #' @template param_check_values #' @template param_callbacks +#' @template param_ties_method #' #' @inheritSection FSelectInstanceSingleCrit Resources #' @inheritSection FSelectInstanceSingleCrit Default Measures #' #' @export #' @inherit FSelectInstanceSingleCrit examples -fsi = function(task, learner, resampling, measures = NULL, terminator, store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, callbacks = list()) { - FSelectInstance = if (!is.list(measures)) FSelectInstanceSingleCrit else FSelectInstanceMultiCrit - FSelectInstance$new(task, learner, resampling, measures, terminator, store_benchmark_result, store_models, check_values, callbacks) +fsi = function( + task, + learner, + resampling, + measures = NULL, + terminator, + store_benchmark_result = TRUE, + store_models = FALSE, + check_values = FALSE, + callbacks = list(), + ties_method = "n_features" + ) { + if (!is.list(measures)) { + FSelectInstanceSingleCrit$new( + task = task, + learner = learner, + resampling = resampling, + measure = measures, + terminator = terminator, + store_benchmark_result = store_benchmark_result, + store_models = store_models, + check_values = check_values, + callbacks = callbacks, + ties_method = ties_method) + } else { + FSelectInstanceMultiCrit$new( + task = task, + learner = learner, + resampling = resampling, + measures = measures, + terminator = terminator, + store_benchmark_result = store_benchmark_result, + store_models = store_models, + check_values = check_values, + callbacks = callbacks) + } } diff --git a/man/fsi.Rd b/man/fsi.Rd index 6ef6990c..b38ff294 100644 --- a/man/fsi.Rd +++ b/man/fsi.Rd @@ -13,7 +13,8 @@ fsi( store_benchmark_result = TRUE, store_models = FALSE, check_values = FALSE, - callbacks = list() + callbacks = list(), + ties_method = "n_features" ) } \arguments{ @@ -47,6 +48,15 @@ validity?} \item{callbacks}{(list of \link{CallbackFSelect})\cr List of callbacks.} + +\item{ties_method}{(\code{character(1)})\cr +The method to break ties when selecting sets while optimizing and when selecting the best set. +Can be one of \code{n_features}, \code{first}, \code{random}. +The option \code{n_features} (default) selects the feature set with the least features. +If there are multiple best feature sets with the same number of features, the first one is selected. +The \code{first} method returns the first added best feature set. +The \code{random} method returns a random feature set from the best feature sets. +Ignored if multiple measures are used.} } \description{ Function to construct a \link{FSelectInstanceSingleCrit} or \link{FSelectInstanceMultiCrit}. diff --git a/man/mlr3fselect.one_se_rule.Rd b/man/mlr3fselect.one_se_rule.Rd index cde3ab3f..2be08be4 100644 --- a/man/mlr3fselect.one_se_rule.Rd +++ b/man/mlr3fselect.one_se_rule.Rd @@ -5,6 +5,7 @@ \title{One Standard Error Rule Callback} \description{ Selects the smallest feature set within one standard error of the best as the result. +If there are multiple feature sets with the same performance and number of features, the first one is selected. } \examples{ clbk("mlr3fselect.one_se_rule") diff --git a/tests/testthat/test_ArchiveFSelect.R b/tests/testthat/test_ArchiveFSelect.R index 0822f5bf..ff041cad 100644 --- a/tests/testthat/test_ArchiveFSelect.R +++ b/tests/testthat/test_ArchiveFSelect.R @@ -140,7 +140,61 @@ test_that("ArchiveFSelect as.data.table function works", { expect_equal(tab$batch_nr, 1:10) }) -test_that("best method works with ties and maximization", { +test_that("global ties method works", { + design = mlr3misc::rowwise_table( + ~x1, ~x2, ~x3, ~x4, + FALSE, TRUE, FALSE, TRUE, + TRUE, FALSE, FALSE, TRUE, + TRUE, FALSE, FALSE, FALSE, + FALSE, TRUE, FALSE, FALSE + ) + + score_design = data.table( + score = c(0.1, 0.2, 0.2, 0.1), + features = list(c("x2", "x4"), c("x1", "x4"), "x1", c("x1", "x2")) + ) + measure = msr("dummy", score_design = score_design, minimize = FALSE) + + # n_features + instance = fselect( + fselector = fs("design_points", design = design), + task = TEST_MAKE_TSK(), + learner = lrn("regr.rpart"), + resampling = rsmp("cv", folds = 3), + measures = measure, + ties_method = "n_features" + ) + + expect_equal(instance$result_feature_set, "x1") + + # first + instance$clear() + instance = fselect( + fselector = fs("design_points", design = design), + task = TEST_MAKE_TSK(), + learner = lrn("regr.rpart"), + resampling = rsmp("cv", folds = 3), + measures = measure, + ties_method = "first" + ) + + expect_equal(instance$result_feature_set, c("x1", "x4")) + + # random + instance$clear() + instance = fselect( + fselector = fs("design_points", design = design), + task = TEST_MAKE_TSK(), + learner = lrn("regr.rpart"), + resampling = rsmp("cv", folds = 3), + measures = measure, + ties_method = "random" + ) + + expect_names(instance$result_feature_set, must.include = "x1") +}) + +test_that("local ties method works when maximize measure", { design = mlr3misc::rowwise_table( ~x1, ~x2, ~x3, ~x4, @@ -169,7 +223,7 @@ test_that("best method works with ties and maximization", { expect_features(instance$archive$best(ties_method = "n_features")[, list(x1, x2, x3, x4)], identical_to = "x1") }) -test_that("best method works with ties and minimization", { +test_that("local ties method works when minimize measure", { design = mlr3misc::rowwise_table( ~x1, ~x2, ~x3, ~x4, @@ -198,7 +252,7 @@ test_that("best method works with ties and minimization", { expect_features(instance$archive$best(ties_method = "n_features")[, list(x1, x2, x3, x4)], identical_to = "x2") }) -test_that("best method works with batches and ties", { +test_that("local ties method works with batches", { design = mlr3misc::rowwise_table( ~x1, ~x2, ~x3, ~x4, diff --git a/tests/testthat/test_fselect.R b/tests/testthat/test_fselect.R index 82ed5ac1..6e83c196 100644 --- a/tests/testthat/test_fselect.R +++ b/tests/testthat/test_fselect.R @@ -38,7 +38,7 @@ test_that("fselect interface is equal to FSelectInstanceSingleCrit", { test_that("fselect interface is equal to FSelectInstanceMultiCrit", { fselect_args = formalArgs(fselect) - fselect_args = fselect_args[fselect_args != "fselector"] + fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method")] instance_args = formalArgs(FSelectInstanceMultiCrit$public_methods$initialize) instance_args = c(instance_args, "term_evals", "term_time")