Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for survival 3.7-0 #321

Merged
merged 8 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ BugReports: https://github.com/tidymodels/censored/issues
Depends:
parsnip (>= 1.1.0),
R (>= 3.5.0),
survival (>= 3.3-1)
survival (>= 3.7-0)
Imports:
cli,
dials,
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# censored (development version)

* censored now depends on survival >= 3.7-0 which allows us to use it also for predictions of survival probabilities at infinite evaluation time points. This means that: Survival probabilities at `eval_time = Inf` are now not always set to 0 and confidence intervals at infinite evaluation times are now not always set to `NA`. This applies to `proportional_hazards()`and `bag_tree()` models as well as models with the `partykit` engine, `decision_tree()` and `rand_forest()` (#320).


# censored 0.3.1

* Internal changes to the `predict()` methods for flexsurv models, in preparation for the upcoming flexsurv release (#317).
Expand Down
8 changes: 2 additions & 6 deletions R/aaa_survival_prob.R
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,6 @@ survfit_summary_to_patched_tibble <- function(object, index_missing, eval_time,
object %>%
summary(times = eval_time, extend = TRUE) %>%
survfit_summary_typestable() %>%
survfit_summary_patch_infinite_time(eval_time = eval_time) %>%
survfit_summary_restore_time_order(eval_time = eval_time) %>%
survfit_summary_patch_missings(
index_missing = index_missing,
eval_time = eval_time,
Expand All @@ -275,7 +273,7 @@ survfit_summary_to_patched_tibble <- function(object, index_missing, eval_time,
}

combine_list_of_survfit_summary <- function(object, eval_time) {
n_time <- sum(is.finite(eval_time))
n_time <- length(eval_time)
elements <- available_survfit_summary_elements(object[[1]])

ret <- list()
Expand All @@ -288,11 +286,9 @@ combine_list_of_survfit_summary <- function(object, eval_time) {
ret
}

survfit_summary_patch <- function(object, index_missing, eval_time, n_obs) {
survfit_summary_patch <- function(object, index_missing, eval_time, n_obs) {11
object %>%
survfit_summary_typestable() %>%
survfit_summary_patch_infinite_time(eval_time = eval_time) %>%
survfit_summary_restore_time_order(eval_time = eval_time) %>%
survfit_summary_patch_missings(
index_missing = index_missing,
eval_time = eval_time,
Expand Down
75 changes: 0 additions & 75 deletions tests/testthat/test-aaa_survival_prob.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,81 +175,6 @@ test_that("survfit_summary_typestable() works for survival prob - stratified (co
expect_equal(dim(prob), c(length(pred_time), 1))
})

test_that("survfit_summary_patch_infinite_time() works (coxph)", {
lung_pred <- tidyr::drop_na(lung)
pred_time <- c(-Inf, 0, Inf, 1022, -Inf)

mod <- coxph(Surv(time, status) ~ ., data = lung)
surv_fit <- survfit(mod, newdata = lung_pred)
surv_fit_summary <- summary(surv_fit, times = pred_time, extend = TRUE)

surv_fit_summary_patched <- surv_fit_summary %>%
survfit_summary_typestable() %>%
survfit_summary_patch_infinite_time(eval_time = pred_time)

prob <- surv_fit_summary_patched$surv
exp_prob <- surv_fit_summary$surv

expect_equal(prob[c(3, 4), ], exp_prob)
expect_equal(
prob[c(1, 2), ],
matrix(1, nrow = 2, ncol = nrow(lung_pred)),
ignore_attr = "dimnames"
)
expect_equal(unname(prob[5, ]), rep(0, nrow(lung_pred)))
})

test_that("survfit_summary_patch_infinite_time() works (coxnet)", {
skip_if_not_installed("glmnet")

pred_time <- c(-Inf, 0, Inf, 1022, -Inf)

lung2 <- lung[-14, ]
lung_x <- as.matrix(lung2[, c("age", "ph.ecog")])
lung_y <- Surv(lung2$time, lung2$status)
lung_pred <- lung_x[1:5, ]

mod <- suppressWarnings(
glmnet::glmnet(x = lung_x, y = lung_y, family = "cox")
)
surv_fit <- survfit(mod, newx = lung_pred, s = 0.1, x = lung_x, y = lung_y)
surv_fit_summary <- summary(surv_fit, times = pred_time, extend = TRUE)

surv_fit_summary_patched <- surv_fit_summary %>%
survfit_summary_typestable() %>%
survfit_summary_patch_infinite_time(eval_time = pred_time)

prob <- surv_fit_summary_patched$surv
exp_prob <- surv_fit_summary$surv

expect_equal(prob[c(3, 4), ], exp_prob)
expect_equal(
prob[c(1, 2), ],
matrix(1, nrow = 2, ncol = nrow(lung_pred)),
ignore_attr = "dimnames"
)
expect_equal(unname(prob[5, ]), rep(0, nrow(lung_pred)))
})

test_that("survfit_summary_restore_time_order() works", {
lung_pred <- tidyr::drop_na(lung)
pred_time <- c(300, 100, 200)

mod <- coxph(Surv(time, status) ~ ., data = lung)
surv_fit <- survfit(mod, newdata = lung_pred)
surv_fit_summary <- summary(surv_fit, times = pred_time, extend = TRUE)

surv_fit_summary_patched <- surv_fit_summary %>%
survfit_summary_typestable() %>%
survfit_summary_patch_infinite_time(eval_time = pred_time) %>%
survfit_summary_restore_time_order(eval_time = pred_time)

prob <- surv_fit_summary_patched$surv
exp_prob <- surv_fit_summary$surv

expect_equal(prob, exp_prob[c(3, 1:2), ])
})

test_that("survfit_summary_patch_missings() works", {
pred_time <- c(100, 200)
mod <- coxph(Surv(time, status) ~ age + ph.ecog, data = lung)
Expand Down
21 changes: 2 additions & 19 deletions tests/testthat/test-bag_tree-rpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,7 @@ test_that("survival_prob_survbagg() works", {
expect_true(all(is.na(prob_na$.pred_survival)))
# for non-missings, get probs right
expect_equal(prob_non_na$.eval_time, pred_time)
expect_equal(
prob_non_na$.pred_survival[c(1, 4)],
c(1, 0)
)
expect_equal(
prob_non_na %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(.eval_time) %>%
dplyr::pull(.pred_survival),
exp_prob_non_na
)
expect_equal(prob_non_na$.pred_survival, exp_prob_non_na)

# single observation
lung_pred <- lung[13, ]
Expand All @@ -209,14 +199,7 @@ test_that("survival_prob_survbagg() works", {
prob <- tidyr::unnest(prob, cols = .pred)
exp_prob <- surv_fit_summary$surv

expect_equal(
prob$.pred_survival[c(1, 4)],
c(1, 0)
)
expect_equal(
prob %>% dplyr::filter(is.finite(.eval_time)) %>% dplyr::pull(.pred_survival),
as.vector(exp_prob)
)
expect_equal(prob$.pred_survival, as.vector(exp_prob))

# all observations with missings
lung_pred <- lung[c(14, 14), ]
Expand Down
12 changes: 4 additions & 8 deletions tests/testthat/test-boost_tree-mboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ test_that("survival_curve_to_prob() works", {
event_times = surv_fit$time,
survival_prob = surv_fit$surv
)
expect_equal(prob[c(2, 3, 1), ], exp_prob)
expect_equal(prob, exp_prob)

# can handle out of range time (before and after events)
pred_time_extend <- c(-2, 0, 3000)
Expand All @@ -146,19 +146,15 @@ test_that("survival_curve_to_prob() works", {

# can handle infinite time
pred_time_inf <- c(-Inf, 0, Inf, 1022, -Inf)
exp_prob <- summary(surv_fit, time = pred_time_inf)$surv
exp_prob <- summary(surv_fit, time = pred_time_inf, extend = TRUE)$surv
prob <- survival_curve_to_prob(
eval_time = pred_time_inf,
event_times = surv_fit$time,
survival_prob = surv_fit$surv
)
expect_equal(nrow(prob), length(pred_time_inf))
expect_equal(prob[c(2, 4), ], exp_prob)
expect_equal(
prob[c(1, 5), ],
matrix(1, nrow = 2, ncol = nrow(lung_pred)),
ignore_attr = "dimnames"
)

expect_equal(prob[-3, ], exp_prob[-3, ])
expect_equal(
prob[3, ] %>% unname(),
rep(0, nrow(lung_pred))
Expand Down
24 changes: 4 additions & 20 deletions tests/testthat/test-partykit.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ test_that("survival_prob_partykit() works for ctree", {
rep(pred_time, nrow(lung_pred))
)
expect_equal(
prob %>% dplyr::filter(is.infinite(.eval_time)) %>% dplyr::pull(.pred_survival),
rep(c(1, 0), nrow(lung_pred))
)
expect_equal(
prob %>% dplyr::filter(is.finite(.eval_time)) %>% dplyr::pull(.pred_survival),
prob$.pred_survival,
as.vector(exp_prob)
)

Expand All @@ -57,11 +53,7 @@ test_that("survival_prob_partykit() works for ctree", {
exp_prob <- surv_fit_summary$surv

expect_equal(
prob %>% dplyr::filter(is.infinite(.eval_time)) %>% dplyr::pull(.pred_survival),
c(1, 0)
)
expect_equal(
prob %>% dplyr::filter(is.finite(.eval_time)) %>% dplyr::pull(.pred_survival),
prob$.pred_survival,
as.vector(exp_prob)
)

Expand Down Expand Up @@ -111,11 +103,7 @@ test_that("survival_prob_partykit() works for cforest", {
rep(pred_time, nrow(lung_pred))
)
expect_equal(
prob %>% dplyr::filter(is.infinite(.eval_time)) %>% dplyr::pull(.pred_survival),
rep(c(1, 0), nrow(lung_pred))
)
expect_equal(
prob %>% dplyr::filter(is.finite(.eval_time)) %>% dplyr::pull(.pred_survival),
prob$.pred_survival,
as.vector(exp_prob)
)

Expand All @@ -137,11 +125,7 @@ test_that("survival_prob_partykit() works for cforest", {
exp_prob <- surv_fit_summary$surv

expect_equal(
prob %>% dplyr::filter(is.infinite(.eval_time)) %>% dplyr::pull(.pred_survival),
c(1, 0)
)
expect_equal(
prob %>% dplyr::filter(is.finite(.eval_time)) %>% dplyr::pull(.pred_survival),
prob$.pred_survival,
as.vector(exp_prob)
)

Expand Down
48 changes: 4 additions & 44 deletions tests/testthat/test-proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -776,17 +776,7 @@ test_that("survival_prob_coxnet() works for single penalty value", {
expect_true(all(is.na(prob_na$.pred_survival)))
# for non-missings, get probs right
expect_equal(prob_non_na$.eval_time, pred_time)
expect_equal(
prob_non_na$.pred_survival[c(1, 4)],
c(1, 0)
)
expect_equal(
prob_non_na %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(.eval_time) %>%
dplyr::pull(.pred_survival),
exp_prob_non_na
)
expect_equal(prob_non_na$.pred_survival, exp_prob_non_na)

# single observation
lung_pred <- lung[13, c("age", "ph.ecog")]
Expand All @@ -797,17 +787,7 @@ test_that("survival_prob_coxnet() works for single penalty value", {
prob <- tidyr::unnest(prob, cols = .pred)
exp_prob <- surv_fit_summary$surv

expect_equal(
prob$.pred_survival[c(1, 4)],
c(1, 0)
)
expect_equal(
prob %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(.eval_time) %>%
dplyr::pull(.pred_survival),
exp_prob
)
expect_equal(prob$.pred_survival, exp_prob)

# all observations with missings
lung_pred <- lung[c(14, 14), ]
Expand Down Expand Up @@ -858,17 +838,7 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
expect_true(all(is.na(prob_na$.pred_survival)))
# for non-missings, get probs right
expect_equal(prob_non_na$.eval_time, rep(pred_time, length(pred_penalty)))
expect_equal(
prob_non_na$.pred_survival[c(1, 4, 7, 10)],
c(1, 0, 1, 0)
)
expect_equal(
prob_non_na %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(penalty, .eval_time) %>%
dplyr::pull(.pred_survival),
exp_prob
)
expect_equal(prob_non_na$.pred_survival, exp_prob)

# single observation
lung_pred <- lung[13, c("age", "ph.ecog")]
Expand All @@ -887,17 +857,7 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
exp_prob <- purrr::map(surv_fit_summary, purrr::pluck, "surv") %>% unlist()

expect_equal(prob_non_na$.eval_time, rep(pred_time, length(pred_penalty)))
expect_equal(
prob$.pred_survival[c(1, 4, 7, 10)],
c(1, 0, 1, 0)
)
expect_equal(
prob %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(penalty, .eval_time) %>%
dplyr::pull(.pred_survival),
exp_prob
)
expect_equal(prob$.pred_survival, exp_prob)

# all observations with missings
lung_pred <- lung[c(14, 14), ]
Expand Down
36 changes: 2 additions & 34 deletions tests/testthat/test-proportional_hazards-survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -361,17 +361,7 @@ test_that("survival_prob_coxph() works", {
expect_true(all(is.na(prob_na$.pred_survival)))
# for non-missings, get probs right
expect_equal(prob_non_na$.eval_time, pred_time)
expect_equal(
prob_non_na$.pred_survival[c(1, 4)],
c(1, 0)
)
expect_equal(
prob_non_na %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(.eval_time) %>%
dplyr::pull(.pred_survival),
exp_prob_non_na
)
expect_equal(prob_non_na$.pred_survival, exp_prob_non_na)

# single observation
lung_pred <- lung[13, ]
Expand All @@ -382,17 +372,7 @@ test_that("survival_prob_coxph() works", {
prob <- tidyr::unnest(prob, cols = .pred)
exp_prob <- surv_fit_summary$surv

expect_equal(
prob$.pred_survival[c(1, 4)],
c(1, 0)
)
expect_equal(
prob %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(.eval_time) %>%
dplyr::pull(.pred_survival),
exp_prob
)
expect_equal(prob$.pred_survival, exp_prob)

# all observations with missings
lung_pred <- lung[c(14, 14), ]
Expand Down Expand Up @@ -428,25 +408,13 @@ test_that("survival_prob_coxph() works with confidence intervals", {
expect_true(all(is.na(pred_na$.pred_lower)))
expect_true(all(is.na(pred_na$.pred_upper)))
# for non-missings, get interval right
expect_equal(
pred_non_na$.pred_lower[c(1, 4)],
rep(NA_real_, 2)
)
expect_equal(
pred_non_na$.pred_upper[c(1, 4)],
rep(NA_real_, 2)
)
expect_equal(
pred_non_na %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(.eval_time) %>%
dplyr::pull(.pred_lower),
exp_pred$lower[, 2] # observation in row 15
)
expect_equal(
pred_non_na %>%
dplyr::filter(is.finite(.eval_time)) %>%
dplyr::arrange(.eval_time) %>%
dplyr::pull(.pred_upper),
exp_pred$upper[, 2] # observation in row 15
)
Expand Down
Loading