Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix measure bottlenecks #337

Merged
merged 17 commits into from
Nov 20, 2023
Merged
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.3
Version: 0.5.4
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -43,7 +43,7 @@ Depends:
Imports:
checkmate,
data.table,
distr6 (>= 1.8.3),
distr6 (>= 1.8.4),
ggplot2,
mlr3misc (>= 0.7.0),
mlr3viz,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3proba 0.5.4

* Fix bottlenecks in Dcalib and RCLL

# mlr3proba 0.5.3

* Add support for learners that can predict multiple posterior distributions by using `distr6::Arrdist`
Expand Down
16 changes: 13 additions & 3 deletions R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,21 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
ps = self$param_set$values
# initialize buckets
bj = numeric(ps$B)
true_times = prediction$truth[, 1L]
# predict individual probability of death at observed event time
if (inherits(prediction$distr, "VectorDistribution")) {
si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L)))
# bypass distr6 construction if possible
if (inherits(prediction$data$distr, "array")) {
si = diag(t(distr6:::C_Vec_WeightedDiscreteCdf(true_times,
as.numeric(colnames(prediction$data$distr)),
t(1 - prediction$data$distr), FALSE, FALSE
)))
} else {
si = diag(prediction$distr$survival(prediction$truth[, 1L]))
distr = prediction$distr
if (inherits(distr, "VectorDistribution")) {
si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L)))
} else {
si = diag(distr$survival(true_times))
}
}
# remove zeros
si = map_dbl(si, function(.x) max(.x, 1e-5))
Expand Down
56 changes: 46 additions & 10 deletions R/MeasureSurvRCLL.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,53 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL",
event = truth[, 2] == 1
event_times = truth[event, 1]
cens_times = truth[!event, 1]
distr = prediction$distr

if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
# Bypass distr6 construction if underlying distr represented by array
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

if (any(!event)) {
if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
cdf = t(1 - surv)
} else {
cdf = t(1 - surv[!event, ])
}

out[!event] = diag(
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
distr6:::C_Vec_WeightedDiscreteCdf(cens_times, times, cdf = cdf, FALSE, FALSE)
)
}
if (any(event)) {
pdf = distr6:::cdfpdf(1 - surv)
if (sum(event) == 1) { # fix subsetting issue in case of 1 event
pdf = t(pdf)
} else {
pdf = t(pdf[event, ])
}

out[event] = diag(
distr6:::C_Vec_WeightedDiscretePdf(event_times, times, pdf = pdf)
)
}
} else {
distr = prediction$distr

# Splitting in this way bypasses unnecessary distr extraction
if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out = diag(as.matrix(distr$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out = diag(as.matrix(distr$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
}
}

stopifnot(!any(out == -99L)) # safety check
Expand Down
21 changes: 16 additions & 5 deletions R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,23 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) {
}

if (!is.null(pdata$distr)) {
if (inherits(pdata$distr, "matrix")) {
pdata$distr = pdata$distr[keep, , drop = FALSE]
} else { # array
pdata$distr = pdata$distr[keep, , , drop = FALSE]
distr = pdata$distr

if (testDistribution(distr)) { # distribution
ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) &&
length(keep) > 1 # e.g.: Arrdist(1xYxZ) and keep = FALSE
if (ok) {
pdata$distr = distr[keep] # we can subset row/samples like this
} else {
pdata$distr = base::switch(keep, distr) # one distribution only
}
} else {
if (length(dim(distr)) == 2) { # 2d matrix
pdata$distr = distr[keep, , drop = FALSE]
} else { # 3d array
pdata$distr = distr[keep, , , drop = FALSE]
}
}

}

pdata
Expand Down
4 changes: 3 additions & 1 deletion R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv",
}
},
.distrify_survarray = function(x) {
if (inherits(x, "array")) { # can be matrix as well
if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well
# create Matdist or Arrdist (default => median curve)
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
} else {
NULL
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
}
}
)
Expand Down
4 changes: 2 additions & 2 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) {
"response", "distr", "lp", "crank"))
checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids))
checkmate::expect_atomic_vector(p$missing)
if ("distr" %in% p$predict_types) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist"))
if ("distr" %in% p$predict_types && !is.null(p$distr)) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete"))
}
expect_true(inherits(p, "PredictionSurv"))
}
50 changes: 48 additions & 2 deletions tests/testthat/test_PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,21 +176,67 @@ test_that("as_prediction_surv", {
})

test_that("filtering", {
p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task))
p2 = reshape_distr_to_3d(p) # survival array distr
p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) # survival matrix
p2 = reshape_distr_to_3d(p) # survival array
p3 = p$clone()
p4 = p2$clone()
p3$data$distr = p3$distr # Matdist
p4$data$distr = p4$distr # Arrdist

p$filter(c(20, 37, 42))
p2$filter(c(20, 37, 42))
p3$filter(c(20, 37, 42))
p4$filter(c(20, 37, 42))
expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)

expect_set_equal(p$data$row_ids, c(20, 37, 42))
expect_set_equal(p2$data$row_ids, c(20, 37, 42))
expect_set_equal(p3$data$row_ids, c(20, 37, 42))
expect_set_equal(p4$data$row_ids, c(20, 37, 42))
expect_numeric(p$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p2$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p3$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p4$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p$data$lp, any.missing = FALSE, len = 3)
expect_numeric(p2$data$lp, any.missing = FALSE, len = 3)
expect_numeric(p3$data$lp, any.missing = FALSE, len = 3)
expect_numeric(p4$data$lp, any.missing = FALSE, len = 3)
expect_matrix(p$data$distr, nrows = 3)
expect_array(p2$data$distr, d = 3)
expect_equal(nrow(p2$data$distr), 3)
expect_true(inherits(p3$data$distr, "Matdist"))
expect_true(inherits(p4$data$distr, "Arrdist"))

# edge case: filter to 1 observation
p$filter(20)
p2$filter(20)
p3$filter(20)
p4$filter(20)
expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)
expect_matrix(p$data$distr, nrows = 1)
expect_array(p2$data$distr, d = 3)
expect_equal(nrow(p2$data$distr), 1)
expect_true(inherits(p3$data$distr, "WeightedDiscrete")) # from Matdist!
expect_true(inherits(p4$data$distr, "Arrdist")) # remains an Arrdist!

# filter to 0 observations using non-existent (positive) id
p$filter(42)
p2$filter(42)
p3$filter(42)
p4$filter(42)

expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)
expect_null(p$distr)
expect_null(p2$distr)
expect_null(p3$distr)
expect_null(p4$distr)
})