From e21782a288f2dfbe38eb7034bdcf7321c133ef2d Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 6 Jan 2025 17:12:03 +0100 Subject: [PATCH] feat: improve docs for converters and better checks (#1231) * feat: improve docs for converters and better checks * ... * fix failing tests --- NAMESPACE | 2 ++ R/BenchmarkResult.R | 18 +++++++++++++++--- R/Prediction.R | 12 ++++++++++-- R/ResampleResult.R | 18 +++++++++++++++--- R/as_learner.R | 1 + R/as_measure.R | 2 ++ R/as_resampling.R | 3 ++- R/as_task.R | 8 ++++++++ R/as_task_classif.R | 2 +- R/as_task_regr.R | 2 +- R/assertions.R | 26 ++++++++++++++++++++++++++ man/as_resampling.Rd | 1 + man/as_task.Rd | 2 ++ man/as_task_classif.Rd | 2 +- man/as_task_regr.Rd | 2 +- man/assert_empty_ellipsis.Rd | 20 ++++++++++++++++++++ tests/testthat/test_Learner.R | 3 +-- tests/testthat/test_as_learner.R | 4 ++++ tests/testthat/test_as_measure.R | 4 ++++ tests/testthat/test_as_resampling.R | 4 ++++ tests/testthat/test_as_task.R | 4 ++++ tests/testthat/test_assertions.R | 8 ++++++++ 22 files changed, 133 insertions(+), 15 deletions(-) create mode 100644 man/assert_empty_ellipsis.Rd create mode 100644 tests/testthat/test_assertions.R diff --git a/NAMESPACE b/NAMESPACE index 8946cdd5e..d6625ea4e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -42,6 +42,7 @@ S3method(as_resampling,Resampling) S3method(as_resamplings,default) S3method(as_resamplings,list) S3method(as_task,Task) +S3method(as_task,default) S3method(as_task_classif,DataBackend) S3method(as_task_classif,Matrix) S3method(as_task_classif,TaskClassif) @@ -200,6 +201,7 @@ export(as_tasks) export(as_tasks_unsupervised) export(assert_backend) export(assert_benchmark_result) +export(assert_empty_ellipsis) export(assert_learnable) export(assert_learner) export(assert_learners) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 9ee35c3ce..d3abea9bc 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -175,7 +175,11 @@ BenchmarkResult = R6Class("BenchmarkResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -230,7 +234,11 @@ BenchmarkResult = R6Class("BenchmarkResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = as_measures(measures, task_type = private$.data$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } map_dtr(self$resample_results$resample_result, function(rr) { rr$obs_loss(measures, predict_sets) @@ -276,7 +284,11 @@ BenchmarkResult = R6Class("BenchmarkResult", #' #' @return [data.table::data.table()]. aggregate = function(measures = NULL, ids = TRUE, uhashes = FALSE, params = FALSE, conditions = FALSE) { - measures = assert_measures(as_measures(measures, task_type = self$task_type)) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } assert_flag(ids) assert_flag(uhashes) assert_flag(params) diff --git a/R/Prediction.R b/R/Prediction.R index 243397ea6..ad7a0c8ce 100644 --- a/R/Prediction.R +++ b/R/Prediction.R @@ -90,7 +90,11 @@ Prediction = R6Class("Prediction", #' #' @return [Prediction]. score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set)) set_names(scores, ids(measures)) }, @@ -105,7 +109,11 @@ Prediction = R6Class("Prediction", #' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an #' additional transformation after aggregation, in this example taking the square-root. obs_loss = function(measures = NULL) { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } get_obs_loss(as.data.table(self), measures) }, diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 83b79ada3..895696d54 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -143,7 +143,11 @@ ResampleResult = R6Class("ResampleResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = as_measures(measures, task_type = private$.data$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -196,7 +200,11 @@ ResampleResult = R6Class("ResampleResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = as_measures(measures, task_type = self$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration") get_obs_loss(tab, measures) }, @@ -208,7 +216,11 @@ ResampleResult = R6Class("ResampleResult", #' #' @return Named `numeric()`. aggregate = function(measures = NULL) { - measures = as_measures(measures, task_type = private$.data$task_type) + measures = if (is.null(measures)) { + default_measures(self$task_type) + } else { + assert_measures(as_measures(measures)) + } resample_result_aggregate(self, measures) }, diff --git a/R/as_learner.R b/R/as_learner.R index 4c303f511..3ec806845 100644 --- a/R/as_learner.R +++ b/R/as_learner.R @@ -16,6 +16,7 @@ as_learner = function(x, ...) { # nolint #' Whether to discard the state. #' @rdname as_learner as_learner.Learner = function(x, clone = FALSE, discard_state = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone) && isTRUE(discard_state)) { clone_without(x, "state") } else if (isTRUE(clone)) { diff --git a/R/as_measure.R b/R/as_measure.R index d266d8f27..97f2798ef 100644 --- a/R/as_measure.R +++ b/R/as_measure.R @@ -17,12 +17,14 @@ as_measure = function(x, ...) { # nolint #' @export #' @rdname as_measure as_measure.NULL = function(x, task_type = NULL, ...) { # nolint + assert_empty_ellipsis(...) default_measures(task_type)[[1L]] } #' @export #' @rdname as_measure as_measure.Measure = function(x, clone = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone() else x } diff --git a/R/as_resampling.R b/R/as_resampling.R index 03c4e6cba..60e2453e6 100644 --- a/R/as_resampling.R +++ b/R/as_resampling.R @@ -2,7 +2,7 @@ #' #' @description #' Convert object to a [Resampling] or a list of [Resampling]. -#' +#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Resampling`]. #' @inheritParams as_task #' @export as_resampling = function(x, ...) { # nolint @@ -12,6 +12,7 @@ as_resampling = function(x, ...) { # nolint #' @export #' @rdname as_resampling as_resampling.Resampling = function(x, clone = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone() else x } diff --git a/R/as_task.R b/R/as_task.R index a6841dc39..1f7017f70 100644 --- a/R/as_task.R +++ b/R/as_task.R @@ -2,6 +2,8 @@ #' #' @description #' Convert object to a [Task] or a list of [Task]. +#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Task`] and additionally supports cloning. +#' In order to construct a [Task] from a `data.frame`, use task-specific converters such as [`as_task_classif()`] or [`as_task_regr()`]. #' #' @param x (any)\cr #' Object to convert. @@ -12,11 +14,17 @@ as_task = function(x, ...) { UseMethod("as_task") } +#' @export +as_task.default = function(x, ...) { + stopf("No method for class '%s'. To create a task from a `data.frame`, use dedicated converters such as `as_task_classif()` or `as_task_regr()`.", class(x)[1L]) +} + #' @rdname as_task #' @param clone (`logical(1)`)\cr #' If `TRUE`, ensures that the returned object is not the same as the input `x`. #' @export as_task.Task = function(x, clone = FALSE, ...) { # nolint + assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone(deep = TRUE) else x } diff --git a/R/as_task_classif.R b/R/as_task_classif.R index cf75a0396..4f1c39bbd 100644 --- a/R/as_task_classif.R +++ b/R/as_task_classif.R @@ -4,7 +4,7 @@ #' Convert object to a [TaskClassif]. #' This is a S3 generic. mlr3 ships with methods for the following objects: #' -#' 1. [TaskClassif]: ensure the identity +#' 1. [TaskClassif]: returns the object as-is, possibly cloned. #' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskClassif]. #' 3. [TaskRegr]: Calls [convert_task()]. #' diff --git a/R/as_task_regr.R b/R/as_task_regr.R index ce4f90d1a..ad1e682e9 100644 --- a/R/as_task_regr.R +++ b/R/as_task_regr.R @@ -4,7 +4,7 @@ #' Convert object to a [TaskRegr]. #' This is a S3 generic. mlr3 ships with methods for the following objects: #' -#' 1. [TaskRegr]: ensure the identity +#' 1. [TaskRegr]: returns the object as-is, possibly cloned. #' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskRegr]. #' 3. [TaskClassif]: Calls [convert_task()]. #' diff --git a/R/assertions.R b/R/assertions.R index e14494bf4..bd1a529ba 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -405,3 +405,29 @@ assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) { } invisible(x) } + +#' @title Assert Empty Ellipsis +#' @description +#' Assert that `...` arguments are empty. +#' Use this function in S3-methods to ensure that misspelling of arguments does not go unnoticed. +#' @param ... (any)\cr +#' Ellipsis arguments to check. +#' @keywords internal +#' @return `NULL` +#' @export +assert_empty_ellipsis = function(...) { + if (...length()) { + names = ...names() + if (is.null(names)) { + stopf("Received %i unnamed argument that was not used.", ...length()) + } else { + names2 = names[names != ""] + if (length(names2) == length(names)) { + stopf("Received the following named arguments that were unused: %s.", paste0(names2, collapse = ", ")) + } else { + stopf("Received unused arguments: %i unnamed, as well as named arguments %s.", length(names) - length(names2), paste0(names2, collapse = ", ")) + } + } + } + NULL +} diff --git a/man/as_resampling.Rd b/man/as_resampling.Rd index 846ea9ac3..02784a7ec 100644 --- a/man/as_resampling.Rd +++ b/man/as_resampling.Rd @@ -30,4 +30,5 @@ If \code{TRUE}, ensures that the returned object is not the same as the input \c } \description{ Convert object to a \link{Resampling} or a list of \link{Resampling}. +This method e.g. allows to convert an \code{\link[mlr3oml:oml_task]{mlr3oml::OMLTask}} to a \code{\link{Resampling}}. } diff --git a/man/as_task.Rd b/man/as_task.Rd index eba52ac8a..8aff6071e 100644 --- a/man/as_task.Rd +++ b/man/as_task.Rd @@ -30,4 +30,6 @@ If \code{TRUE}, ensures that the returned object is not the same as the input \c } \description{ Convert object to a \link{Task} or a list of \link{Task}. +This method e.g. allows to convert an \code{\link[mlr3oml:oml_task]{mlr3oml::OMLTask}} to a \code{\link{Task}} and additionally supports cloning. +In order to construct a \link{Task} from a \code{data.frame}, use task-specific converters such as \code{\link[=as_task_classif]{as_task_classif()}} or \code{\link[=as_task_regr]{as_task_regr()}}. } diff --git a/man/as_task_classif.Rd b/man/as_task_classif.Rd index 823cffaa8..4d633963c 100644 --- a/man/as_task_classif.Rd +++ b/man/as_task_classif.Rd @@ -106,7 +106,7 @@ Data frame containing all columns referenced in formula \code{x}.} Convert object to a \link{TaskClassif}. This is a S3 generic. mlr3 ships with methods for the following objects: \enumerate{ -\item \link{TaskClassif}: ensure the identity +\item \link{TaskClassif}: returns the object as-is, possibly cloned. \item \code{\link{formula}}, \code{\link[=data.frame]{data.frame()}}, \code{\link[=matrix]{matrix()}}, \code{\link[Matrix:Matrix]{Matrix::Matrix()}} and \link{DataBackend}: provides an alternative to the constructor of \link{TaskClassif}. \item \link{TaskRegr}: Calls \code{\link[=convert_task]{convert_task()}}. } diff --git a/man/as_task_regr.Rd b/man/as_task_regr.Rd index 35e57e84d..f2b77330c 100644 --- a/man/as_task_regr.Rd +++ b/man/as_task_regr.Rd @@ -100,7 +100,7 @@ Data frame containing all columns referenced in formula \code{x}.} Convert object to a \link{TaskRegr}. This is a S3 generic. mlr3 ships with methods for the following objects: \enumerate{ -\item \link{TaskRegr}: ensure the identity +\item \link{TaskRegr}: returns the object as-is, possibly cloned. \item \code{\link{formula}}, \code{\link[=data.frame]{data.frame()}}, \code{\link[=matrix]{matrix()}}, \code{\link[Matrix:Matrix]{Matrix::Matrix()}} and \link{DataBackend}: provides an alternative to the constructor of \link{TaskRegr}. \item \link{TaskClassif}: Calls \code{\link[=convert_task]{convert_task()}}. } diff --git a/man/assert_empty_ellipsis.Rd b/man/assert_empty_ellipsis.Rd new file mode 100644 index 000000000..21568d8b8 --- /dev/null +++ b/man/assert_empty_ellipsis.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/assertions.R +\name{assert_empty_ellipsis} +\alias{assert_empty_ellipsis} +\title{Assert Empty Ellipsis} +\usage{ +assert_empty_ellipsis(...) +} +\arguments{ +\item{...}{(any)\cr +Ellipsis arguments to check.} +} +\value{ +\code{NULL} +} +\description{ +Assert that \code{...} arguments are empty. +Use this function in S3-methods to ensure that misspelling of arguments does not go unnoticed. +} +\keyword{internal} diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 901c6c271..5ed9d803a 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -551,8 +551,7 @@ test_that("learner state contains internal valid task information", { test_that("validation task with 0 observations", { learner = lrn("classif.debug", validate = "predefined") task = tsk("iris") - task$internal_valid_task = integer(0) - expect_error({learner$train(task)}, "has 0 observations") + expect_warning({task$internal_valid_task = integer(0)}) }) test_that("column info is compared during predict", { diff --git a/tests/testthat/test_as_learner.R b/tests/testthat/test_as_learner.R index 3d34685a1..59cdc865a 100644 --- a/tests/testthat/test_as_learner.R +++ b/tests/testthat/test_as_learner.R @@ -21,3 +21,7 @@ test_that("discard_state", { as_learner(learner3, clone = FALSE, discard_state = TRUE) expect_null(learner3$state) }) + +test_that("error when arguments are misspelled", { + expect_error(as_learner(lrn("classif.rpart"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_as_measure.R b/tests/testthat/test_as_measure.R index 647774fed..584962d66 100644 --- a/tests/testthat/test_as_measure.R +++ b/tests/testthat/test_as_measure.R @@ -14,3 +14,7 @@ test_that("as_measure conversion", { default = as_measures(NULL, task_type = "classif") expect_list(default, types = "Measure") }) + +test_that("error when arguments are misspelled", { + expect_error(as_measure(msr("classif.acc"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_as_resampling.R b/tests/testthat/test_as_resampling.R index 40143a0c4..0718a7ff0 100644 --- a/tests/testthat/test_as_resampling.R +++ b/tests/testthat/test_as_resampling.R @@ -10,3 +10,7 @@ test_that("as_resampling conversion", { expect_list(as_resamplings(resampling), types = "Resampling") expect_list(as_resamplings(list(resampling)), types = "Resampling") }) + +test_that("error when arguments are misspelled", { + expect_error(as_resampling(rsmp("holdout"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_as_task.R b/tests/testthat/test_as_task.R index 13b00ed82..9d763f221 100644 --- a/tests/testthat/test_as_task.R +++ b/tests/testthat/test_as_task.R @@ -22,3 +22,7 @@ test_that("as_task_xx error messages (#944)", { "subset of" ) }) + +test_that("error when arguments are misspelled", { + expect_error(as_task(tsk("iris"), clone2 = TRUE), "Received the following") +}) diff --git a/tests/testthat/test_assertions.R b/tests/testthat/test_assertions.R new file mode 100644 index 000000000..1c77d92b9 --- /dev/null +++ b/tests/testthat/test_assertions.R @@ -0,0 +1,8 @@ +test_that("assert_empty_ellipsis works", { + expect_error(assert_empty_ellipsis(1), "Received 1 unnamed argument") + expect_error(assert_empty_ellipsis(1, 2), "Received 2 unnamed argument") + expect_error(assert_empty_ellipsis(a = 1), "that were unused: a") + expect_error(assert_empty_ellipsis(a = 1, b = 2), "that were unused: a, b") + expect_error(assert_empty_ellipsis(a = 1, b = 1, 2), "1 unnamed, as well as named arguments a, b") + expect_null(assert_empty_ellipsis()) +})