Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harmonize inputs to survival_*_*() functions #302

Merged
merged 12 commits into from
Jan 24, 2024
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

* Bug fix: `proportional_hazards(engine = "glmnet")` models now don't pretend to be able to deal with sparse matrices when they are not (#291).

* Breaking change: The `survival_prob_*()`, `survival_time_*()`, and `hazard_*()` functions now all take a parsnip `model_fit` object as the main input, instead of an engine fit as was the case for some of them previously.


# censored 0.2.0

Expand Down
4 changes: 2 additions & 2 deletions R/bag_tree-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ make_bag_tree_rpart <- function() {
func = c(pkg = "censored", fun = "survival_time_survbagg"),
args =
list(
object = rlang::expr(object$fit),
object = rlang::expr(object),
new_data = rlang::expr(new_data)
)
)
Expand All @@ -75,7 +75,7 @@ make_bag_tree_rpart <- function() {
func = c(pkg = "censored", fun = "survival_prob_survbagg"),
args =
list(
object = rlang::expr(object$fit),
object = rlang::expr(object),
new_data = rlang::expr(new_data),
eval_time = rlang::expr(eval_time)
)
Expand Down
32 changes: 22 additions & 10 deletions R/bag_tree-rpart.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
#' A wrapper for survival times with `survbagg` models
#' @param object A model from `ipred::bagging()`.
#' @param object A parsnip `model_fit` object resulting from `bag_tree()` with `engine = "rpart"`.
hfrick marked this conversation as resolved.
Show resolved Hide resolved
#' @param new_data Data for prediction
#' @return A vector.
#' @keywords internal
#' @export
#' @examples
#' library(ipred)
#' bagged_tree <- bagging(Surv(time, status) ~ age + ph.ecog, data = lung)
#' bagged_tree <- bag_tree() %>%
#' set_engine("rpart") %>%
#' set_mode("censored regression") %>%
#' fit(Surv(time, status) ~ age + ph.ecog, data = lung)
#' survival_time_survbagg(bagged_tree, lung[1:3, ])
survival_time_survbagg <- function(object, new_data) {
missings_in_new_data <- get_missings_survbagg(object, new_data)
if (inherits(object, "survbagg")) {
cli::cli_abort("{.arg object} needs to be a parsnip {.cls model_fit} object, not a {.cls survbagg} object.")
}

missings_in_new_data <- get_missings_survbagg(object$fit, new_data)
if (!is.null(missings_in_new_data)) {
n_total <- nrow(new_data)
n_missing <- length(missings_in_new_data)
Expand All @@ -21,7 +27,7 @@ survival_time_survbagg <- function(object, new_data) {
new_data <- new_data[-missings_in_new_data, , drop = FALSE]
}

y <- predict(object, newdata = new_data)
y <- predict(object$fit, newdata = new_data)

res <- purrr::map_dbl(y, ~ quantile(.x, probs = .5)$quantile)

Expand All @@ -48,18 +54,24 @@ get_missings_survbagg <- function(object, new_data) {
}

#' A wrapper for survival probabilities with `survbagg` models
#' @param object A model from `ipred::bagging()`.
#' @param object A parsnip `model_fit` object resulting from `bag_tree()` with `engine = "rpart"`.
#' @param new_data Data for prediction.
#' @param eval_time A vector of prediction times.
#' @param time Deprecated in favor of `eval_time`. A vector of prediction times.
#' @return A vctrs list of tibbles.
#' @keywords internal
#' @export
#' @examples
#' library(ipred)
#' bagged_tree <- bagging(Surv(time, status) ~ age + ph.ecog, data = lung)
#' bagged_tree <- bag_tree() %>%
#' set_engine("rpart") %>%
#' set_mode("censored regression") %>%
#' fit(Surv(time, status) ~ age + ph.ecog, data = lung)
#' survival_prob_survbagg(bagged_tree, lung[1:3, ], eval_time = 100)
survival_prob_survbagg <- function(object, new_data, eval_time, time = deprecated()) {
if (inherits(object, "survbagg")) {
cli::cli_abort("{.arg object} needs to be a parsnip {.cls model_fit} object, not a {.cls survbagg} object.")
}

if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
"0.2.0",
Expand All @@ -76,7 +88,7 @@ survival_prob_survbagg <- function(object, new_data, eval_time, time = deprecate
output <- "surv"

n_obs <- nrow(new_data)
missings_in_new_data <- get_missings_survbagg(object, new_data)
missings_in_new_data <- get_missings_survbagg(object$fit, new_data)

if (!is.null(missings_in_new_data)) {
n_missing <- length(missings_in_new_data)
Expand All @@ -89,7 +101,7 @@ survival_prob_survbagg <- function(object, new_data, eval_time, time = deprecate
new_data <- new_data[-missings_in_new_data, , drop = FALSE]
}

y <- predict(object, newdata = new_data)
y <- predict(object$fit, newdata = new_data)

survfit_summary_list <- purrr::map(y, summary, times = eval_time, extend = TRUE)
survfit_summary_combined <- combine_list_of_survfit_summary(
Expand Down
4 changes: 2 additions & 2 deletions R/boost_tree-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ make_boost_tree_mboost <- function() {
post = NULL,
func = c(pkg = "censored", fun = "survival_prob_mboost"),
args = list(
object = rlang::expr(object$fit),
object = rlang::expr(object),
new_data = rlang::expr(new_data),
eval_time = rlang::expr(eval_time)
)
Expand Down Expand Up @@ -138,7 +138,7 @@ make_boost_tree_mboost <- function() {
func = c(pkg = "censored", fun = "survival_time_mboost"),
args =
list(
object = quote(object$fit),
object = quote(object),
new_data = quote(new_data)
)
)
Expand Down
30 changes: 20 additions & 10 deletions R/boost_tree-mboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,24 @@ predict_linear_pred._blackboost <- function(object,
}

#' A wrapper for survival probabilities with mboost models
#' @param object A model from `blackboost()`.
#' @param object A parsnip `model_fit` object resulting from `boost_tree()` with `engine = "mboost"`.
#' @param new_data Data for prediction.
#' @param eval_time A vector of integers for prediction times.
#' @param time Deprecated in favor of `eval_time`. A vector of integers for prediction times.
#' @return A tibble with a list column of nested tibbles.
#' @keywords internal
#' @export
#' @examples
#' library(mboost)
#' mod <- blackboost(Surv(time, status) ~ ., data = lung, family = CoxPH())
#' mod <- boost_tree() %>%
#' set_engine("mboost") %>%
#' set_mode("censored regression") %>%
#' fit(Surv(time, status) ~ ., data = lung)
#' survival_prob_mboost(mod, new_data = lung[1:3, ], eval_time = 300)
survival_prob_mboost <- function(object, new_data, eval_time, time = deprecated()) {
if (inherits(object, "mboost")) {
cli::cli_abort("{.arg object} needs to be a parsnip {.cls model_fit} object, not a {.cls mboost} object.")
}

if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
"0.2.0",
Expand All @@ -125,7 +131,7 @@ survival_prob_mboost <- function(object, new_data, eval_time, time = deprecated(
eval_time <- time
}

survival_curve <- mboost::survFit(object, newdata = new_data)
survival_curve <- mboost::survFit(object$fit, newdata = new_data)

survival_prob <- survival_curve_to_prob(
eval_time,
Expand Down Expand Up @@ -165,19 +171,23 @@ survival_curve_to_prob <- function(eval_time, event_times, survival_prob) {


#' A wrapper for mean survival times with `mboost` models
#' @param object A model from `blackboost()`.
#' @param object A parsnip `model_fit` object resulting from `boost_tree()` with `engine = "mboost"`.
#' @param new_data Data for prediction
#' @return A tibble.
#' @keywords internal
#' @export
#' @examples
#' library(mboost)
#' boosted_tree <- blackboost(Surv(time, status) ~ age + ph.ecog,
#' data = lung[-14, ], family = CoxPH()
#' )
#' boosted_tree <- boost_tree() %>%
#' set_engine("mboost") %>%
#' set_mode("censored regression") %>%
#' fit(Surv(time, status) ~ age + ph.ecog, data = lung[-14, ])
#' survival_time_mboost(boosted_tree, new_data = lung[1:3, ])
survival_time_mboost <- function(object, new_data) {
y <- mboost::survFit(object, new_data)
if (inherits(object, "mboost")) {
cli::cli_abort("{.arg object} needs to be a parsnip {.cls model_fit} object, not a {.cls mboost} object.")
}

y <- mboost::survFit(object$fit, new_data)

stacked_survfit <- stack_survfit(y, n = nrow(new_data))

Expand Down
2 changes: 1 addition & 1 deletion R/decision_tree-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ make_decision_tree_partykit <- function() {
post = NULL,
func = c(pkg = "censored", fun = "survival_prob_partykit"),
args = list(
object = rlang::expr(object$fit),
object = rlang::expr(object),
new_data = rlang::expr(new_data),
eval_time = rlang::expr(eval_time)
)
Expand Down
2 changes: 1 addition & 1 deletion R/decision_tree-rpart.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' A wrapper for survival probabilities with pecRpart models
#' @param object A fitted `_pecRpart` object.
#' @param object A parsnip `model_fit` object resulting from `decision_tree()` with `engine = "rpart"`.
#' @param new_data Data for prediction.
#' @param eval_time A vector of integers for prediction times.
#' @return A tibble with a list column of nested tibbles.
Expand Down
24 changes: 17 additions & 7 deletions R/partykit.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#' A wrapper for survival probabilities with partykit models
#' @param object A model object from `partykit::ctree()` or `partykit::cforest()`.
#' @param object A parsnip `model_fit` object resulting from `bag_tree()` with
#' `engine = "partykit"` or from `rand_forest()` with `engine = "partykit"`.
#' @param new_data A data frame to be predicted.
#' @param eval_time A vector of times to predict the survival probability.
#' @param time Deprecated in favor of `eval_time`. A vector of times to predict the survival probability.
Expand All @@ -8,16 +9,25 @@
#' @export
#' @keywords internal
#' @examples
#' library(partykit)
#' c_tree <- ctree(Surv(time, status) ~ age + ph.ecog, data = lung)
#' survival_prob_partykit(c_tree, lung[1:3, ], eval_time = 100)
#' c_forest <- cforest(Surv(time, status) ~ age + ph.ecog, data = lung, ntree = 10)
#' survival_prob_partykit(c_forest, lung[1:3, ], eval_time = 100)
#' tree <- decision_tree() %>%
#' set_mode("censored regression") %>%
#' set_engine("partykit") %>%
#' fit(Surv(time, status) ~ age + ph.ecog, data = lung)
#' survival_prob_partykit(tree, lung[1:3, ], eval_time = 100)
#' forest <- rand_forest() %>%
#' set_mode("censored regression") %>%
#' set_engine("partykit") %>%
#' fit(Surv(time, status) ~ age + ph.ecog, data = lung)
#' survival_prob_partykit(forest, lung[1:3, ], eval_time = 100)
survival_prob_partykit <- function(object,
new_data,
eval_time,
time = deprecated(),
output = "surv") {
if (inherits(object, "party")) {
cli::cli_abort("{.arg object} needs to be a parsnip {.cls model_fit} object, not a {.cls party} object.")
}

if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
"0.2.0",
Expand All @@ -36,7 +46,7 @@ survival_prob_partykit <- function(object,
# partykit handles missing values
missings_in_new_data <- NULL

y <- predict(object, newdata = new_data, type = "prob")
y <- predict(object$fit, newdata = new_data, type = "prob")

survfit_summary_list <- purrr::map(y, summary, times = eval_time, extend = TRUE)
survfit_summary_combined <- combine_list_of_survfit_summary(
Expand Down
4 changes: 2 additions & 2 deletions R/proportional_hazards-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ make_proportional_hazards_survival <- function() {
func = c(pkg = "censored", fun = "survival_time_coxph"),
args =
list(
object = quote(object$fit),
object = quote(object),
new_data = quote(new_data)
)
)
Expand All @@ -80,7 +80,7 @@ make_proportional_hazards_survival <- function() {
func = c(pkg = "censored", fun = "survival_prob_coxph"),
args =
list(
x = quote(object$fit),
object = quote(object),
new_data = quote(new_data),
eval_time = rlang::expr(eval_time),
output = "surv",
Expand Down
4 changes: 2 additions & 2 deletions R/proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {
# prediction: time --------------------------------------------------------

#' A wrapper for survival times with coxnet models
#' @param object A fitted `_coxnet` object.
#' @param object A parsnip `model_fit` object resulting from `proportional_hazards()` with `engine = "glmnet"`.
#' @param new_data Data for prediction.
#' @param penalty Penalty value(s).
#' @param multi Allow multiple penalty values?
Expand Down Expand Up @@ -583,7 +583,7 @@ get_missings_coxnet <- function(new_x, new_strata) {


#' A wrapper for survival probabilities with coxnet models
#' @param object A fitted `_coxnet` object.
#' @param object A parsnip `model_fit` object resulting from `proportional_hazards()` with `engine = "glmnet"`.
#' @param new_data Data for prediction.
#' @param eval_time A vector of integers for prediction times.
#' @param time Deprecated in favor of `eval_time`. A vector of integers for prediction times.
Expand Down
39 changes: 29 additions & 10 deletions R/proportional_hazards-survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,22 @@ cph_survival_pre <- function(new_data, object, ..., call = caller_env()) {
# prediction: time --------------------------------------------------------

#' A wrapper for survival times with `coxph` models
#' @param object A model from `coxph()`.
#' @param object A parsnip `model_fit` object resulting from `proportional_hazards()` with `engine = "survival"`.
#' @param new_data Data for prediction
#' @return A vector.
#' @keywords internal
#' @export
#' @examples
#' cox_mod <- coxph(Surv(time, status) ~ ., data = lung)
#' cox_mod <- proportional_hazards() %>%
#' set_engine("survival") %>%
#' fit(Surv(time, status) ~ ., data = lung)
#' survival_time_coxph(cox_mod, new_data = lung[1:3, ])
survival_time_coxph <- function(object, new_data) {
missings_in_new_data <- get_missings_coxph(object, new_data)
if (inherits(object, "coxph")) {
cli::cli_abort("{.arg object} needs to be a parsnip {.cls model_fit} object, not a {.cls coxph} object.")
}

missings_in_new_data <- get_missings_coxph(object$fit, new_data)
if (!is.null(missings_in_new_data)) {
n_total <- nrow(new_data)
n_missing <- length(missings_in_new_data)
Expand All @@ -67,8 +73,7 @@ survival_time_coxph <- function(object, new_data) {
new_data <- new_data[-missings_in_new_data, ]
}


y <- survival::survfit(object, new_data, na.action = stats::na.exclude)
y <- survival::survfit(object$fit, new_data, na.action = stats::na.exclude)

tabs <- summary(y)$table
if (is.matrix(tabs)) {
Expand Down Expand Up @@ -106,7 +111,8 @@ get_missings_coxph <- function(object, new_data) {
# prediction: survival ----------------------------------------------------

#' A wrapper for survival probabilities with coxph models
#' @param x A model from `coxph()`.
#' @param object A parsnip `model_fit` object resulting from `proportional_hazards()` with `engine = "survival"`.
#' @param x Deprecated. A model from `coxph()`.
#' @param new_data Data for prediction
#' @param eval_time A vector of integers for prediction times.
#' @param time Deprecated in favor of `eval_time`. A vector of integers for prediction times.
Expand All @@ -119,16 +125,29 @@ get_missings_coxph <- function(object, new_data) {
#' @keywords internal
#' @export
#' @examples
#' cox_mod <- coxph(Surv(time, status) ~ ., data = lung)
#' cox_mod <- proportional_hazards() %>%
#' set_engine("survival") %>%
#' fit(Surv(time, status) ~ ., data = lung)
#' survival_prob_coxph(cox_mod, new_data = lung[1:3, ], eval_time = 300)
survival_prob_coxph <- function(x,
survival_prob_coxph <- function(object,
x = deprecated(),
new_data,
eval_time,
time = deprecated(),
output = "surv",
interval = "none",
conf.int = .95,
...) {
if (inherits(object, "coxph")) {
cli::cli_abort("{.arg object} needs to be a parsnip {.cls model_fit} object, not a {.cls coxph} object.")
}
if (lifecycle::is_present(x)) {
lifecycle::deprecate_stop(
"0.3.0",
"survival_prob_coxph(x)",
"survival_prob_coxph(object)"
)
}
if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
"0.2.0",
Expand All @@ -145,7 +164,7 @@ survival_prob_coxph <- function(x,
}

n_obs <- nrow(new_data)
missings_in_new_data <- get_missings_coxph(x, new_data)
missings_in_new_data <- get_missings_coxph(object$fit, new_data)

if (!is.null(missings_in_new_data)) {
n_missing <- length(missings_in_new_data)
Expand All @@ -159,7 +178,7 @@ survival_prob_coxph <- function(x,
}

surv_fit <- survival::survfit(
x,
object$fit,
newdata = new_data,
conf.int = conf.int,
na.action = na.exclude,
Expand Down
Loading
Loading