Skip to content

Commit

Permalink
add multi arg to survival_prob_coxnet() (#279)
Browse files Browse the repository at this point in the history
* add `multi` arg to `survival_prob_coxnet()`

* add PR number

* Apply suggestions from code review

Co-authored-by: Emil Hvitfeldt <[email protected]>

* `document()`

* Suggestion from code review

---------

Co-authored-by: Emil Hvitfeldt <[email protected]>
  • Loading branch information
hfrick and EmilHvitfeldt authored Dec 20, 2023
1 parent d47936b commit 31a3741
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 24 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

* `extract_fit_engine()` now works properly for proportional hazards models fitted with the `"glmnet"` engine (#266).

* `survival_time_coxnet()` gained a `multi` argument to allow multiple values for `penalty` (#278).
* `survival_time_coxnet()` and `survival_prob_coxnet()` gain a `multi` argument to allow multiple values for `penalty` (#278, #279).


# censored 0.2.0
Expand Down
23 changes: 21 additions & 2 deletions R/aaa_survival_prob.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,31 @@ prob_template <- tibble::tibble(
)


predict_survival_na <- function(eval_time, interval = "none") {
ret <- tibble(.eval_time = eval_time, .pred_survival = NA_real_)
predict_survival_na <- function(eval_time, interval = "none", penalty = NULL) {
if (!is.null(penalty)) {
n_penalty <- length(penalty)
n_eval_time <- length(eval_time)
ret <- tibble::new_tibble(
list(
penalty = rep(penalty, each = n_eval_time),
.eval_time = rep(eval_time, times = n_penalty),
.pred_survival = rep(NA_real_, n_penalty * n_eval_time)
)
)
} else {
ret <- tibble::new_tibble(
list(
.eval_time = eval_time,
.pred_survival = rep(NA_real_, length(eval_time))
)
)
}

if (interval == "confidence") {
ret <- ret %>%
dplyr::mutate(.pred_lower = NA_real_, .pred_upper = NA_real_)
}

ret
}

Expand Down
39 changes: 27 additions & 12 deletions R/proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ get_missings_coxnet <- function(new_x, new_strata) {
#' @param time Deprecated in favor of `eval_time`. A vector of integers for prediction times.
#' @param output One of "surv" or "haz".
#' @param penalty Penalty value(s).
#' @param multi Allow multiple penalty values? Defaults to FALSE.
#' @param ... Options to pass to [survival::survfit()].
#' @return A tibble with a list column of nested tibbles.
#' @keywords internal
Expand All @@ -539,6 +540,7 @@ survival_prob_coxnet <- function(object,
time = deprecated(),
output = "surv",
penalty = NULL,
multi = FALSE,
...) {
if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
Expand All @@ -553,14 +555,18 @@ survival_prob_coxnet <- function(object,
penalty <- object$spec$args$penalty
}

n_penalty <- length(penalty)
if (n_penalty > 1 & !multi) {
rlang::abort("Cannot use multiple penalty values with `multi = FALSE`.")
}

output <- match.arg(output, c("surv", "haz"))
multi <- length(penalty) > 1

new_x <- coxnet_prepare_x(new_data, object)

went_through_formula_interface <- !is.null(object$preproc$coxnet)
if (went_through_formula_interface &&
has_strata(object$formula, object$training_data)) {
has_strata(object$formula, object$training_data)) {
new_strata <- get_strata_glmnet(
object$formula,
data = new_data,
Expand All @@ -577,7 +583,11 @@ survival_prob_coxnet <- function(object,
n_missing <- length(missings_in_new_data)
all_missing <- n_missing == n_obs
if (all_missing) {
ret <- predict_survival_na(eval_time, interval = "none")
if (multi) {
ret <- predict_survival_na(eval_time, interval = "none", penalty = penalty)
} else {
ret <- predict_survival_na(eval_time, interval = "none")
}
ret <- tibble(.pred = rep(list(ret), n_missing))
return(ret)
}
Expand All @@ -597,15 +607,25 @@ survival_prob_coxnet <- function(object,
...
)

if (multi) {
if (length(penalty) > 1) {
res_patched <- purrr::map(
y,
survfit_summary_to_patched_tibble,
index_missing = missings_in_new_data,
eval_time = eval_time,
n_obs = n_obs
)
res <- tibble::tibble(
} else {
res_patched <- survfit_summary_to_patched_tibble(
y,
index_missing = missings_in_new_data,
eval_time = eval_time,
n_obs = n_obs
)
}

if (multi) {
res_formatted <- tibble::tibble(
penalty = penalty,
res_patched = res_patched
) %>%
Expand All @@ -614,16 +634,11 @@ survival_prob_coxnet <- function(object,
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
} else {
res <- survfit_summary_to_patched_tibble(
y,
index_missing = missings_in_new_data,
eval_time = eval_time,
n_obs = n_obs
) %>%
res_formatted <- res_patched %>%
keep_cols(output) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
}

res
res_formatted
}
3 changes: 3 additions & 0 deletions man/survival_prob_coxnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 38 additions & 9 deletions tests/testthat/test-proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ test_that("survival_time_coxnet() works for multiple penalty values", {
pred %>% tidyr::unnest(cols = .pred) %>% dplyr::arrange(penalty) %>% dplyr::pull(.pred_time),
exp_pred
)

})

# prediction: survival ----------------------------------------------------
Expand Down Expand Up @@ -475,6 +474,7 @@ test_that("survival probabilities without strata", {
dplyr::select(penalty, .eval_time, .pred_survival)


skip("temporarily until we call `survival_prob_coxnet()` directly from `multi_predict()`")
pred_multi <- multi_predict(
f_fit,
new_data = new_data_3,
Expand Down Expand Up @@ -569,6 +569,7 @@ test_that("survival probabilities with strata", {
dplyr::arrange(.row, penalty, .eval_time) %>%
dplyr::select(penalty, .eval_time, .pred_survival)

skip("temporarily until we call `survival_prob_coxnet()` directly from `multi_predict()`")
pred_multi <- multi_predict(
f_fit,
new_data = new_data_3,
Expand Down Expand Up @@ -729,8 +730,6 @@ test_that("survival prediction with NA in strata", {
expect_true(all(is.na(f_pred$.pred[[1]]$.pred_survival)))
})



test_that("survival_prob_coxnet() works for single penalty value", {
# single penalty value
pred_penalty <- 0.1
Expand Down Expand Up @@ -819,7 +818,6 @@ test_that("survival_prob_coxnet() works for single penalty value", {
expect_true(all(is.na(prob$.pred_survival)))
})


test_that("survival_prob_coxnet() works for multiple penalty values", {
# multiple penalty values
pred_penalty <- c(0.1, 0.2)
Expand All @@ -845,7 +843,13 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
surv_fit <- survfit(exp_f_fit, newx = as.matrix(lung_pred), s = pred_penalty, x = lung_x, y = lung_y)
surv_fit_summary <- purrr::map(surv_fit, summary, times = pred_time, extend = TRUE)

prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty)
prob <- survival_prob_coxnet(
f_fit,
new_data = lung_pred,
eval_time = pred_time,
penalty = pred_penalty,
multi = TRUE
)
prob_na <- prob$.pred[[2]]
prob_non_na <- prob$.pred[[3]]
# observation in row 15
Expand All @@ -872,7 +876,14 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
surv_fit <- survfit(exp_f_fit, newx = as.matrix(lung_pred), s = pred_penalty, x = lung_x, y = lung_y)
surv_fit_summary <- purrr::map(surv_fit, summary, times = pred_time, extend = TRUE)

prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty)
prob <- survival_prob_coxnet(
f_fit,
new_data = lung_pred,
eval_time = pred_time,
penalty = pred_penalty,
multi = TRUE
)

prob <- tidyr::unnest(prob, cols = .pred)
exp_prob <- purrr::map(surv_fit_summary, purrr::pluck, "surv") %>% unlist()

Expand All @@ -892,9 +903,27 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
# all observations with missings
lung_pred <- lung[c(14, 14), ]

prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty)
prob <- tidyr::unnest(prob, cols = .pred)
expect_true(all(is.na(prob$.pred_survival)))
pred <- survival_prob_coxnet(
f_fit,
new_data = lung_pred,
eval_time = pred_time,
penalty = pred_penalty,
multi = TRUE
)
exp_pred <- rep(
NA_real_,
times = length(pred_penalty) * length(pred_time) * nrow(lung_pred)
)

expect_named(pred, ".pred")
expect_named(pred$.pred[[1]], c("penalty", ".eval_time", ".pred_survival"))
expect_identical(
pred %>%
tidyr::unnest(cols = .pred) %>%
dplyr::arrange(penalty) %>%
dplyr::pull(.pred_survival),
exp_pred
)
})

test_that("can predict for out-of-domain timepoints", {
Expand Down

0 comments on commit 31a3741

Please sign in to comment.