Skip to content

Commit

Permalink
Merge pull request #337 from mlr-org/fix_rcll_bottleneck
Browse files Browse the repository at this point in the history
fix measure bottlenecks
  • Loading branch information
RaphaelS1 authored Nov 20, 2023
2 parents e557cd7 + 148387f commit 083e685
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 66 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.3
Version: 0.5.4
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -43,7 +43,7 @@ Depends:
Imports:
checkmate,
data.table,
distr6 (>= 1.8.3),
distr6 (>= 1.8.4),
ggplot2,
mlr3misc (>= 0.7.0),
mlr3viz,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3proba 0.5.4

* Fix bottlenecks in Dcalib and RCLL

# mlr3proba 0.5.3

* Add support for learners that can predict multiple posterior distributions by using `distr6::Arrdist`
Expand Down
83 changes: 58 additions & 25 deletions R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,25 @@
#' @templateVar fullname MeasureSurvDCalibration
#'
#' @description
#' This calibration method is defined by calculating
#' This calibration method is defined by calculating the following statistic:
#' \deqn{s = B/n \sum_i (P_i - n/B)^2}
#' where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions,
#' and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval
#' [0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1).
#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals),
#' \eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion
#' of observations in the \eqn{i}th interval. An observation is assigned to the
#' \eqn{i}th bucket, if its predicted survival probability at the time of event
#' falls within the corresponding interval.
#' This statistic assumes that censoring time is independent of death time.
#'
#' A model is well-calibrated if `s ~ Unif(B)`, tested with `chisq.test`
#' (`p > 0.05` if well-calibrated).
#' Model `i` is better calibrated than model `j` if `s_i < s_j`.
#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
#' (\eqn{p > 0.05} if well-calibrated).
#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)},
#' meaning that *lower values* of this measure are preferred.
#'
#' @details
#' This measure can either return the test statistic or the p-value from the `chisq.test`.
#' The former is useful for model comparison whereas the latter is useful for determining if a model
#' is well-calibration. If `chisq = FALSE` and `m` is the predicted value then you can manually
#' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`.
#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually
#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`.
#'
#' NOTE: This measure is still experimental both theoretically and in implementation. Results
#' should therefore only be taken as an indicator of performance and not for
Expand All @@ -34,18 +38,29 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @param B (`integer(1)`) \cr
#' Number of buckets to test for uniform predictions over. Default of `10` is recommended by
#' Haider et al. (2020).
#' Number of buckets to test for uniform predictions over.
#' Default of `10` is recommended by Haider et al. (2020).
#' Changing this parameter affects `truncate`.
#' @param chisq (`logical(1)`) \cr
#' If `TRUE` returns the p.value of the corresponding chisq.test instead of the measure.
#' Otherwise this can be performed manually with `pchisq(m, B - 1, lower.tail = FALSE)`.
#' `p > 0.05` indicates well-calibrated.
#' If `TRUE` returns the p-value of the corresponding chisq.test instead of the measure.
#' Default is `FALSE` and returns the statistic `s`.
#' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`.
#' `p > 0.05` indicates a well-calibrated model.
#' @param truncate (`double(1)`) \cr
#' This parameter controls the upper bound of the output statistic,
#' when `chisq` is `FALSE`. The default `truncate` value of \eqn{10}
#' corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets.
#' Values \eqn{>10} translate to even lower p-values and thus less calibrated
#' models. If the number of buckets \eqn{B} changes, you probably will want to
#' change the `truncate` value as well to correspond to the same p-value significance.
#' Initialize with `truncate = Inf` if no truncation is desired.
initialize = function() {
ps = ps(
B = p_int(1, default = 10),
chisq = p_lgl(default = FALSE)
chisq = p_lgl(default = FALSE),
truncate = p_dbl(lower = 0, upper = Inf, default = 10)
)
ps$values = list(B = 10L, chisq = FALSE)
ps$values = list(B = 10L, chisq = FALSE, truncate = 10)

super$initialize(
id = "surv.dcalib",
Expand All @@ -62,18 +77,36 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
private = list(
.score = function(prediction, ...) {
ps = self$param_set$values
B = ps$B

# initialize buckets
bj = numeric(ps$B)
bj = numeric(B)
true_times = prediction$truth[, 1L]

# predict individual probability of death at observed event time
if (inherits(prediction$distr, "VectorDistribution")) {
si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L)))
# bypass distr6 construction if possible
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, times,
cdf = t(1 - surv), FALSE, FALSE))
} else {
si = diag(prediction$distr$survival(prediction$truth[, 1L]))
distr = prediction$distr
if (inherits(distr, c("Matdist", "Arrdist"))) {
si = diag(distr$survival(true_times))
} else { # VectorDistribution or single Distribution, e.g. WeightDisc()
si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L)))
}
}
# remove zeros
si = map_dbl(si, function(.x) max(.x, 1e-5))
# index of associated bucket
js = ceiling(ps$B * si)
js = ceiling(B * si)

# could remove loop for dead observations but needed for censored ones and minimal overhead
# in combining both
Expand All @@ -83,18 +116,18 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
# dead observations contribute 1 to their index
bj[ji] = bj[ji] + 1
} else {
# uncensored observations spread across buckets with most weighting on penultimate
# censored observations spread across buckets with most weighting on penultimate
for (k in seq.int(ji - 1)) {
bj[k] = bj[k] + 1 / (ps$B * si[[i]])
bj[k] = bj[k] + 1 / (B * si[[i]])
}
bj[ji] = bj[ji] + (1 - (ji - 1) / (ps$B * si[[i]]))
bj[ji] = bj[ji] + (1 - (ji - 1) / (B * si[[i]]))
}
}

if (ps$chisq) {
return(stats::chisq.test(bj)$p.value)
} else {
return((ps$B / length(si)) * sum((bj - length(si) / ps$B)^2))
return(min(ps$truncate, (B / length(si)) * sum((bj - length(si) / B)^2)))
}
}
)
Expand Down
56 changes: 46 additions & 10 deletions R/MeasureSurvRCLL.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,53 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL",
event = truth[, 2] == 1
event_times = truth[event, 1]
cens_times = truth[!event, 1]
distr = prediction$distr

if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
# Bypass distr6 construction if underlying distr represented by array
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

if (any(!event)) {
if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored
cdf = as.matrix(1 - surv[!event, ])
} else {
cdf = t(1 - surv[!event, ])
}

out[!event] = diag(
distr6:::C_Vec_WeightedDiscreteCdf(cens_times, times, cdf = cdf, FALSE, FALSE)
)
}
if (any(event)) {
pdf = distr6:::cdfpdf(1 - surv)
if (sum(event) == 1) { # fix subsetting issue in case of 1 event
pdf = as.matrix(pdf[event, ])
} else {
pdf = t(pdf[event, ])
}

out[event] = diag(
distr6:::C_Vec_WeightedDiscretePdf(event_times, times, pdf = pdf)
)
}
} else {
distr = prediction$distr

# Splitting in this way bypasses unnecessary distr extraction
if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out = diag(as.matrix(distr$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out = diag(as.matrix(distr$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
}
}

stopifnot(!any(out == -99L)) # safety check
Expand Down
21 changes: 16 additions & 5 deletions R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,23 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) {
}

if (!is.null(pdata$distr)) {
if (inherits(pdata$distr, "matrix")) {
pdata$distr = pdata$distr[keep, , drop = FALSE]
} else { # array
pdata$distr = pdata$distr[keep, , , drop = FALSE]
distr = pdata$distr

if (testDistribution(distr)) { # distribution
ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) &&
length(keep) > 1 # e.g.: Arrdist(1xYxZ) and keep = FALSE
if (ok) {
pdata$distr = distr[keep] # we can subset row/samples like this
} else {
pdata$distr = base::switch(keep, distr) # one distribution only
}
} else {
if (length(dim(distr)) == 2) { # 2d matrix
pdata$distr = distr[keep, , drop = FALSE]
} else { # 3d array
pdata$distr = distr[keep, , , drop = FALSE]
}
}

}

pdata
Expand Down
4 changes: 3 additions & 1 deletion R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv",
}
},
.distrify_survarray = function(x) {
if (inherits(x, "array")) { # can be matrix as well
if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well
# create Matdist or Arrdist (default => median curve)
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
} else {
NULL
}
}
)
Expand Down
4 changes: 2 additions & 2 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) {
"response", "distr", "lp", "crank"))
checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids))
checkmate::expect_atomic_vector(p$missing)
if ("distr" %in% p$predict_types) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist"))
if ("distr" %in% p$predict_types && !is.null(p$distr)) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete"))
}
expect_true(inherits(p, "PredictionSurv"))
}
43 changes: 29 additions & 14 deletions man/mlr_measures_surv.dcalib.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 083e685

Please sign in to comment.