Skip to content

Commit

Permalink
Adapt for upcoming flexsurv version 2.3 (#317)
Browse files Browse the repository at this point in the history
* adapt for upcoming flexsurv version

* add NEWS bullet

* namespace `packageVersion()
  • Loading branch information
hfrick authored Apr 19, 2024
1 parent ebda76c commit 66a9873
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 72 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# censored (development version)

* Internal changes to the `predict()` methods for flexsurv models, in preparation for the upcoming flexsurv release (#317).


# censored 0.3.0

## New features
Expand Down
64 changes: 4 additions & 60 deletions R/survival_reg-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,7 @@ make_survival_reg_flexsurv <- function() {
type = "hazard",
value = list(
pre = NULL,
post = function(pred, object) {
if (".pred" %in% names(pred)) {
pred %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
} else {
dplyr::rename(pred, .eval_time = .time) %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::nest(.by = .row) %>%
dplyr::select(-.row)
}
},
post = flexsurv_post,
func = c(fun = "predict"),
args =
list(
Expand All @@ -290,21 +276,7 @@ make_survival_reg_flexsurv <- function() {
type = "survival",
value = list(
pre = NULL,
post = function(pred, object) {
if (".pred" %in% names(pred)) {
pred %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
} else {
dplyr::rename(pred, .eval_time = .time) %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::nest(.by = .row) %>%
dplyr::select(-.row)
}
},
post = flexsurv_post,
func = c(fun = "predict"),
args =
list(
Expand Down Expand Up @@ -442,21 +414,7 @@ make_survival_reg_flexsurvspline <- function() {
type = "hazard",
value = list(
pre = NULL,
post = function(pred, object) {
if (".pred" %in% names(pred)) {
pred %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
} else {
dplyr::rename(pred, .eval_time = .time) %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::nest(.by = .row) %>%
dplyr::select(-.row)
}
},
post = flexsurv_post,
func = c(fun = "predict"),
args =
list(
Expand All @@ -475,21 +433,7 @@ make_survival_reg_flexsurvspline <- function() {
type = "survival",
value = list(
pre = NULL,
post = function(pred, object) {
if (".pred" %in% names(pred)) {
pred %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
} else {
dplyr::rename(pred, .eval_time = .time) %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::nest(.by = .row) %>%
dplyr::select(-.row)
}
},
post = flexsurv_post,
func = c(fun = "predict"),
args =
list(
Expand Down
29 changes: 29 additions & 0 deletions R/survival_reg-flexsurv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
flexsurv_post <- function(pred, object) {
if (utils::packageVersion("flexsurv") < "2.3") {
pred <- flexsurv_rename_time(pred)
}

# if there's only one observation in new_data,
# flexsurv output isn't nested
if (!(".pred" %in% names(pred))) {
pred <- pred %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::nest(.by = .row) %>%
dplyr::select(-.row)
}
pred
}

flexsurv_rename_time <- function(pred){
if (".pred" %in% names(pred)) {
pred %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
} else {
pred %>%
dplyr::rename(.eval_time = .time)
}
}
30 changes: 18 additions & 12 deletions tests/testthat/test-survival_reg-flexsurvspline.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ test_that("survival probability prediction", {
head(lung),
type = "survival",
times = c(0, 500, 1000)
) %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
)
if (packageVersion("flexsurv") < "2.3") {
exp_pred <- exp_pred %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
}

f_fit <- survival_reg() %>%
set_engine("flexsurvspline", k = 1) %>%
Expand Down Expand Up @@ -281,12 +284,15 @@ test_that("hazard prediction", {
head(lung),
type = "hazard",
times = c(0, 500, 1000)
) %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
)
if (packageVersion("flexsurv") < "2.3") {
exp_pred <- exp_pred %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred = list(dplyr::rename(.pred, .eval_time = .time))
) %>%
dplyr::ungroup()
}

f_fit <- survival_reg() %>%
set_engine("flexsurvspline", k = 1) %>%
Expand Down

0 comments on commit 66a9873

Please sign in to comment.