Skip to content

Commit

Permalink
Enable predictions of survival time with aorsf (#308)
Browse files Browse the repository at this point in the history
* Enable time predictions for aorsf

* include new type in docs

* Update NEWS

* Update tests/testthat/test-rand_forest-aorsf.R

Co-authored-by: Emil Hvitfeldt <[email protected]>

---------

Co-authored-by: Emil Hvitfeldt <[email protected]>
  • Loading branch information
hfrick and EmilHvitfeldt authored Jan 26, 2024
1 parent 5343ca8 commit dc00710
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Imports:
tibble (>= 3.1.3),
tidyr (>= 1.0.0)
Suggests:
aorsf (>= 0.0.4),
aorsf (>= 0.1.2),
coin,
covr,
flexsurv (>= 2.2.1),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

* Breaking change: The `survival_prob_*()`, `survival_time_*()`, and `hazard_*()` functions now all take a parsnip `model_fit` object as the main input, instead of an engine fit as was the case for some of them previously (#302).

* Random forests with the `"aorsf"` engine can now predict survival time, i.e., `predict(type = "time")` is now available (#308).


# censored 0.2.0

Expand Down
20 changes: 20 additions & 0 deletions R/rand_forest-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,26 @@ make_rand_forest_aorsf <- function() {
)
)

parsnip::set_pred(
model = "rand_forest",
eng = "aorsf",
mode = "censored regression",
type = "time",
value = list(
pre = NULL,
post = function(x, object) {
as.vector(x)
},
func = c(fun = "predict"),
args = list(
object = rlang::expr(object$fit),
new_data = rlang::expr(new_data),
pred_type = "time",
na_action = "pass"
)
)
)

parsnip::set_pred(
model = "rand_forest",
eng = "aorsf",
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ with `type = "quantile"`, and the hazard with `type = "hazard"`.
| proportional_hazards | survival |||||||
| proportional_hazards | glmnet |||||||
| rand_forest | partykit |||||||
| rand_forest | aorsf | ||||||
| rand_forest | aorsf | ||||||
| survival_reg | survival |||||||
| survival_reg | flexsurv |||||||
| survival_reg | flexsurvspline |||||||
Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/test-rand_forest-aorsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,42 @@ test_that("model object", {
)
})

# prediction: time --------------------------------------------------------

test_that("time predictions", {
skip_if_not_installed("aorsf", "0.1.2")

lung_orsf <- na.omit(lung)

set.seed(1234)
exp_f_fit <- aorsf::orsf(
data = lung_orsf,
formula = Surv(time, status) ~ age + ph.ecog
)
exp_f_pred <- predict(
exp_f_fit,
new_data = lung,
pred_type = "time",
na_action = "pass"
)

mod_spec <- rand_forest() %>%
set_engine("aorsf") %>%
set_mode("censored regression")
set.seed(1234)
f_fit <- fit(mod_spec, Surv(time, status) ~ age + ph.ecog, data = lung_orsf)
f_pred <- predict(f_fit, lung, type = "time")

expect_s3_class(f_pred, "tbl_df")
expect_true(all(names(f_pred) == ".pred_time"))
expect_equal(f_pred$.pred_time, as.vector(exp_f_pred))
expect_equal(nrow(f_pred), nrow(lung))

# single observation
f_pred_1 <- predict(f_fit, lung[2,], type = "time")
expect_identical(nrow(f_pred_1), 1L)
})

# prediction: survival ----------------------------------------------------

test_that("survival predictions", {
Expand Down
1 change: 1 addition & 0 deletions vignettes/articles/examples.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ The following examples use the same data set throughout.
) %>%
slice(1) %>%
tidyr::unnest(col = .pred)
predict(rf_fit, lung_test, type = "time")
```
</details>

Expand Down

0 comments on commit dc00710

Please sign in to comment.