Skip to content

Commit

Permalink
Merge pull request #278 from tidymodels/survival_time_coxnet-multi
Browse files Browse the repository at this point in the history
Make `survival_time_coxnet()` work with multiple penalty values
  • Loading branch information
hfrick authored Dec 15, 2023
2 parents e4e2497 + 344352d commit 7b2613d
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 9 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* `extract_fit_engine()` now works properly for proportional hazards models fitted with the `"glmnet"` engine (#266).

* `survival_time_coxnet()` gained a `multi` argument to allow multiple values for `penalty` (#278).


# censored 0.2.0

Expand Down
52 changes: 44 additions & 8 deletions R/proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {
#' @param object A fitted `_coxnet` object.
#' @param new_data Data for prediction.
#' @param penalty Penalty value(s).
#' @param multi Allow multiple penalty values?
#' @param ... Options to pass to [survival::survfit()].
#' @return A vector.
#' @keywords internal
Expand All @@ -409,12 +410,18 @@ multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {
#' set_engine("glmnet") %>%
#' fit(Surv(time, status) ~ ., data = lung)
#' survival_time_coxnet(cox_mod, new_data = lung[1:3, ], penalty = 0.1)
survival_time_coxnet <- function(object, new_data, penalty = NULL, ...) {
survival_time_coxnet <- function(object, new_data, penalty = NULL, multi = FALSE, ...) {
n_obs <- nrow(new_data)
n_penalty <- length(penalty)
if (n_penalty > 1 & !multi) {
rlang::abort("Cannot use multiple penalty values with `multi = FALSE`.")
}

new_x <- coxnet_prepare_x(new_data, object)

went_through_formula_interface <- !is.null(object$preproc$coxnet)
if (went_through_formula_interface &&
has_strata(object$formula, object$training_data)) {
has_strata(object$formula, object$training_data)) {
new_strata <- get_strata_glmnet(
object$formula,
data = new_data,
Expand All @@ -426,11 +433,20 @@ survival_time_coxnet <- function(object, new_data, penalty = NULL, ...) {

missings_in_new_data <- get_missings_coxnet(new_x, new_strata)
if (!is.null(missings_in_new_data)) {
n_total <- nrow(new_data)
n_missing <- length(missings_in_new_data)
all_missing <- n_missing == n_total
all_missing <- n_missing == n_obs
if (all_missing) {
ret <- rep(NA, n_missing)
if (multi) {
ret <- tibble::tibble(
penalty = rep(penalty, each = n_obs),
.pred_time = NA,
.row = rep(seq_len(n_obs), times = n_penalty)
) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
} else {
ret <- rep(NA, n_missing)
}
return(ret)
}
new_x <- new_x[-missings_in_new_data, , drop = FALSE]
Expand All @@ -449,7 +465,28 @@ survival_time_coxnet <- function(object, new_data, penalty = NULL, ...) {
...
)

tabs <- summary(y)$table
if (length(penalty) > 1) {
res <- purrr::map(y, extract_patched_survival_time, missings_in_new_data, n_obs) %>%
purrr::list_c()
} else {
res <- extract_patched_survival_time(y, missings_in_new_data, n_obs)
}

if (multi) {
res <- tibble::tibble(
penalty = rep(penalty, each = n_obs),
.pred_time = res,
.row = rep(seq_len(n_obs), times = n_penalty)
) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
}

res
}

extract_patched_survival_time <- function(survfit_object, missings_in_new_data, n_obs) {
tabs <- summary(survfit_object)$table
if (is.matrix(tabs)) {
colnames(tabs) <- gsub("[[:punct:]]", "", colnames(tabs))
res <- unname(tabs[, "rmean"])
Expand All @@ -458,14 +495,13 @@ survival_time_coxnet <- function(object, new_data, penalty = NULL, ...) {
res <- unname(tabs["rmean"])
}
if (!is.null(missings_in_new_data)) {
index_with_na <- rep(NA, n_total)
index_with_na <- rep(NA, n_obs)
index_with_na[-missings_in_new_data] <- seq_along(res)
res <- res[index_with_na]
}
res
}


get_missings_coxnet <- function(new_x, new_strata) {
missings_logical <- apply(cbind(new_x, new_strata), MARGIN = 1, anyNA)
if (!any(missings_logical)) {
Expand Down
4 changes: 3 additions & 1 deletion man/survival_time_coxnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

180 changes: 180 additions & 0 deletions tests/testthat/test-proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,186 @@ test_that("time predictions with NA in strata", {
expect_identical(which(is.na(f_pred$.pred_time)), 1L)
})

test_that("survival_time_coxnet() works for single penalty value", {
# single penalty value
pred_penalty <- 0.1

lung2 <- lung[-14, ]
lung_x <- as.matrix(lung2[, c("age", "ph.ecog")])
lung_y <- Surv(lung2$time, lung2$status)

exp_f_fit <- suppressWarnings(
glmnet::glmnet(x = lung_x, y = lung_y, family = "cox")
)
f_fit <- suppressWarnings(
proportional_hazards(penalty = 0.1) %>%
set_engine("glmnet") %>%
fit(Surv(time, status) ~ age + ph.ecog, data = lung2)
)

# multiple observations (with 1 missing)
lung_pred <- lung[13:15, c("age", "ph.ecog")]
surv_fit <- survfit(
exp_f_fit,
newx = as.matrix(lung_pred),
s = pred_penalty,
x = lung_x,
y = lung_y
)

pred <- survival_time_coxnet(
f_fit,
new_data = lung_pred,
penalty = pred_penalty
)
exp_pred <- extract_patched_survival_time(
surv_fit,
missings_in_new_data = 2,
n_obs = 3
)

expect_identical(pred, exp_pred)
expect_identical(which(is.na(pred)), 2L)

# single observation
lung_pred <- lung[13, c("age", "ph.ecog")]
surv_fit <- survfit(
exp_f_fit,
newx = as.matrix(lung_pred),
s = pred_penalty,
x = lung_x,
y = lung_y
)

pred <- survival_time_coxnet(
f_fit,
new_data = lung_pred,
penalty = pred_penalty
)
exp_pred <- extract_patched_survival_time(
surv_fit,
missings_in_new_data = NULL,
n_obs = 1
)

expect_identical(pred, exp_pred)

# all observations with missings
lung_pred <- lung[c(14, 14), ]

pred <- survival_time_coxnet(
f_fit,
new_data = lung_pred,
penalty = pred_penalty
)
exp_pred <- rep(NA, 2)

expect_identical(pred, exp_pred)
})

test_that("survival_time_coxnet() works for multiple penalty values", {
# multiple penalty values
pred_penalty <- c(0.1, 0.2)

lung2 <- lung[-14, ]
lung_x <- as.matrix(lung2[, c("age", "ph.ecog")])
lung_y <- Surv(lung2$time, lung2$status)

exp_f_fit <- suppressWarnings(
glmnet::glmnet(x = lung_x, y = lung_y, family = "cox")
)
f_fit <- suppressWarnings(
proportional_hazards(penalty = 0.1) %>%
set_engine("glmnet") %>%
fit(Surv(time, status) ~ age + ph.ecog, data = lung2)
)

# multiple observations (with 1 missing)
lung_pred <- lung[13:15, c("age", "ph.ecog")]
surv_fit <- survfit(
exp_f_fit,
newx = as.matrix(lung_pred),
s = pred_penalty,
x = lung_x,
y = lung_y
)

pred <- survival_time_coxnet(
f_fit,
new_data = lung_pred,
penalty = pred_penalty,
multi = TRUE
)
exp_pred <- purrr::map(
surv_fit,
extract_patched_survival_time,
missings_in_new_data = 2,
n_obs = 3
) %>%
purrr::list_c()

expect_named(pred, ".pred")
expect_named(pred$.pred[[1]], c("penalty", ".pred_time"))
expect_identical(
pred %>% tidyr::unnest(cols = .pred) %>% dplyr::arrange(penalty) %>% dplyr::pull(.pred_time),
exp_pred
)
expect_identical(
pred$.pred[[2]] %>% dplyr::pull(.pred_time),
rep(NA_real_, 2)
)

# single observation
lung_pred <- lung[13, c("age", "ph.ecog")]
surv_fit <- survfit(
exp_f_fit,
newx = as.matrix(lung_pred),
s = pred_penalty,
x = lung_x,
y = lung_y
)

pred <- survival_time_coxnet(
f_fit,
new_data = lung_pred,
penalty = pred_penalty,
multi = TRUE
)
exp_pred <- purrr::map(
surv_fit,
extract_patched_survival_time,
missings_in_new_data = NULL,
n_obs = 1
) %>%
purrr::list_c()

expect_named(pred, ".pred")
expect_named(pred$.pred[[1]], c("penalty", ".pred_time"))
expect_identical(
pred %>% tidyr::unnest(cols = .pred) %>% dplyr::arrange(penalty) %>% dplyr::pull(.pred_time),
exp_pred
)

# all observations with missings
lung_pred <- lung[c(14, 14), ]

pred <- survival_time_coxnet(
f_fit,
new_data = lung_pred,
penalty = pred_penalty,
multi = TRUE
)
exp_pred <- rep(NA, 4)

expect_named(pred, ".pred")
expect_named(pred$.pred[[1]], c("penalty", ".pred_time"))
expect_identical(
pred %>% tidyr::unnest(cols = .pred) %>% dplyr::arrange(penalty) %>% dplyr::pull(.pred_time),
exp_pred
)

})

# prediction: survival ----------------------------------------------------

test_that("survival probabilities without strata", {
Expand Down

0 comments on commit 7b2613d

Please sign in to comment.