diff --git a/NEWS.md b/NEWS.md index adc1fce..ea42fda 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,7 +4,7 @@ * `extract_fit_engine()` now works properly for proportional hazards models fitted with the `"glmnet"` engine (#266). -* `survival_time_coxnet()` gained a `multi` argument to allow multiple values for `penalty` (#278). +* `survival_time_coxnet()` and `survival_prob_coxnet()` gain a `multi` argument to allow multiple values for `penalty` (#278, #279). # censored 0.2.0 diff --git a/R/aaa_survival_prob.R b/R/aaa_survival_prob.R index ab66f89..e6e19a2 100644 --- a/R/aaa_survival_prob.R +++ b/R/aaa_survival_prob.R @@ -62,12 +62,31 @@ prob_template <- tibble::tibble( ) -predict_survival_na <- function(eval_time, interval = "none") { - ret <- tibble(.eval_time = eval_time, .pred_survival = NA_real_) +predict_survival_na <- function(eval_time, interval = "none", penalty = NULL) { + if (!is.null(penalty)) { + n_penalty <- length(penalty) + n_eval_time <- length(eval_time) + ret <- tibble::new_tibble( + list( + penalty = rep(penalty, each = n_eval_time), + .eval_time = rep(eval_time, times = n_penalty), + .pred_survival = rep(NA_real_, n_penalty * n_eval_time) + ) + ) + } else { + ret <- tibble::new_tibble( + list( + .eval_time = eval_time, + .pred_survival = rep(NA_real_, length(eval_time)) + ) + ) + } + if (interval == "confidence") { ret <- ret %>% dplyr::mutate(.pred_lower = NA_real_, .pred_upper = NA_real_) } + ret } diff --git a/R/proportional_hazards-glmnet.R b/R/proportional_hazards-glmnet.R index 42b9dd2..13bcdf3 100644 --- a/R/proportional_hazards-glmnet.R +++ b/R/proportional_hazards-glmnet.R @@ -524,6 +524,7 @@ get_missings_coxnet <- function(new_x, new_strata) { #' @param time Deprecated in favor of `eval_time`. A vector of integers for prediction times. #' @param output One of "surv" or "haz". #' @param penalty Penalty value(s). +#' @param multi Allow multiple penalty values? Defaults to FALSE. #' @param ... Options to pass to [survival::survfit()]. #' @return A tibble with a list column of nested tibbles. #' @keywords internal @@ -539,6 +540,7 @@ survival_prob_coxnet <- function(object, time = deprecated(), output = "surv", penalty = NULL, + multi = FALSE, ...) { if (lifecycle::is_present(time)) { lifecycle::deprecate_warn( @@ -553,14 +555,18 @@ survival_prob_coxnet <- function(object, penalty <- object$spec$args$penalty } + n_penalty <- length(penalty) + if (n_penalty > 1 & !multi) { + rlang::abort("Cannot use multiple penalty values with `multi = FALSE`.") + } + output <- match.arg(output, c("surv", "haz")) - multi <- length(penalty) > 1 new_x <- coxnet_prepare_x(new_data, object) went_through_formula_interface <- !is.null(object$preproc$coxnet) if (went_through_formula_interface && - has_strata(object$formula, object$training_data)) { + has_strata(object$formula, object$training_data)) { new_strata <- get_strata_glmnet( object$formula, data = new_data, @@ -577,7 +583,11 @@ survival_prob_coxnet <- function(object, n_missing <- length(missings_in_new_data) all_missing <- n_missing == n_obs if (all_missing) { - ret <- predict_survival_na(eval_time, interval = "none") + if (multi) { + ret <- predict_survival_na(eval_time, interval = "none", penalty = penalty) + } else { + ret <- predict_survival_na(eval_time, interval = "none") + } ret <- tibble(.pred = rep(list(ret), n_missing)) return(ret) } @@ -597,7 +607,7 @@ survival_prob_coxnet <- function(object, ... ) - if (multi) { + if (length(penalty) > 1) { res_patched <- purrr::map( y, survfit_summary_to_patched_tibble, @@ -605,7 +615,17 @@ survival_prob_coxnet <- function(object, eval_time = eval_time, n_obs = n_obs ) - res <- tibble::tibble( + } else { + res_patched <- survfit_summary_to_patched_tibble( + y, + index_missing = missings_in_new_data, + eval_time = eval_time, + n_obs = n_obs + ) + } + + if (multi) { + res_formatted <- tibble::tibble( penalty = penalty, res_patched = res_patched ) %>% @@ -614,16 +634,11 @@ survival_prob_coxnet <- function(object, tidyr::nest(.pred = c(-.row)) %>% dplyr::select(-.row) } else { - res <- survfit_summary_to_patched_tibble( - y, - index_missing = missings_in_new_data, - eval_time = eval_time, - n_obs = n_obs - ) %>% + res_formatted <- res_patched %>% keep_cols(output) %>% tidyr::nest(.pred = c(-.row)) %>% dplyr::select(-.row) } - res + res_formatted } diff --git a/man/survival_prob_coxnet.Rd b/man/survival_prob_coxnet.Rd index 21798e8..91611f9 100644 --- a/man/survival_prob_coxnet.Rd +++ b/man/survival_prob_coxnet.Rd @@ -11,6 +11,7 @@ survival_prob_coxnet( time = deprecated(), output = "surv", penalty = NULL, + multi = FALSE, ... ) } @@ -27,6 +28,8 @@ survival_prob_coxnet( \item{penalty}{Penalty value(s).} +\item{multi}{Allow multiple penalty values? Defaults to FALSE.} + \item{...}{Options to pass to \code{\link[survival:survfit]{survival::survfit()}}.} } \value{ diff --git a/tests/testthat/test-proportional_hazards-glmnet.R b/tests/testthat/test-proportional_hazards-glmnet.R index a22eef8..a81cb74 100644 --- a/tests/testthat/test-proportional_hazards-glmnet.R +++ b/tests/testthat/test-proportional_hazards-glmnet.R @@ -391,7 +391,6 @@ test_that("survival_time_coxnet() works for multiple penalty values", { pred %>% tidyr::unnest(cols = .pred) %>% dplyr::arrange(penalty) %>% dplyr::pull(.pred_time), exp_pred ) - }) # prediction: survival ---------------------------------------------------- @@ -475,6 +474,7 @@ test_that("survival probabilities without strata", { dplyr::select(penalty, .eval_time, .pred_survival) + skip("temporarily until we call `survival_prob_coxnet()` directly from `multi_predict()`") pred_multi <- multi_predict( f_fit, new_data = new_data_3, @@ -569,6 +569,7 @@ test_that("survival probabilities with strata", { dplyr::arrange(.row, penalty, .eval_time) %>% dplyr::select(penalty, .eval_time, .pred_survival) + skip("temporarily until we call `survival_prob_coxnet()` directly from `multi_predict()`") pred_multi <- multi_predict( f_fit, new_data = new_data_3, @@ -729,8 +730,6 @@ test_that("survival prediction with NA in strata", { expect_true(all(is.na(f_pred$.pred[[1]]$.pred_survival))) }) - - test_that("survival_prob_coxnet() works for single penalty value", { # single penalty value pred_penalty <- 0.1 @@ -819,7 +818,6 @@ test_that("survival_prob_coxnet() works for single penalty value", { expect_true(all(is.na(prob$.pred_survival))) }) - test_that("survival_prob_coxnet() works for multiple penalty values", { # multiple penalty values pred_penalty <- c(0.1, 0.2) @@ -845,7 +843,13 @@ test_that("survival_prob_coxnet() works for multiple penalty values", { surv_fit <- survfit(exp_f_fit, newx = as.matrix(lung_pred), s = pred_penalty, x = lung_x, y = lung_y) surv_fit_summary <- purrr::map(surv_fit, summary, times = pred_time, extend = TRUE) - prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty) + prob <- survival_prob_coxnet( + f_fit, + new_data = lung_pred, + eval_time = pred_time, + penalty = pred_penalty, + multi = TRUE + ) prob_na <- prob$.pred[[2]] prob_non_na <- prob$.pred[[3]] # observation in row 15 @@ -872,7 +876,14 @@ test_that("survival_prob_coxnet() works for multiple penalty values", { surv_fit <- survfit(exp_f_fit, newx = as.matrix(lung_pred), s = pred_penalty, x = lung_x, y = lung_y) surv_fit_summary <- purrr::map(surv_fit, summary, times = pred_time, extend = TRUE) - prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty) + prob <- survival_prob_coxnet( + f_fit, + new_data = lung_pred, + eval_time = pred_time, + penalty = pred_penalty, + multi = TRUE + ) + prob <- tidyr::unnest(prob, cols = .pred) exp_prob <- purrr::map(surv_fit_summary, purrr::pluck, "surv") %>% unlist() @@ -892,9 +903,27 @@ test_that("survival_prob_coxnet() works for multiple penalty values", { # all observations with missings lung_pred <- lung[c(14, 14), ] - prob <- survival_prob_coxnet(f_fit, new_data = lung_pred, eval_time = pred_time, penalty = pred_penalty) - prob <- tidyr::unnest(prob, cols = .pred) - expect_true(all(is.na(prob$.pred_survival))) + pred <- survival_prob_coxnet( + f_fit, + new_data = lung_pred, + eval_time = pred_time, + penalty = pred_penalty, + multi = TRUE + ) + exp_pred <- rep( + NA_real_, + times = length(pred_penalty) * length(pred_time) * nrow(lung_pred) + ) + + expect_named(pred, ".pred") + expect_named(pred$.pred[[1]], c("penalty", ".eval_time", ".pred_survival")) + expect_identical( + pred %>% + tidyr::unnest(cols = .pred) %>% + dplyr::arrange(penalty) %>% + dplyr::pull(.pred_survival), + exp_pred + ) }) test_that("can predict for out-of-domain timepoints", {