diff --git a/NEWS.md b/NEWS.md index 4c598039..ba4703be 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # mlr3tuning (development version) +* fix: The `$predict_type` was written to the model even when the `AutoTuner` was not trained. + # mlr3tuning 1.3.0 * feat: Save `ArchiveAsyncTuning` to a `data.table` with `ArchiveAsyncTuningFrozen`. diff --git a/R/AutoTuner.R b/R/AutoTuner.R index 9164476b..0981abe9 100644 --- a/R/AutoTuner.R +++ b/R/AutoTuner.R @@ -314,6 +314,8 @@ AutoTuner = R6Class("AutoTuner", #' @field predict_type (`character(1)`)\cr #' Stores the currently active predict type, e.g. `"response"`. #' Must be an element of `$predict_types`. + #' A few learners already use the predict type during training. + #' So there is no guarantee that changing the predict type after tuning and training will have any effect or does not lead to errors. predict_type = function(rhs) { if (missing(rhs)) { return(private$.predict_type) @@ -322,10 +324,12 @@ AutoTuner = R6Class("AutoTuner", stopf("Learner '%s' does not support predict type '%s'", self$id, rhs) } - # Catches 'Error: Field/Binding is read-only' bug - tryCatch({ + self$instance_args$learner$predict_type = rhs + + + if (!is.null(self$model)) { self$model$learner$predict_type = rhs - }, error = function(cond){}) + } private$.predict_type = rhs }, diff --git a/man/AutoTuner.Rd b/man/AutoTuner.Rd index afb4408b..2624968c 100644 --- a/man/AutoTuner.Rd +++ b/man/AutoTuner.Rd @@ -168,7 +168,9 @@ Short-cut to \code{result} from tuning instance.} \item{\code{predict_type}}{(\code{character(1)})\cr Stores the currently active predict type, e.g. \code{"response"}. -Must be an element of \verb{$predict_types}.} +Must be an element of \verb{$predict_types}. +A few learners already use the predict type during training. +So there is no guarantee that changing the predict type after tuning and training will have any effect or does not lead to errors.} \item{\code{hash}}{(\code{character(1)})\cr Hash (unique identifier) for this object.} diff --git a/tests/testthat/test_AutoTuner.R b/tests/testthat/test_AutoTuner.R index 64e9b9ac..66c85987 100644 --- a/tests/testthat/test_AutoTuner.R +++ b/tests/testthat/test_AutoTuner.R @@ -231,22 +231,44 @@ test_that("store_tuning_instance, store_benchmark_result and store_models flags }) test_that("predict_type works", { - te = trm("evals", n_evals = 4) - task = tsk("iris") - ps = TEST_MAKE_PS1(n_dim = 1) - ms = msr("classif.ce") - tuner = tnr("grid_search", resolution = 3) + task = tsk("pima") - at = AutoTuner$new(lrn("classif.rpart"), rsmp("holdout"), ms, te, - tuner = tuner, ps) + # response predict type + at = auto_tuner( + tuner = tnr("random_search"), + learner = lrn("classif.rpart"), + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 4)) + + expect_equal(at$predict_type, "response") at$train(task) expect_equal(at$predict_type, "response") expect_equal(at$model$learner$predict_type, "response") + # change predict type after training at$predict_type = "prob" expect_equal(at$predict_type, "prob") expect_equal(at$model$learner$predict_type, "prob") + + # prob predict type + at = auto_tuner( + tuner = tnr("random_search"), + learner = lrn("classif.rpart", predict_type = "prob"), + resampling = rsmp("holdout"), + measure = msr("classif.ce"), + terminator = trm("evals", n_evals = 4)) + + expect_equal(at$predict_type, "prob") + + at$train(task) + + expect_equal(at$predict_type, "prob") + expect_equal(at$model$learner$predict_type, "prob") + + pred = at$predict(task) + expect_numeric(pred$score(msr("classif.auc"))) }) test_that("search space from TuneToken works", {