diff --git a/DESCRIPTION b/DESCRIPTION index 3fa5a5b71..dbd8d014c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mlr3proba Title: Probabilistic Supervised Learning for 'mlr3' -Version: 0.5.3 +Version: 0.5.4 Authors@R: c(person(given = "Raphael", family = "Sonabend", @@ -43,7 +43,7 @@ Depends: Imports: checkmate, data.table, - distr6 (>= 1.8.3), + distr6 (>= 1.8.4), ggplot2, mlr3misc (>= 0.7.0), mlr3viz, diff --git a/NEWS.md b/NEWS.md index 245a29e93..fccf182aa 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# mlr3proba 0.5.4 + +* Fix bottlenecks in Dcalib and RCLL + # mlr3proba 0.5.3 * Add support for learners that can predict multiple posterior distributions by using `distr6::Arrdist` diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R index 954548e29..e7d2d5c99 100644 --- a/R/MeasureSurvDCalibration.R +++ b/R/MeasureSurvDCalibration.R @@ -3,21 +3,25 @@ #' @templateVar fullname MeasureSurvDCalibration #' #' @description -#' This calibration method is defined by calculating +#' This calibration method is defined by calculating the following statistic: #' \deqn{s = B/n \sum_i (P_i - n/B)^2} -#' where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions, -#' and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval -#' [0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1). +#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals), +#' \eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion +#' of observations in the \eqn{i}th interval. An observation is assigned to the +#' \eqn{i}th bucket, if its predicted survival probability at the time of event +#' falls within the corresponding interval. +#' This statistic assumes that censoring time is independent of death time. #' -#' A model is well-calibrated if `s ~ Unif(B)`, tested with `chisq.test` -#' (`p > 0.05` if well-calibrated). -#' Model `i` is better calibrated than model `j` if `s_i < s_j`. +#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test` +#' (\eqn{p > 0.05} if well-calibrated). +#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}, +#' meaning that *lower values* of this measure are preferred. #' #' @details #' This measure can either return the test statistic or the p-value from the `chisq.test`. #' The former is useful for model comparison whereas the latter is useful for determining if a model -#' is well-calibration. If `chisq = FALSE` and `m` is the predicted value then you can manually -#' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`. +#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually +#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`. #' #' NOTE: This measure is still experimental both theoretically and in implementation. Results #' should therefore only be taken as an indicator of performance and not for @@ -34,18 +38,29 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", public = list( #' @description Creates a new instance of this [R6][R6::R6Class] class. #' @param B (`integer(1)`) \cr - #' Number of buckets to test for uniform predictions over. Default of `10` is recommended by - #' Haider et al. (2020). + #' Number of buckets to test for uniform predictions over. + #' Default of `10` is recommended by Haider et al. (2020). + #' Changing this parameter affects `truncate`. #' @param chisq (`logical(1)`) \cr - #' If `TRUE` returns the p.value of the corresponding chisq.test instead of the measure. - #' Otherwise this can be performed manually with `pchisq(m, B - 1, lower.tail = FALSE)`. - #' `p > 0.05` indicates well-calibrated. + #' If `TRUE` returns the p-value of the corresponding chisq.test instead of the measure. + #' Default is `FALSE` and returns the statistic `s`. + #' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`. + #' `p > 0.05` indicates a well-calibrated model. + #' @param truncate (`double(1)`) \cr + #' This parameter controls the upper bound of the output statistic, + #' when `chisq` is `FALSE`. The default `truncate` value of \eqn{10} + #' corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets. + #' Values \eqn{>10} translate to even lower p-values and thus less calibrated + #' models. If the number of buckets \eqn{B} changes, you probably will want to + #' change the `truncate` value as well to correspond to the same p-value significance. + #' Initialize with `truncate = Inf` if no truncation is desired. initialize = function() { ps = ps( B = p_int(1, default = 10), - chisq = p_lgl(default = FALSE) + chisq = p_lgl(default = FALSE), + truncate = p_dbl(lower = 0, upper = Inf, default = 10) ) - ps$values = list(B = 10L, chisq = FALSE) + ps$values = list(B = 10L, chisq = FALSE, truncate = 10) super$initialize( id = "surv.dcalib", @@ -62,18 +77,36 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", private = list( .score = function(prediction, ...) { ps = self$param_set$values + B = ps$B + # initialize buckets - bj = numeric(ps$B) + bj = numeric(B) + true_times = prediction$truth[, 1L] + # predict individual probability of death at observed event time - if (inherits(prediction$distr, "VectorDistribution")) { - si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L))) + # bypass distr6 construction if possible + if (inherits(prediction$data$distr, "array")) { + surv = prediction$data$distr + if (length(dim(surv)) == 3) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } + times = as.numeric(colnames(surv)) + + si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, times, + cdf = t(1 - surv), FALSE, FALSE)) } else { - si = diag(prediction$distr$survival(prediction$truth[, 1L])) + distr = prediction$distr + if (inherits(distr, c("Matdist", "Arrdist"))) { + si = diag(distr$survival(true_times)) + } else { # VectorDistribution or single Distribution, e.g. WeightDisc() + si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L))) + } } # remove zeros si = map_dbl(si, function(.x) max(.x, 1e-5)) # index of associated bucket - js = ceiling(ps$B * si) + js = ceiling(B * si) # could remove loop for dead observations but needed for censored ones and minimal overhead # in combining both @@ -83,18 +116,18 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration", # dead observations contribute 1 to their index bj[ji] = bj[ji] + 1 } else { - # uncensored observations spread across buckets with most weighting on penultimate + # censored observations spread across buckets with most weighting on penultimate for (k in seq.int(ji - 1)) { - bj[k] = bj[k] + 1 / (ps$B * si[[i]]) + bj[k] = bj[k] + 1 / (B * si[[i]]) } - bj[ji] = bj[ji] + (1 - (ji - 1) / (ps$B * si[[i]])) + bj[ji] = bj[ji] + (1 - (ji - 1) / (B * si[[i]])) } } if (ps$chisq) { return(stats::chisq.test(bj)$p.value) } else { - return((ps$B / length(si)) * sum((bj - length(si) / ps$B)^2)) + return(min(ps$truncate, (B / length(si)) * sum((bj - length(si) / B)^2))) } } ) diff --git a/R/MeasureSurvRCLL.R b/R/MeasureSurvRCLL.R index b78225317..0a4ffe051 100644 --- a/R/MeasureSurvRCLL.R +++ b/R/MeasureSurvRCLL.R @@ -73,17 +73,53 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL", event = truth[, 2] == 1 event_times = truth[event, 1] cens_times = truth[!event, 1] - distr = prediction$distr - if (!any(event)) { # all censored - # survival at outcome time (survived *at least* this long) - out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) - } else if (all(event)) { # all uncensored - # pdf at outcome time (survived *this* long) - out[event] = diag(as.matrix(distr[event]$pdf(event_times))) - } else { # mix - out[event] = diag(as.matrix(distr[event]$pdf(event_times))) - out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) + # Bypass distr6 construction if underlying distr represented by array + if (inherits(prediction$data$distr, "array")) { + surv = prediction$data$distr + if (length(dim(surv)) == 3) { + # survival 3d array, extract median + surv = .ext_surv_mat(arr = surv, which.curve = 0.5) + } + times = as.numeric(colnames(surv)) + + if (any(!event)) { + if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored + cdf = as.matrix(1 - surv[!event, ]) + } else { + cdf = t(1 - surv[!event, ]) + } + + out[!event] = diag( + distr6:::C_Vec_WeightedDiscreteCdf(cens_times, times, cdf = cdf, FALSE, FALSE) + ) + } + if (any(event)) { + pdf = distr6:::cdfpdf(1 - surv) + if (sum(event) == 1) { # fix subsetting issue in case of 1 event + pdf = as.matrix(pdf[event, ]) + } else { + pdf = t(pdf[event, ]) + } + + out[event] = diag( + distr6:::C_Vec_WeightedDiscretePdf(event_times, times, pdf = pdf) + ) + } + } else { + distr = prediction$distr + + # Splitting in this way bypasses unnecessary distr extraction + if (!any(event)) { # all censored + # survival at outcome time (survived *at least* this long) + out = diag(as.matrix(distr$survival(cens_times))) + } else if (all(event)) { # all uncensored + # pdf at outcome time (survived *this* long) + out = diag(as.matrix(distr$pdf(event_times))) + } else { # mix + out[event] = diag(as.matrix(distr[event]$pdf(event_times))) + out[!event] = diag(as.matrix(distr[!event]$survival(cens_times))) + } } stopifnot(!any(out == -99L)) # safety check diff --git a/R/PredictionDataSurv.R b/R/PredictionDataSurv.R index b355ec609..acb7ec95c 100644 --- a/R/PredictionDataSurv.R +++ b/R/PredictionDataSurv.R @@ -129,12 +129,23 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) { } if (!is.null(pdata$distr)) { - if (inherits(pdata$distr, "matrix")) { - pdata$distr = pdata$distr[keep, , drop = FALSE] - } else { # array - pdata$distr = pdata$distr[keep, , , drop = FALSE] + distr = pdata$distr + + if (testDistribution(distr)) { # distribution + ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) && + length(keep) > 1 # e.g.: Arrdist(1xYxZ) and keep = FALSE + if (ok) { + pdata$distr = distr[keep] # we can subset row/samples like this + } else { + pdata$distr = base::switch(keep, distr) # one distribution only + } + } else { + if (length(dim(distr)) == 2) { # 2d matrix + pdata$distr = distr[keep, , drop = FALSE] + } else { # 3d array + pdata$distr = distr[keep, , , drop = FALSE] + } } - } pdata diff --git a/R/PredictionSurv.R b/R/PredictionSurv.R index 37c50b1ac..205a0300d 100644 --- a/R/PredictionSurv.R +++ b/R/PredictionSurv.R @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv", } }, .distrify_survarray = function(x) { - if (inherits(x, "array")) { # can be matrix as well + if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well # create Matdist or Arrdist (default => median curve) distr6::as.Distribution(1 - x, fun = "cdf", decorators = c("CoreStatistics", "ExoticStatistics")) + } else { + NULL } } ) diff --git a/inst/testthat/helper_expectations.R b/inst/testthat/helper_expectations.R index 65177d1bb..ea2ece8a8 100644 --- a/inst/testthat/helper_expectations.R +++ b/inst/testthat/helper_expectations.R @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) { "response", "distr", "lp", "crank")) checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids)) checkmate::expect_atomic_vector(p$missing) - if ("distr" %in% p$predict_types) { - expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist")) + if ("distr" %in% p$predict_types && !is.null(p$distr)) { + expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete")) } expect_true(inherits(p, "PredictionSurv")) } diff --git a/man/mlr_measures_surv.dcalib.Rd b/man/mlr_measures_surv.dcalib.Rd index 411931e45..6694eed71 100644 --- a/man/mlr_measures_surv.dcalib.Rd +++ b/man/mlr_measures_surv.dcalib.Rd @@ -5,21 +5,25 @@ \alias{MeasureSurvDCalibration} \title{D-Calibration Survival Measure} \description{ -This calibration method is defined by calculating +This calibration method is defined by calculating the following statistic: \deqn{s = B/n \sum_i (P_i - n/B)^2} -where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions, -and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval -[0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1). +where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals), +\eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion +of observations in the \eqn{i}th interval. An observation is assigned to the +\eqn{i}th bucket, if its predicted survival probability at the time of event +falls within the corresponding interval. +This statistic assumes that censoring time is independent of death time. -A model is well-calibrated if \code{s ~ Unif(B)}, tested with \code{chisq.test} -(\code{p > 0.05} if well-calibrated). -Model \code{i} is better calibrated than model \code{j} if \code{s_i < s_j}. +A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with \code{chisq.test} +(\eqn{p > 0.05} if well-calibrated). +Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}, +meaning that \emph{lower values} of this measure are preferred. } \details{ This measure can either return the test statistic or the p-value from the \code{chisq.test}. The former is useful for model comparison whereas the latter is useful for determining if a model -is well-calibration. If \code{chisq = FALSE} and \code{m} is the predicted value then you can manually -compute the p.value with \code{pchisq(m, B - 1, lower.tail = FALSE)}. +is well-calibrated. If \code{chisq = FALSE} and \code{s} is the predicted value then you can manually +compute the p.value with \code{pchisq(s, B - 1, lower.tail = FALSE)}. NOTE: This measure is still experimental both theoretically and in implementation. Results should therefore only be taken as an indicator of performance and not for @@ -126,13 +130,24 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \if{html}{\out{
}} \describe{ \item{\code{B}}{(\code{integer(1)}) \cr -Number of buckets to test for uniform predictions over. Default of \code{10} is recommended by -Haider et al. (2020).} +Number of buckets to test for uniform predictions over. +Default of \code{10} is recommended by Haider et al. (2020). +Changing this parameter affects \code{truncate}.} \item{\code{chisq}}{(\code{logical(1)}) \cr -If \code{TRUE} returns the p.value of the corresponding chisq.test instead of the measure. -Otherwise this can be performed manually with \code{pchisq(m, B - 1, lower.tail = FALSE)}. -\code{p > 0.05} indicates well-calibrated.} +If \code{TRUE} returns the p-value of the corresponding chisq.test instead of the measure. +Default is \code{FALSE} and returns the statistic \code{s}. +You can manually get the p-value by executing \code{pchisq(s, B - 1, lower.tail = FALSE)}. +\code{p > 0.05} indicates a well-calibrated model.} + +\item{\code{truncate}}{(\code{double(1)}) \cr +This parameter controls the upper bound of the output statistic, +when \code{chisq} is \code{FALSE}. The default \code{truncate} value of \eqn{10} +corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets. +Values \eqn{>10} translate to even lower p-values and thus less calibrated +models. If the number of buckets \eqn{B} changes, you probably will want to +change the \code{truncate} value as well to correspond to the same p-value significance. +Initialize with \code{truncate = Inf} if no truncation is desired.} } \if{html}{\out{
}} } diff --git a/tests/testthat/test_PredictionSurv.R b/tests/testthat/test_PredictionSurv.R index 20872bf08..0e12e137d 100644 --- a/tests/testthat/test_PredictionSurv.R +++ b/tests/testthat/test_PredictionSurv.R @@ -176,21 +176,67 @@ test_that("as_prediction_surv", { }) test_that("filtering", { - p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) - p2 = reshape_distr_to_3d(p) # survival array distr + p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) # survival matrix + p2 = reshape_distr_to_3d(p) # survival array + p3 = p$clone() + p4 = p2$clone() + p3$data$distr = p3$distr # Matdist + p4$data$distr = p4$distr # Arrdist p$filter(c(20, 37, 42)) p2$filter(c(20, 37, 42)) + p3$filter(c(20, 37, 42)) + p4$filter(c(20, 37, 42)) expect_prediction_surv(p) expect_prediction_surv(p2) + expect_prediction_surv(p3) + expect_prediction_surv(p4) expect_set_equal(p$data$row_ids, c(20, 37, 42)) expect_set_equal(p2$data$row_ids, c(20, 37, 42)) + expect_set_equal(p3$data$row_ids, c(20, 37, 42)) + expect_set_equal(p4$data$row_ids, c(20, 37, 42)) expect_numeric(p$data$crank, any.missing = FALSE, len = 3) expect_numeric(p2$data$crank, any.missing = FALSE, len = 3) + expect_numeric(p3$data$crank, any.missing = FALSE, len = 3) + expect_numeric(p4$data$crank, any.missing = FALSE, len = 3) expect_numeric(p$data$lp, any.missing = FALSE, len = 3) expect_numeric(p2$data$lp, any.missing = FALSE, len = 3) + expect_numeric(p3$data$lp, any.missing = FALSE, len = 3) + expect_numeric(p4$data$lp, any.missing = FALSE, len = 3) expect_matrix(p$data$distr, nrows = 3) expect_array(p2$data$distr, d = 3) expect_equal(nrow(p2$data$distr), 3) + expect_true(inherits(p3$data$distr, "Matdist")) + expect_true(inherits(p4$data$distr, "Arrdist")) + + # edge case: filter to 1 observation + p$filter(20) + p2$filter(20) + p3$filter(20) + p4$filter(20) + expect_prediction_surv(p) + expect_prediction_surv(p2) + expect_prediction_surv(p3) + expect_prediction_surv(p4) + expect_matrix(p$data$distr, nrows = 1) + expect_array(p2$data$distr, d = 3) + expect_equal(nrow(p2$data$distr), 1) + expect_true(inherits(p3$data$distr, "WeightedDiscrete")) # from Matdist! + expect_true(inherits(p4$data$distr, "Arrdist")) # remains an Arrdist! + + # filter to 0 observations using non-existent (positive) id + p$filter(42) + p2$filter(42) + p3$filter(42) + p4$filter(42) + + expect_prediction_surv(p) + expect_prediction_surv(p2) + expect_prediction_surv(p3) + expect_prediction_surv(p4) + expect_null(p$distr) + expect_null(p2$distr) + expect_null(p3$distr) + expect_null(p4$distr) }) diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R index 9503a1bb7..5f766256c 100644 --- a/tests/testthat/test_mlr_measures.R +++ b/tests/testthat/test_mlr_measures.R @@ -194,11 +194,15 @@ test_that("rcll works", { t = tsk("rats")$filter(sample(1:300, 50)) l = lrn("surv.kaplan") p = l$train(t)$predict(t) + p2 = p$clone() + p2$data$distr = p2$distr # hack: test score via distribution m = msr("surv.rcll") expect_true(m$minimize) expect_equal(m$range, c(0, Inf)) KMscore = p$score(m) expect_numeric(KMscore) + KMscore2 = p2$score(m) + expect_equal(KMscore, KMscore2) status = t$truth()[,2] row_ids = t$row_ids @@ -207,18 +211,128 @@ test_that("rcll works", { # only censored rats in test set p = l$predict(t, row_ids = cens_ids) - expect_numeric(p$score(m)) - expect_numeric(p$filter(row_ids = cens_ids[1])$score(m)) # 1 test rat + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 censored test rat + p = p$filter(row_ids = cens_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) # only dead rats in test set p = l$predict(t, row_ids = event_ids) - expect_numeric(p$score(m)) - expect_numeric(p$filter(row_ids = event_ids[1])$score(m)) # 1 test rat + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr # Matdist(1xY) + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 dead rat + p = p$filter(row_ids = event_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr[1] # WeightDisc + score2 = p2$score(m) + expect_equal(score, score2) # Cox is better than baseline (Kaplan-Meier) + l2 = lrn("surv.coxph") + p2 = suppressWarnings(l2$train(t)$predict(t)) + expect_true(p2$score(m) < KMscore) + + # Another edge case: some dead rats and 1 only censored + p3 = p2$filter(row_ids = c(event_ids, cens_ids[1])) + score = p3$score(m) + expect_numeric(score) + p3$data$distr = p3$distr + score2 = p3$score(m) + expect_equal(score, score2) +}) + +test_that("dcal works", { + set.seed(1) + t = tsk("rats")$filter(sample(1:300, 50)) l = lrn("surv.coxph") p = suppressWarnings(l$train(t)$predict(t)) - expect_true(p$score(m) < KMscore) + p2 = p$clone() + p2$data$distr = p2$distr # hack: test score via distribution + m = msr("surv.dcalib", truncate = 20) + expect_true(m$minimize) + expect_equal(m$range, c(0, Inf)) + expect_equal(m$param_set$values$B, 10) + expect_equal(m$param_set$values$chisq, FALSE) + expect_equal(m$param_set$values$truncate, 20) + KMscore = p$score(m) + expect_numeric(KMscore) + KMscore2 = p2$score(m) + expect_equal(KMscore, KMscore2) + + status = t$truth()[,2] + row_ids = t$row_ids + cens_ids = row_ids[status == 0] + event_ids = row_ids[status == 1] + + # only censored rats in test set + p = l$predict(t, row_ids = cens_ids) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 censored test rat + p = p$filter(row_ids = cens_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr + score2 = p2$score(m) + expect_equal(score, score2) + + # only dead rats in test set + p = l$predict(t, row_ids = event_ids) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr # Matdist(1xY) + score2 = p2$score(m) + expect_equal(score, score2) + + # 1 dead rat + p = p$filter(row_ids = event_ids[1]) + score = p$score(m) + expect_numeric(score) + p2 = p$clone() # test score via distribution + p2$data$distr = p2$distr[1] # WeightDisc + score2 = p2$score(m) + expect_equal(score, score2) + + # Another edge case: some dead rats and 1 only censored + p = l$predict(t, row_ids = c(event_ids, cens_ids[1])) + score = p$score(m) + expect_numeric(score) + p$data$distr = p$distr + score2 = p$score(m) + expect_equal(score, score2) + expect_true(score > 10) + + score3 = p$score(msr("surv.dcalib")) # default truncate = 10 + expect_equal(unname(score3), 10) + score4 = p$score(msr("surv.dcalib", truncate = 5)) + expect_equal(unname(score4), 5) + score5 = p$score(msr("surv.dcalib", truncate = Inf, B = 20)) # B affects truncate + expect_true(score5 > score) }) test_that("distr measures work with 3d survival array", {