From 5c24ba8ac1990b7e336cc6a0c378fc54001700d9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 6 Jan 2025 17:04:59 +0100 Subject: [PATCH] Fix/predict newdata (#1240) * ci: fail on note * fix(predict): type conversion when predicting on new data * ... --------- Co-authored-by: Marc Becker <33069354+be-marc@users.noreply.github.com> --- NEWS.md | 4 +++- R/Learner.R | 11 +++++++++++ R/assertions.R | 2 +- man/Learner.Rd | 4 +++- tests/testthat/test_Learner.R | 27 +++++++++++++++++++++------ 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/NEWS.md b/NEWS.md index 69e7292de..441f1976e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,9 @@ # mlr3 (development version) +* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions if the input is a `data.frame` (#685) +* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning. * Column names with UTF-8 characters are now allowed by default. -The option `mlr3.allow_utf8_names` is removed. + The option `mlr3.allow_utf8_names` is removed. * BREAKING CHANGE: `Learner$predict_types` is read-only now. * docs: Clear up behavior of `Learner$predict_type` after training. diff --git a/R/Learner.R b/R/Learner.R index 84fd7a075..de23f51e2 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -377,6 +377,8 @@ Learner = R6Class("Learner", #' `data.frame()` or [DataBackend]. #' If a [DataBackend] is provided as `newdata`, the row ids are preserved, #' otherwise they are set to to the sequence `1:nrow(newdata)`. + #' If the input is a `data.frame`, [`auto_convert`] is used for type-conversions to ensure compatability + #' of features between `$train()` and `$predict()`. #' #' @param task ([Task]). #' @@ -393,6 +395,14 @@ Learner = R6Class("Learner", task = task_rm_backend(task) } + if (is.data.frame(newdata)) { + keep_cols = intersect(names(newdata), task$col_info$id) + ci = task$col_info[list(keep_cols), on = "id"] + newdata = do.call(data.table, Map(auto_convert, + value = as.list(newdata)[ci$id], + id = ci$id, type = ci$type, levels = ci$levels)) + } + newdata = as_data_backend(newdata) assert_names(newdata$colnames, must.include = task$feature_names) @@ -409,6 +419,7 @@ Learner = R6Class("Learner", # do some type conversions if necessary task$backend = newdata + task$col_info = col_info(task$backend) task$row_roles$use = task$backend$rownames self$predict(task) }, diff --git a/R/assertions.R b/R/assertions.R index d86b5ce43..e14494bf4 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -189,7 +189,7 @@ assert_predictable = function(task, learner) { all(pmap_lgl(list(x = ci_train$levels, y = ci_predict$levels), identical)) if (!ok) { - lg$warn("Learner '%s' received task with different column info (feature type or level ordering) during train and predict.", learner$id) + stopf("Learner '%s' received task with different column info (feature type or factor level ordering) during train and predict.", learner$id) } } diff --git a/man/Learner.Rd b/man/Learner.Rd index 817222e0f..c57dcbf14 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -518,7 +518,9 @@ New data to predict on. All data formats convertible by \code{\link[=as_data_backend]{as_data_backend()}} are supported, e.g. \code{data.frame()} or \link{DataBackend}. If a \link{DataBackend} is provided as \code{newdata}, the row ids are preserved, -otherwise they are set to to the sequence \code{1:nrow(newdata)}.} +otherwise they are set to to the sequence \code{1:nrow(newdata)}. +If the input is a \code{data.frame}, \code{\link{auto_convert}} is used for type-conversions to ensure compatability +of features between \verb{$train()} and \verb{$predict()}.} \item{\code{task}}{(\link{Task}).} } diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 8917f4452..901c6c271 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -255,13 +255,15 @@ test_that("learner cannot be trained with TuneToken present", { test_that("integer<->numeric conversion in newdata (#533)", { data = data.table(y = runif(10), x = 1:10) - newdata = data.table(y = runif(10), x = 1:10 + 0.1) + newdata1 = data.table(y = runif(10), x = as.double(1:10)) + newdata2 = data.table(y = runif(10), x = 1:10 + 0.1) task = TaskRegr$new("test", data, "y") learner = lrn("regr.featureless") learner$train(task) expect_prediction(learner$predict_newdata(data)) - expect_prediction(learner$predict_newdata(newdata)) + expect_prediction(learner$predict_newdata(newdata1)) + expect_error(learner$predict_newdata(newdata2), "failed to convert from class 'numeric'") }) test_that("weights", { @@ -575,10 +577,7 @@ test_that("column info is compared during predict", { task_other = as_task_classif(dother, target = "y") l = lrn("classif.rpart") l$train(task) - old_threshold = lg$threshold - lg$set_threshold("warn") - expect_output(l$predict(task_flip), "task with different column info") - lg$set_threshold(old_threshold) + expect_error(l$predict(task_flip), "task with different column info") expect_error(l$predict(task_other), "with different columns") }) @@ -663,3 +662,19 @@ test_that("configure method works", { expect_equal(learner$param_set$values$xval, 10) expect_equal(learner$predict_sets, "train") }) + +test_that("predict_newdata auto conversion (#685)", { + l = lrn("classif.debug", save_tasks = TRUE)$train(tsk("iris")$select(c("Sepal.Length", "Sepal.Width"))) + expect_error(l$predict_newdata(data.table(Sepal.Length = 1, Sepal.Width = "abc")), + "Incompatible types during auto-converting column 'Sepal.Width'", fixed = TRUE) + expect_error(l$predict_newdata(data.table(Sepal.Length = 1L)), + "but is missing elements") + + # New test for integerish value conversion to double + p1 = l$predict_newdata(data.table(Sepal.Length = 1, Sepal.Width = 2)) + p2 = l$predict_newdata(data.table(Sepal.Length = 1L, Sepal.Width = 2)) + expect_equal(l$model$task_predict$col_info[list("Sepal.Length")]$type, "numeric") + expect_double(l$model$task_predict$data(cols = "Sepal.Length")[[1]]) + + expect_equal(p1, p2) +})