diff --git a/R/PipeOpPredRegrSurvPEM.R b/R/PipeOpPredRegrSurvPEM.R new file mode 100644 index 000000000..0667bfb8c --- /dev/null +++ b/R/PipeOpPredRegrSurvPEM.R @@ -0,0 +1,97 @@ +#' @title PipeOpPredRegrSurvPEM +#' @name mlr_pipeops_trafopred_regrsurv_PEM +#' +#' @description +#' Transform [PredictionRegr] to [PredictionSurv]. +#' +#' @section Dictionary: +#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the +#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops] +#' or with the associated sugar function [mlr3pipelines::po()]: +#' ``` +#' PipeOpPredRegrSurvPEM$new() +#' mlr_pipeops$get("trafopred_regrsurv_PEM") +#' po("trafopred_regrsurv_PEM") +#' ``` +#' +#' @section Input and Output Channels: +#' The input is a [PredictionRegr] and a [data.table][data.table::data.table] +#' with the transformed data both generated by [PipeOpTaskSurvRegrPEM]. +#' The output is the input [PredictionRegr] transformed to a [PredictionSurv]. +#' Only works during prediction phase. +#' +#' @family PipeOps +#' @family Transformation PipeOps +#' @export +PipeOpPredRegrSurvPEM = R6Class( + "PipeOpPredRegrSurvPEM", + inherit = mlr3pipelines::PipeOp, + + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + #' @param id (character(1))\cr + #' Identifier of the resulting object. + initialize = function(id = "trafopred_regrsurv_PEM") { + super$initialize( + id = id, + input = data.table( + name = c("input", "transformed_data"), + train = c("NULL", "data.table"), + predict = c("PredictionRegr", "data.table") + ), + output = data.table( + name = "output", + train = "NULL", + predict = "PredictionSurv" + ) + ) + } + ), + + private = list( + .predict = function(input) { + pred = input[[1]] + data = input[[2]] + assert_true(!is.null(pred$response)) + # probability of having the event (1) in each respective interval + # is the discrete-time hazard + data = cbind(data, dt_hazard = pred$response) + + # From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset))) + rows_per_id = nrow(data) / length(unique(data$id)) + + # If 'single_event', 'cr', 'msm') + surv = t(vapply(unique(data$id), function(unique_id) { + exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]]))) + }, numeric(rows_per_id))) + + unique_end_times = sort(unique(data$tend)) + # coerce to distribution and crank + pred_list = .surv_return(times = unique_end_times, surv = surv) + + # select the real tend values by only selecting the last row of each id + # basically a slightly more complex unique() + real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0] + + ids = unique(data$id) + # select last row for every id => observed times + id = PEM_status = NULL # to fix note + data = data[, .SD[.N, list(PEM_status)], by = id] + + # create prediction object + p = PredictionSurv$new( + row_ids = ids, + crank = pred_list$crank, distr = pred_list$distr, + truth = Surv(real_tend, as.integer(as.character(data$PEM_status)))) + + list(p) + }, + + .train = function(input) { + self$state = list() + list(input) + } + ) +) +register_pipeop("trafopred_regrsurv_PEM", PipeOpPredRegrSurvPEM) diff --git a/R/PipeOpTaskSurvRegrPEM.R b/R/PipeOpTaskSurvRegrPEM.R new file mode 100644 index 000000000..58011c051 --- /dev/null +++ b/R/PipeOpTaskSurvRegrPEM.R @@ -0,0 +1,220 @@ +#' @title PipeOpTaskSurvRegrPEM +#' @name mlr_pipeops_trafotask_survregr_PEM +#' @template param_pipelines +#' +#' @description +#' Transform [TaskSurv] to [TaskRegr][mlr3::TaskRegr] by dividing continuous +#' time into multiple time intervals for each observation. +#' This transformation creates a new target variable `PEM_status` that indicates +#' whether an event occurred within each time interval. +#' +#' @section Dictionary: +#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the +#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops] +#' or with the associated sugar function [mlr3pipelines::po()]: +#' ``` +#' PipeOpTaskSurvRegrPEM$new() +#' mlr_pipeops$get("trafotask_survregr_PEM") +#' po("trafotask_survregr_PEM") +#' ``` +#' +#' @section Input and Output Channels: +#' [PipeOpTaskSurvRegrPEM] has one input channel named "input", and two +#' output channels, one named "output" and the other "transformed_data". +#' +#' During training, the "output" is the "input" [TaskSurv] transformed to a +#' [TaskRegr][mlr3::TaskRegr]. +#' The target column is named `"PEM_status"` and indicates whether an event occurred +#' in each time interval. +#' An additional feature named `"tend"` contains the end time point of each interval. +#' Lastly, the "output" task has an offset column `"offset"`. +#' The "transformed_data" is an empty [data.table][data.table::data.table]. +#' +#' During prediction, the "input" [TaskSurv] is transformed to the "output" +#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target and the `"tend"` +#' as well as `"offset"` feature included. +#' The "transformed_data" is a [data.table] with columns the `"PEM_status"` +#' target of the "output" task, the `"id"` (original observation ids), +#' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval). +#' This "transformed_data" is only meant to be used with the [PipeOpPredRegrSurvPEM]. +#' +#' @section State: +#' The `$state` contains information about the `cut` parameter used. +#' +#' @section Parameters: +#' The parameters are +#' +#' * `cut :: numeric()`\cr +#' Split points, used to partition the data into intervals based on the `time` column. +#' If unspecified, all unique event times will be used. +#' If `cut` is a single integer, it will be interpreted as the number of equidistant +#' intervals from 0 until the maximum event time. +#' * `max_time :: numeric(1)`\cr +#' If `cut` is unspecified, this will be the last possible event time. +#' All event times after `max_time` will be administratively censored at `max_time.` +#' Needs to be greater than the minimum event time in the given task. +#' +#' @examples +#' +#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3learners"), quietly = TRUE) +#' \dontrun{ +#' # Update documentation to match PEM +#' library(mlr3) +#' library(mlr3learners) +#' library(mlr3pipelines) +#' +#' task = tsk("lung") +#' +#' # transform the survival task to a poisson regression task +#' # all unique event times are used as cutpoints +#' po_PEM = po("trafotask_survregr_PEM") +#' task_regr = po_PEM$train(list(task))[[1L]] +#' +#' # the end time points of the discrete time intervals +#' unique(task_regr$data(cols = "tend"))[[1L]] +#' +#' # train a classification learner +#' learner = lrn("classif.log_reg", predict_type = "prob") +#' learner$train(task_regr) +#' } +#' } +#' +#' +#' @family PipeOps +#' @family Transformation PipeOps +#' @export +PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM", + inherit = mlr3pipelines::PipeOp, + + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function(id = "trafotask_survregr_PEM") { + param_set = ps( + cut = p_uty(default = NULL), + max_time = p_dbl(0, default = NULL, special_vals = list(NULL)), + censor_code = p_int(0L), + min_events = p_int(1L), + form = p_uty(tags = 'train') + #pammtools arguments: transitions etc. + ) + super$initialize( + id = id, + param_set = param_set, + input = data.table( + name = "input", + train = "TaskSurv", + predict = "TaskSurv" + ), + output = data.table( + name = c("output", "transformed_data"), + train = c("TaskRegr", "data.table"), + predict = c("TaskRegr", "data.table") + ) + ) + } + ), + + private = list( + .train = function(input) { + task = input[[1L]] + assert_true(task$censtype == "right") + data = task$data() + + if ("PEM_status" %in% colnames(task$data())) { + stop("\"PEM_status\" can not be a column in the input data.") + } + + cut = assert_numeric(self$param_set$values$cut, null.ok = TRUE, lower = 0) + max_time = self$param_set$values$max_time + + time_var = task$target_names[1] + event_var = task$target_names[2] + if (testInt(cut, lower = 1)) { + cut = seq(0, data[get(event_var) == 1, max(get(time_var))], length.out = cut + 1) + } + + if (!is.null(max_time)) { + assert(max_time > data[get(event_var) == 1, min(get(time_var))], + "max_time must be greater than the minimum event time.") + } + + # To-Do: Extend to a more general formulation for competing risks and msm + # Issue: We pass form (e.g. Surv(time, status) ~ .) which currently serves to correctly transform the data into ped format + # but doesn't serve any other purpose yet. For ML learners, such as xgb, the covariate structure is passed to the pipeline via rhs not form. + long_data = pammtools::as_ped(data = data, formula = self$param_set$values$form, cut = cut, max_time = max_time) + self$state$cut = attributes(long_data)$trafo_args$cut + + + + long_data = as.data.table(long_data) + setnames(long_data, old = "ped_status", new = "PEM_status") #change to PEM + + # remove some columns from `long_data` + long_data[, c("tstart", "interval") := NULL] + # keep id mapping + reps = table(long_data$id) + ids = rep(task$row_ids, times = reps) + id = NULL + long_data[, id := ids] + + task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, + target = "PEM_status") + task_PEM$set_col_roles("id", roles = "original_ids") + + list(task_PEM, data.table()) + }, + + .predict = function(input) { + task = input[[1]] + data = task$data() + + # extract `cut` from `state` + cut = self$state$cut + + time_var = task$target_names[1] + event_var = task$target_names[2] + + max_time = max(cut) + time = data[[time_var]] + data[[time_var]] = max_time + + status = data[[event_var]] + data[[event_var]] = 1 + + + long_data = as.data.table(pammtools::as_ped(data, formula = self$param_set$values$form, cut = cut)) + setnames(long_data, old = "ped_status", new = "PEM_status") + + PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table + long_data[, PEM_status := 0] + # set correct id + rows_per_id = nrow(long_data) / length(unique(long_data$id)) + long_data$obs_times = rep(time, each = rows_per_id) + ids = rep(task$row_ids, each = rows_per_id) + long_data[, id := ids] + + # set correct PEM_status + reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count + status = rep(status, times = reps) + long_data[long_data[, .I[tend >= obs_times], by = id]$V1, PEM_status := status] + + # remove some columns from `long_data` + long_data[, c("tstart", "interval", "obs_times") := NULL] + task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data, + target = "PEM_status") + task_PEM$set_col_roles("id", roles = "original_ids") + + # map observed times back + reps = table(long_data$id) + long_data$obs_times = rep(time, each = rows_per_id) + # subset transformed data + columns_to_keep = c("id", "obs_times", "tend", "PEM_status", "offset") + long_data = long_data[, columns_to_keep, with = FALSE] + + list(task_PEM, long_data) + } + ) +) + +register_pipeop("trafotask_survregr_PEM", PipeOpTaskSurvRegrPEM) diff --git a/R/aaa.R b/R/aaa.R index 20925bec5..c874badaf 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -51,7 +51,8 @@ register_reflections = function() { x$task_col_roles$surv = x$task_col_roles$regr x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum") - x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids")) # for discrete time + x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids"))# for discrete time + x$task_col_roles$regr = unique(c(x$task_col_roles$regr, "original_ids")) x$task_properties$surv = x$task_properties$regr x$task_properties$dens = x$task_properties$regr diff --git a/R/pipelines.R b/R/pipelines.R index 8587f6d44..5fff52c3c 100644 --- a/R/pipelines.R +++ b/R/pipelines.R @@ -659,6 +659,89 @@ pipeline_survtoclassif_disctime = function(learner, cut = NULL, max_time = NULL, gr } +#' @name mlr_graphs_survtoregr_PEM +#' @title Survival to Poisson Regression Reduction Pipeline +#' @description Wrapper around multiple [PipeOp][mlr3pipelines::PipeOp]s to help in creation +#' of complex survival reduction methods. +#' +#' @param learner [LearnerRegr][mlr3::LearnerRegr]\cr +#' Regression learner to fit the transformed [TaskRegr][mlr3::TaskRegr]. +#' `learner` must be able to handle `offset`. +#' @param cut `numeric()`\cr +#' Split points, used to partition the data into intervals. +#' If unspecified, all unique event times will be used. +#' If `cut` is a single integer, it will be interpreted as the number of equidistant +#' intervals from 0 until the maximum event time. +#' @param max_time `numeric(1)`\cr +#' If cut is unspecified, this will be the last possible event time. +#' All event times after max_time will be administratively censored at max_time. +#' @param graph_learner `logical(1)`\cr +#' If `TRUE` returns wraps the [Graph][mlr3pipelines::Graph] as a +#' [GraphLearner][mlr3pipelines::GraphLearner] otherwise (default) returns as a `Graph`. +#' +#' @details +#' The pipeline consists of the following steps: +#' \enumerate{ +#' \item [PipeOpTaskSurvRegrPEM] Converts [TaskSurv] to a [TaskRegr][mlr3::TaskRegr]. +#' \item A [LearnerRegr] is fit and predicted on the new `TaskRegr`. +#' \item [PipeOpPredRegrSurvPEM] transforms the resulting [PredictionRegr][mlr3::PredictionRegr] +#' to [PredictionSurv]. +#' } +#' +#' @return [mlr3pipelines::Graph] or [mlr3pipelines::GraphLearner] +#' @family pipelines +#' +#' @examples +#' \dontrun{ +#' if (requireNamespace("mlr3pipelines", quietly = TRUE) && +#' requireNamespace("mlr3learners", quietly = TRUE)) { +#' +#' library(mlr3) +#' library(mlr3learners) +#' library(mlr3pipelines) +#' +#' task = tsk("lung") +#' part = partition(task) +#' +#' grlrn = ppl( +#' "survtoregr_PEM", +#' learner = lrn("regr.xgboost") +#' ) +#' grlrn$train(task, row_ids = part$train) +#' grlrn$predict(task, row_ids = part$test) +#' } +#' } +#' @export +pipeline_survtoregr_PEM = function(learner, cut = NULL, max_time = NULL, + rhs = NULL, graph_learner = FALSE, form = NULL) { + # TODO: add assertions + + gr = mlr3pipelines::Graph$new() + gr$add_pipeop(mlr3pipelines::po("trafotask_survregr_PEM", cut = cut, max_time = max_time, form = form)) + gr$add_pipeop(mlr3pipelines::po("learner", learner)) + gr$add_pipeop(mlr3pipelines::po("nop")) + gr$add_pipeop(mlr3pipelines::po("trafopred_regrsurv_PEM")) + + gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = learner$id, src_channel = "output", dst_channel = "input") + gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = "nop", src_channel = "transformed_data", dst_channel = "input") + gr$add_edge(src_id = learner$id, dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "input") + gr$add_edge(src_id = "nop", dst_id = "trafopred_regrsurv_PEM", src_channel = "output", dst_channel = "transformed_data") + + + if (!is.null(rhs)) { + gr$edges = gr$edges[-1, ] + gr$add_pipeop(mlr3pipelines::po("modelmatrix", formula = formulate(rhs = rhs, quote = "left"))) + gr$add_edge(src_id = "trafotask_survregr_PEM", dst_id = "modelmatrix", src_channel = "output") + gr$add_edge(src_id = "modelmatrix", dst_id = learner$id, src_channel = "output", dst_channel = "input") + } + + if (graph_learner) { + gr = mlr3pipelines::GraphLearner$new(gr) + } + + gr +} + register_graph("survaverager", pipeline_survaverager) register_graph("survbagging", pipeline_survbagging) register_graph("crankcompositor", pipeline_crankcompositor) @@ -667,3 +750,4 @@ register_graph("responsecompositor", pipeline_responsecompositor) register_graph("probregr", pipeline_probregr) register_graph("survtoregr", pipeline_survtoregr) register_graph("survtoclassif_disctime", pipeline_survtoclassif_disctime) +register_graph("survtoregr_PEM", pipeline_survtoregr_PEM) diff --git a/R/zzz.R b/R/zzz.R index bad8118b7..ec6cba062 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -85,6 +85,7 @@ unregister_reflections = function() { x$task_col_roles$surv = NULL x$task_col_roles$dens = NULL x$task_col_roles$classif = setdiff(x$task_col_roles$classif, "original_ids") + x$task_col_roles$regr = setdiff(x$task_col_roles$regr, 'original_ids') x$task_properties$surv = NULL x$task_properties$dens = NULL