Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multi arg to survival_prob_coxnet() #279

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

* `extract_fit_engine()` now works properly for proportional hazards models fitted with the `"glmnet"` engine (#266).

* `survival_time_coxnet()` gained a `multi` argument to allow multiple values for `penalty` (#278).
* `survival_time_coxnet()` and `survival_prob_coxnet()` gain a `multi` argument to allow multiple values for `penalty` (#278, #279).


# censored 0.2.0
Expand Down
15 changes: 13 additions & 2 deletions R/aaa_survival_prob.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,23 @@ prob_template <- tibble::tibble(
)


predict_survival_na <- function(eval_time, interval = "none") {
ret <- tibble(.eval_time = eval_time, .pred_survival = NA_real_)
predict_survival_na <- function(eval_time, interval = "none", penalty = NULL) {
if (!is.null(penalty)) {
n_penalty <- length(penalty)
ret <- tibble(
penalty = rep(penalty, each = length(eval_time)),
.eval_time = rep(eval_time, times = n_penalty),
.pred_survival = NA_real_
)
hfrick marked this conversation as resolved.
Show resolved Hide resolved
} else {
ret <- tibble(.eval_time = eval_time, .pred_survival = NA_real_)
hfrick marked this conversation as resolved.
Show resolved Hide resolved
}

if (interval == "confidence") {
ret <- ret %>%
dplyr::mutate(.pred_lower = NA_real_, .pred_upper = NA_real_)
}

ret
}

Expand Down
40 changes: 28 additions & 12 deletions R/proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ get_missings_coxnet <- function(new_x, new_strata) {
#' @param time Deprecated in favor of `eval_time`. A vector of integers for prediction times.
#' @param output One of "surv" or "haz".
#' @param penalty Penalty value(s).
#' @param multi Allow multiple penalty values?
hfrick marked this conversation as resolved.
Show resolved Hide resolved
#' @param ... Options to pass to [survival::survfit()].
#' @return A tibble with a list column of nested tibbles.
#' @keywords internal
Expand All @@ -535,6 +536,7 @@ survival_prob_coxnet <- function(object,
time = deprecated(),
output = "surv",
penalty = NULL,
multi = FALSE,
...) {
if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
Expand All @@ -549,14 +551,18 @@ survival_prob_coxnet <- function(object,
penalty <- object$spec$args$penalty
}

n_penalty <- length(penalty)
if (n_penalty > 1 & !multi) {
rlang::abort("Cannot use multiple penalty values with `multi = FALSE`.")
}

output <- match.arg(output, c("surv", "haz"))
multi <- length(penalty) > 1

new_x <- coxnet_prepare_x(new_data, object)

went_through_formula_interface <- !is.null(object$preproc$coxnet)
if (went_through_formula_interface &&
has_strata(object$formula, object$training_data)) {
has_strata(object$formula, object$training_data)) {
new_strata <- get_strata_glmnet(
object$formula,
data = new_data,
Expand All @@ -573,7 +579,11 @@ survival_prob_coxnet <- function(object,
n_missing <- length(missings_in_new_data)
all_missing <- n_missing == n_obs
if (all_missing) {
ret <- predict_survival_na(eval_time, interval = "none")
if (multi) {
ret <- predict_survival_na(eval_time, interval = "none", penalty = penalty)
} else {
ret <- predict_survival_na(eval_time, interval = "none")
}
ret <- tibble(.pred = rep(list(ret), n_missing))
return(ret)
}
Expand All @@ -593,33 +603,39 @@ survival_prob_coxnet <- function(object,
...
)

if (multi) {
if (length(penalty) > 1) {
res_patched <- purrr::map(
y,
survfit_summary_to_patched_tibble,
index_missing = missings_in_new_data,
eval_time = eval_time,
n_obs = n_obs
)
} else {
res_patched <- survfit_summary_to_patched_tibble(
y,
index_missing = missings_in_new_data,
eval_time = eval_time,
n_obs = n_obs
)
}

if (multi) {
res <- tibble::tibble(
penalty = penalty,
res_patched = res_patched
) %>%
tidyr::unnest(cols = res_patched) %>%
tidyr::unnest(cols = res_patched)
res_formatted <- res %>%
keep_cols(output, keep_penalty = TRUE) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
hfrick marked this conversation as resolved.
Show resolved Hide resolved
} else {
res <- survfit_summary_to_patched_tibble(
y,
index_missing = missings_in_new_data,
eval_time = eval_time,
n_obs = n_obs
) %>%
res_formatted <- res_patched %>%
keep_cols(output) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
}

res
res_formatted
}
3 changes: 3 additions & 0 deletions man/survival_prob_coxnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 38 additions & 9 deletions tests/testthat/test-proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ test_that("survival_time_coxnet() works for multiple penalty values", {
pred %>% tidyr::unnest(cols = .pred) %>% dplyr::arrange(penalty) %>% dplyr::pull(.pred_time),
exp_pred
)

})

# prediction: survival ----------------------------------------------------
Expand Down Expand Up @@ -475,6 +474,7 @@ test_that("survival probabilities without strata", {
dplyr::select(penalty, .eval_time, .pred_survival)


skip("temporarily until we call `survival_prob_coxnet()` directly from `multi_predict()`")
pred_multi <- multi_predict(
f_fit,
new_data = new_data_3,
Expand Down Expand Up @@ -569,6 +569,7 @@ test_that("survival probabilities with strata", {
dplyr::arrange(.row, penalty, .eval_time) %>%
dplyr::select(penalty, .eval_time, .pred_survival)

skip("temporarily until we call `survival_prob_coxnet()` directly from `multi_predict()`")
pred_multi <- multi_predict(
f_fit,
new_data = new_data_3,
Expand Down Expand Up @@ -729,8 +730,6 @@ test_that("survival prediction with NA in strata", {
expect_true(all(is.na(f_pred$.pred[[1]]$.pred_survival)))
})



test_that("survival_prob_coxnet() works for single penalty value", {
# single penalty value
pred_penalty <- 0.1
Expand Down Expand Up @@ -819,7 +818,6 @@ test_that("survival_prob_coxnet() works for single penalty value", {
expect_true(all(is.na(prob$.pred_survival)))
})


test_that("survival_prob_coxnet() works for multiple penalty values", {
# multiple penalty values
pred_penalty <- c(0.1, 0.2)
Expand All @@ -845,7 +843,13 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
surv_fit <- survfit(exp_f_fit, newx = as.matrix(lung_pred), s = pred_penalty, x = lung_x, y = lung_y)
surv_fit_summary <- purrr::map(surv_fit, summary, times = pred_time, extend = TRUE)

prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty)
prob <- survival_prob_coxnet(
f_fit,
new_data = lung_pred,
eval_time = pred_time,
penalty = pred_penalty,
multi = TRUE
)
prob_na <- prob$.pred[[2]]
prob_non_na <- prob$.pred[[3]]
# observation in row 15
Expand All @@ -872,7 +876,14 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
surv_fit <- survfit(exp_f_fit, newx = as.matrix(lung_pred), s = pred_penalty, x = lung_x, y = lung_y)
surv_fit_summary <- purrr::map(surv_fit, summary, times = pred_time, extend = TRUE)

prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty)
prob <- survival_prob_coxnet(
f_fit,
new_data = lung_pred,
eval_time = pred_time,
penalty = pred_penalty,
multi = TRUE
)

prob <- tidyr::unnest(prob, cols = .pred)
exp_prob <- purrr::map(surv_fit_summary, purrr::pluck, "surv") %>% unlist()

Expand All @@ -892,9 +903,27 @@ test_that("survival_prob_coxnet() works for multiple penalty values", {
# all observations with missings
lung_pred <- lung[c(14, 14), ]

prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty)
prob <- tidyr::unnest(prob, cols = .pred)
expect_true(all(is.na(prob$.pred_survival)))
pred <- survival_prob_coxnet(
f_fit,
new_data = lung_pred,
eval_time = pred_time,
penalty = pred_penalty,
multi = TRUE
)
exp_pred <- rep(
NA_real_,
times = length(pred_penalty) * length(pred_time) * nrow(lung_pred)
)

expect_named(pred, ".pred")
expect_named(pred$.pred[[1]], c("penalty", ".eval_time", ".pred_survival"))
expect_identical(
pred %>%
tidyr::unnest(cols = .pred) %>%
dplyr::arrange(penalty) %>%
dplyr::pull(.pred_survival),
exp_pred
)
})

test_that("can predict for out-of-domain timepoints", {
Expand Down
Loading