diff --git a/DESCRIPTION b/DESCRIPTION index 3fa5a5b71..dbd8d014c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mlr3proba Title: Probabilistic Supervised Learning for 'mlr3' -Version: 0.5.3 +Version: 0.5.4 Authors@R: c(person(given = "Raphael", family = "Sonabend", @@ -43,7 +43,7 @@ Depends: Imports: checkmate, data.table, - distr6 (>= 1.8.3), + distr6 (>= 1.8.4), ggplot2, mlr3misc (>= 0.7.0), mlr3viz, diff --git a/NEWS.md b/NEWS.md index 245a29e93..fccf182aa 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# mlr3proba 0.5.4 + +* Fix bottlenecks in Dcalib and RCLL + # mlr3proba 0.5.3 * Add support for learners that can predict multiple posterior distributions by using `distr6::Arrdist` diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index 954548e29..e7d2d5c99 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -3,21 +3,25 @@ #' @templateVar fullname MeasureSurvDCalibration #' #' @description -#' This calibration method is defined by calculating +#' This calibration method is defined by calculating the following statistic: #' \deqn{s = B/n \sum_i (P_i - n/B)^2} -#' where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions, -#' and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval -#' [0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1). +#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals), +#' \eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion +#' of observations in the \eqn{i}th interval. An observation is assigned to the +#' \eqn{i}th bucket, if its predicted survival probability at the time of event +#' falls within the corresponding interval. +#' This statistic assumes that censoring time is independent of death time. #' -#' A model is well-calibrated if `s ~ Unif(B)`, tested with `chisq.test` -#' (`p > 0.05` if well-calibrated). -#' Model `i` is better calibrated than model `j` if `s_i < s_j`. +#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test` +#' (\eqn{p > 0.05} if well-calibrated). +#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}, +#' meaning that *lower values* of this measure are preferred. #' #' @details #' This measure can either return the test statistic or the p-value from the `chisq.test`. #' The former is useful for model comparison whereas the latter is useful for determining if a model -#' is well-calibration. If `chisq = FALSE` and `m` is the predicted value then you can manually -#' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`. +#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually +#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`. #' #' NOTE: This measure is still experimental both theoretically and in implementation. Results #' should therefore only be taken as an indicator of performance and not for @@ -34,18 +38,29 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", public = list( #' @description Creates a new instance of this [R6][R6::R6Class] class. #' @param B (`integer(1)`) \cr - #' Number of buckets to test for uniform predictions over. Default of `10` is recommended by - #' Haider et al. (2020). + #' Number of buckets to test for uniform predictions over. + #' Default of `10` is recommended by Haider et al. (2020). + #' Changing this parameter affects `truncate`. #' @param chisq (`logical(1)`) \cr - #' If `TRUE` returns the p.value of the corresponding chisq.test instead of the measure. - #' Otherwise this can be performed manually with `pchisq(m, B - 1, lower.tail = FALSE)`. - #' `p > 0.05` indicates well-calibrated. + #' If `TRUE` returns the p-value of the corresponding chisq.test instead of the measure. + #' Default is `FALSE` and returns the statistic `s`. + #' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`. + #' `p > 0.05` indicates a well-calibrated model. + #' @param truncate (`double(1)`) \cr + #' This parameter controls the upper bound of the output statistic, + #' when `chisq` is `FALSE`. The default `truncate` value of \eqn{10} + #' corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets. + #' Values \eqn{>10} translate to even lower p-values and thus less calibrated + #' models. If the number of buckets \eqn{B} changes, you probably will want to + #' change the `truncate` value as well to correspond to the same p-value significance. + #' Initialize with `truncate = Inf` if no truncation is desired. initialize = function() { ps = ps( B = p_int(1, default = 10), - chisq = p_lgl(default = FALSE) + chisq = p_lgl(default = FALSE), + truncate = p_dbl(lower = 0, upper = Inf, default = 10) ) - ps$values = list(B = 10L, chisq = FALSE) + ps$values = list(B = 10L, chisq = FALSE, truncate = 10) super$initialize( id = "surv.dcalib", @@ -62,18 +77,36 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", private = list( .score = function(prediction, ...) { ps = self$param_set$values + B = ps$B + # initialize buckets - bj = numeric(ps$B) + bj = numeric(B) + true_times = prediction$truth[, 1L] + # predict individual probability of death at observed event time - if (inherits(prediction$distr, "VectorDistribution")) { - si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L))) + # bypass distr6 construction if possible + if (inherits(prediction$data$distr, "array")) { + surv = prediction$data$distr + if (length(dim(surv)) == 3) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } + times = as.numeric(colnames(surv)) + + si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, times, + cdf = t(1 - surv), FALSE, FALSE)) } else { - si = diag(prediction$distr$survival(prediction$truth[, 1L])) + distr = prediction$distr + if (inherits(distr, c("Matdist", "Arrdist"))) { + si = diag(distr$survival(true_times)) + } else { # VectorDistribution or single Distribution, e.g. WeightDisc() + si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L))) + } } # remove zeros si = map_dbl(si, function(.x) max(.x, 1e-5)) # index of associated bucket - js = ceiling(ps$B * si) + js = ceiling(B * si) # could remove loop for dead observations but needed for censored ones and minimal overhead # in combining both @@ -83,18 +116,18 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", # dead observations contribute 1 to their index bj[ji] = bj[ji] + 1 } else { - # uncensored observations spread across buckets with most weighting on penultimate + # censored observations spread across buckets with most weighting on penultimate for (k in seq.int(ji - 1)) { - bj[k] = bj[k] + 1 / (ps$B * si[[i]]) + bj[k] = bj[k] + 1 / (B * si[[i]]) } - bj[ji] = bj[ji] + (1 - (ji - 1) / (ps$B * si[[i]])) + bj[ji] = bj[ji] + (1 - (ji - 1) / (B * si[[i]])) } } if (ps$chisq) { return(stats::chisq.test(bj)$p.value) } else { - return((ps$B / length(si)) * sum((bj - length(si) / ps$B)^2)) + return(min(ps$truncate, (B / length(si)) * sum((bj - length(si) / B)^2))) } } ) diff --git a/R/MeasureSurvRCLL.R b/R/MeasureSurvRCLL.R index b78225317..0a4ffe051 100644 --- a/R/MeasureSurvRCLL.R +++ b/R/MeasureSurvRCLL.R @@ -73,17 +73,53 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL", event = truth[, 2] == 1 event_times = truth[event, 1] cens_times = truth[!event, 1] - distr = prediction$distr - if (!any(event)) { # all censored - # survival at outcome time (survived *at least* this long) - out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) - } else if (all(event)) { # all uncensored - # pdf at outcome time (survived *this* long) - out[event] = diag(as.matrix(distr[event]$pdf(event_times))) - } else { # mix - out[event] = diag(as.matrix(distr[event]$pdf(event_times))) - out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) + # Bypass distr6 construction if underlying distr represented by array + if (inherits(prediction$data$distr, "array")) { + surv = prediction$data$distr + if (length(dim(surv)) == 3) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } + times = as.numeric(colnames(surv)) + + if (any(!event)) { + if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored + cdf = as.matrix(1 - surv[!event, ]) + } else { + cdf = t(1 - surv[!event, ]) + } + + out[!event] = diag( + distr6:::C_Vec_WeightedDiscreteCdf(cens_times, times, cdf = cdf, FALSE, FALSE) + ) + } + if (any(event)) { + pdf = distr6:::cdfpdf(1 - surv) + if (sum(event) == 1) { # fix subsetting issue in case of 1 event + pdf = as.matrix(pdf[event, ]) + } else { + pdf = t(pdf[event, ]) + } + + out[event] = diag( + distr6:::C_Vec_WeightedDiscretePdf(event_times, times, pdf = pdf) + ) + } + } else { + distr = prediction$distr + + # Splitting in this way bypasses unnecessary distr extraction + if (!any(event)) { # all censored + # survival at outcome time (survived *at least* this long) + out = diag(as.matrix(distr$survival(cens_times))) + } else if (all(event)) { # all uncensored + # pdf at outcome time (survived *this* long) + out = diag(as.matrix(distr$pdf(event_times))) + } else { # mix + out[event] = diag(as.matrix(distr[event]$pdf(event_times))) + out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) + } } stopifnot(!any(out == -99L)) # safety check diff --git a/R/PredictionDataSurv.R b/R/PredictionDataSurv.R index b355ec609..acb7ec95c 100644 --- a/R/PredictionDataSurv.R +++ b/R/PredictionDataSurv.R @@ -129,12 +129,23 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) { } if (!is.null(pdata$distr)) { - if (inherits(pdata$distr, "matrix")) { - pdata$distr = pdata$distr[keep, , drop = FALSE] - } else { # array - pdata$distr = pdata$distr[keep, , , drop = FALSE] + distr = pdata$distr + + if (testDistribution(distr)) { # distribution + ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) && + length(keep) > 1 # e.g.: Arrdist(1xYxZ) and keep = FALSE + if (ok) { + pdata$distr = distr[keep] # we can subset row/samples like this + } else { + pdata$distr = base::switch(keep, distr) # one distribution only + } + } else { + if (length(dim(distr)) == 2) { # 2d matrix + pdata$distr = distr[keep, , drop = FALSE] + } else { # 3d array + pdata$distr = distr[keep, , , drop = FALSE] + } } - } pdata diff --git a/R/PredictionSurv.R b/R/PredictionSurv.R index 37c50b1ac..205a0300d 100644 --- a/R/PredictionSurv.R +++ b/R/PredictionSurv.R @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv", } }, .distrify_survarray = function(x) { - if (inherits(x, "array")) { # can be matrix as well + if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well # create Matdist or Arrdist (default => median curve) distr6::as.Distribution(1 - x, fun = "cdf", decorators = c("CoreStatistics", "ExoticStatistics")) + } else { + NULL } } ) diff --git a/inst/testthat/helper_expectations.R b/inst/testthat/helper_expectations.R index 65177d1bb..ea2ece8a8 100644 --- a/inst/testthat/helper_expectations.R +++ b/inst/testthat/helper_expectations.R @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) { "response", "distr", "lp", "crank")) checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids)) checkmate::expect_atomic_vector(p$missing) - if ("distr" %in% p$predict_types) { - expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist")) + if ("distr" %in% p$predict_types && !is.null(p$distr)) { + expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete")) } expect_true(inherits(p, "PredictionSurv")) } diff --git a/man/mlr_measures_surv.dcalib.Rd b/man/mlr_measures_surv.dcalib.Rd index 411931e45..6694eed71 100644 --- a/man/mlr_measures_surv.dcalib.Rd +++ b/man/mlr_measures_surv.dcalib.Rd @@ -5,21 +5,25 @@ \alias{MeasureSurvDCalibration} \title{D-Calibration Survival Measure} \description{ -This calibration method is defined by calculating +This calibration method is defined by calculating the following statistic: \deqn{s = B/n \sum_i (P_i - n/B)^2} -where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions, -and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval -[0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1). +where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals), +\eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion +of observations in the \eqn{i}th interval. An observation is assigned to the +\eqn{i}th bucket, if its predicted survival probability at the time of event +falls within the corresponding interval. +This statistic assumes that censoring time is independent of death time. -A model is well-calibrated if \code{s ~ Unif(B)}, tested with \code{chisq.test} -(\code{p > 0.05} if well-calibrated). -Model \code{i} is better calibrated than model \code{j} if \code{s_i < s_j}. +A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with \code{chisq.test} +(\eqn{p > 0.05} if well-calibrated). +Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}, +meaning that \emph{lower values} of this measure are preferred. } \details{ This measure can either return the test statistic or the p-value from the \code{chisq.test}. The former is useful for model comparison whereas the latter is useful for determining if a model -is well-calibration. If \code{chisq = FALSE} and \code{m} is the predicted value then you can manually -compute the p.value with \code{pchisq(m, B - 1, lower.tail = FALSE)}. +is well-calibrated. If \code{chisq = FALSE} and \code{s} is the predicted value then you can manually +compute the p.value with \code{pchisq(s, B - 1, lower.tail = FALSE)}. NOTE: This measure is still experimental both theoretically and in implementation. Results should therefore only be taken as an indicator of performance and not for @@ -126,13 +130,24 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \if{html}{\out{