Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEM Pipeline #417

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
97 changes: 97 additions & 0 deletions R/PipeOpPredRegrSurvPEM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#' @title PipeOpPredRegrSurvPEM
#' @name mlr_pipeops_trafopred_regrsurv_PEM
#'
#' @description
#' Transform [PredictionRegr] to [PredictionSurv].
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpPredRegrSurvPEM$new()
#' mlr_pipeops$get("trafopred_regrsurv_PEM")
#' po("trafopred_regrsurv_PEM")
#' ```
#'
#' @section Input and Output Channels:
#' The input is a [PredictionRegr] and a [data.table][data.table::data.table]
#' with the transformed data both generated by [PipeOpTaskSurvRegrPEM].
#' The output is the input [PredictionRegr] transformed to a [PredictionSurv].
#' Only works during prediction phase.
#'
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredRegrSurvPEM = R6Class(
"PipeOpPredRegrSurvPEM",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param id (character(1))\cr
#' Identifier of the resulting object.
initialize = function(id = "trafopred_regrsurv_PEM") {
super$initialize(
id = id,
input = data.table(
name = c("input", "transformed_data"),
train = c("NULL", "data.table"),
predict = c("PredictionRegr", "data.table")
),
output = data.table(
name = "output",
train = "NULL",
predict = "PredictionSurv"
)
)
}
),

private = list(
.predict = function(input) {
pred = input[[1]]
data = input[[2]]
assert_true(!is.null(pred$response))
# probability of having the event (1) in each respective interval
# is the discrete-time hazard
data = cbind(data, dt_hazard = pred$response)

# From theory, convert hazards to surv as exp(-cumsum(h(t) * exp(offset)))
rows_per_id = nrow(data) / length(unique(data$id))

# If 'single_event', 'cr', 'msm')
surv = t(vapply(unique(data$id), function(unique_id) {
exp(-cumsum(data[data$id == unique_id, ][["dt_hazard"]] * exp(data[data$id == unique_id, ][["offset"]])))
}, numeric(rows_per_id)))

unique_end_times = sort(unique(data$tend))
# coerce to distribution and crank
pred_list = .surv_return(times = unique_end_times, surv = surv)

# select the real tend values by only selecting the last row of each id
# basically a slightly more complex unique()
real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0]

ids = unique(data$id)
# select last row for every id => observed times
id = PEM_status = NULL # to fix note
data = data[, .SD[.N, list(PEM_status)], by = id]

# create prediction object
p = PredictionSurv$new(
row_ids = ids,
crank = pred_list$crank, distr = pred_list$distr,
truth = Surv(real_tend, as.integer(as.character(data$PEM_status))))

list(p)
},

.train = function(input) {
self$state = list()
list(input)
}
)
)
register_pipeop("trafopred_regrsurv_PEM", PipeOpPredRegrSurvPEM)
220 changes: 220 additions & 0 deletions R/PipeOpTaskSurvRegrPEM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#' @title PipeOpTaskSurvRegrPEM
#' @name mlr_pipeops_trafotask_survregr_PEM
#' @template param_pipelines
#'
#' @description
#' Transform [TaskSurv] to [TaskRegr][mlr3::TaskRegr] by dividing continuous
#' time into multiple time intervals for each observation.
#' This transformation creates a new target variable `PEM_status` that indicates
#' whether an event occurred within each time interval.
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpTaskSurvRegrPEM$new()
#' mlr_pipeops$get("trafotask_survregr_PEM")
#' po("trafotask_survregr_PEM")
#' ```
#'
#' @section Input and Output Channels:
#' [PipeOpTaskSurvRegrPEM] has one input channel named "input", and two
#' output channels, one named "output" and the other "transformed_data".
#'
#' During training, the "output" is the "input" [TaskSurv] transformed to a
#' [TaskRegr][mlr3::TaskRegr].
#' The target column is named `"PEM_status"` and indicates whether an event occurred
#' in each time interval.
#' An additional feature named `"tend"` contains the end time point of each interval.
#' Lastly, the "output" task has an offset column `"offset"`.
#' The "transformed_data" is an empty [data.table][data.table::data.table].
#'
#' During prediction, the "input" [TaskSurv] is transformed to the "output"
#' [TaskRegr][mlr3::TaskRegr] with `"PEM_status"` as target and the `"tend"`
#' as well as `"offset"` feature included.
#' The "transformed_data" is a [data.table] with columns the `"PEM_status"`
#' target of the "output" task, the `"id"` (original observation ids),
#' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval).
#' This "transformed_data" is only meant to be used with the [PipeOpPredRegrSurvPEM].
#'
#' @section State:
#' The `$state` contains information about the `cut` parameter used.
#'
#' @section Parameters:
#' The parameters are
#'
#' * `cut :: numeric()`\cr
#' Split points, used to partition the data into intervals based on the `time` column.
#' If unspecified, all unique event times will be used.
#' If `cut` is a single integer, it will be interpreted as the number of equidistant
#' intervals from 0 until the maximum event time.
#' * `max_time :: numeric(1)`\cr
#' If `cut` is unspecified, this will be the last possible event time.
#' All event times after `max_time` will be administratively censored at `max_time.`
#' Needs to be greater than the minimum event time in the given task.
#'
#' @examples
#'
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3learners"), quietly = TRUE)
#' \dontrun{
#' # Update documentation to match PEM
#' library(mlr3)
#' library(mlr3learners)
#' library(mlr3pipelines)
#'
#' task = tsk("lung")
#'
#' # transform the survival task to a poisson regression task
#' # all unique event times are used as cutpoints
#' po_PEM = po("trafotask_survregr_PEM")
#' task_regr = po_PEM$train(list(task))[[1L]]
#'
#' # the end time points of the discrete time intervals
#' unique(task_regr$data(cols = "tend"))[[1L]]
#'
#' # train a classification learner
#' learner = lrn("classif.log_reg", predict_type = "prob")
#' learner$train(task_regr)
#' }
#' }
#'
#'
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "trafotask_survregr_PEM") {
param_set = ps(
cut = p_uty(default = NULL),
max_time = p_dbl(0, default = NULL, special_vals = list(NULL)),
censor_code = p_int(0L),
min_events = p_int(1L),
form = p_uty(tags = 'train')
#pammtools arguments: transitions etc.
)
super$initialize(
id = id,
param_set = param_set,
input = data.table(
name = "input",
train = "TaskSurv",
predict = "TaskSurv"
),
output = data.table(
name = c("output", "transformed_data"),
train = c("TaskRegr", "data.table"),
predict = c("TaskRegr", "data.table")
)
)
}
),

private = list(
.train = function(input) {
task = input[[1L]]
assert_true(task$censtype == "right")
data = task$data()

if ("PEM_status" %in% colnames(task$data())) {
stop("\"PEM_status\" can not be a column in the input data.")
}

cut = assert_numeric(self$param_set$values$cut, null.ok = TRUE, lower = 0)
max_time = self$param_set$values$max_time

time_var = task$target_names[1]
event_var = task$target_names[2]
if (testInt(cut, lower = 1)) {
cut = seq(0, data[get(event_var) == 1, max(get(time_var))], length.out = cut + 1)
}

if (!is.null(max_time)) {
assert(max_time > data[get(event_var) == 1, min(get(time_var))],
"max_time must be greater than the minimum event time.")
}

# To-Do: Extend to a more general formulation for competing risks and msm
# Issue: We pass form (e.g. Surv(time, status) ~ .) which currently serves to correctly transform the data into ped format
# but doesn't serve any other purpose yet. For ML learners, such as xgb, the covariate structure is passed to the pipeline via rhs not form.
long_data = pammtools::as_ped(data = data, formula = self$param_set$values$form, cut = cut, max_time = max_time)
self$state$cut = attributes(long_data)$trafo_args$cut



long_data = as.data.table(long_data)
setnames(long_data, old = "ped_status", new = "PEM_status") #change to PEM

# remove some columns from `long_data`
long_data[, c("tstart", "interval") := NULL]
# keep id mapping
reps = table(long_data$id)
ids = rep(task$row_ids, times = reps)
id = NULL
long_data[, id := ids]

task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data,
target = "PEM_status")
task_PEM$set_col_roles("id", roles = "original_ids")

list(task_PEM, data.table())
},

.predict = function(input) {
task = input[[1]]
data = task$data()

# extract `cut` from `state`
cut = self$state$cut

time_var = task$target_names[1]
event_var = task$target_names[2]

max_time = max(cut)
time = data[[time_var]]
data[[time_var]] = max_time

status = data[[event_var]]
data[[event_var]] = 1


long_data = as.data.table(pammtools::as_ped(data, formula = self$param_set$values$form, cut = cut))
setnames(long_data, old = "ped_status", new = "PEM_status")

PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table
long_data[, PEM_status := 0]
# set correct id
rows_per_id = nrow(long_data) / length(unique(long_data$id))
long_data$obs_times = rep(time, each = rows_per_id)
ids = rep(task$row_ids, each = rows_per_id)
long_data[, id := ids]

# set correct PEM_status
reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count
status = rep(status, times = reps)
long_data[long_data[, .I[tend >= obs_times], by = id]$V1, PEM_status := status]

# remove some columns from `long_data`
long_data[, c("tstart", "interval", "obs_times") := NULL]
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data,
target = "PEM_status")
task_PEM$set_col_roles("id", roles = "original_ids")

# map observed times back
reps = table(long_data$id)
long_data$obs_times = rep(time, each = rows_per_id)
# subset transformed data
columns_to_keep = c("id", "obs_times", "tend", "PEM_status", "offset")
long_data = long_data[, columns_to_keep, with = FALSE]

list(task_PEM, long_data)
}
)
)

register_pipeop("trafotask_survregr_PEM", PipeOpTaskSurvRegrPEM)
3 changes: 2 additions & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ register_reflections = function() {

x$task_col_roles$surv = x$task_col_roles$regr
x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum")
x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids")) # for discrete time
x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids"))# for discrete time
x$task_col_roles$regr = unique(c(x$task_col_roles$regr, "original_ids"))
x$task_properties$surv = x$task_properties$regr
x$task_properties$dens = x$task_properties$regr

Expand Down
Loading