Skip to content

Commit

Permalink
draft task conversion pipeop
Browse files Browse the repository at this point in the history
  • Loading branch information
studener committed Sep 24, 2024
1 parent a3318b7 commit 6eb5a6e
Showing 1 changed file with 213 additions and 0 deletions.
213 changes: 213 additions & 0 deletions R/PipeOpTaskSurvRegrPEM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#' @title PipeOpTaskSurvRegrPEM
#' @name mlr_pipeops_trafotask_survregr_PEM
#' @template param_pipelines
#'
#' @description
#' Transform [TaskSurv] to [TaskRegr][mlr3::TaskRegr] by dividing continuous
#' time into multiple time intervals for each observation.
#' This transformation creates a new target variable `disc_status` that indicates
#' whether an event occurred within each time interval.
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpTaskSurvRegrPEM$new()
#' mlr_pipeops$get("trafotask_survregr_PEM")
#' po("trafotask_survregr_PEM")
#' ```
#'
#' @section Input and Output Channels:
#' [PipeOpTaskSurvClassifDiscTime] has one input channel named "input", and two
#' output channels, one named "output" and the other "transformed_data".
#'
#' During training, the "output" is the "input" [TaskSurv] transformed to a
#' [TaskRegr][mlr3::TaskRegr].
#' The target column is named `"disc_status"` and indicates whether an event occurred
#' in each time interval.
#' An additional feature named `"tend"` contains the end time point of each interval.
#' Lastly, the "output" task has an offset column `"offset"`.
#' The "transformed_data" is an empty [data.table][data.table::data.table].
#'
#' During prediction, the "input" [TaskSurv] is transformed to the "output"
#' [TaskRegr][mlr3::TaskRegr] with `"disc_status"` as target and the `"tend"`
#' as well as `"offset"` feature included.
#' The "transformed_data" is a [data.table] with columns the `"disc_status"`
#' target of the "output" task, the `"id"` (original observation ids),
#' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval).
#' This "transformed_data" is only meant to be used with the [PipeOpPredRegrSurvPEM].
#'
#' @section State:
#' The `$state` contains information about the `cut` parameter used.
#'
#' @section Parameters:
#' The parameters are
#'
#' * `cut :: numeric()`\cr
#' Split points, used to partition the data into intervals based on the `time` column.
#' If unspecified, all unique event times will be used.
#' If `cut` is a single integer, it will be interpreted as the number of equidistant
#' intervals from 0 until the maximum event time.
#' * `max_time :: numeric(1)`\cr
#' If `cut` is unspecified, this will be the last possible event time.
#' All event times after `max_time` will be administratively censored at `max_time.`
#' Needs to be greater than the minimum event time in the given task.
#'
#' @examples
#'
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3learners"), quietly = TRUE)
#' \dontrun{
#' library(mlr3)
#' library(mlr3learners)
#' library(mlr3pipelines)
#'
#' task = tsk("lung")
#'
#' # transform the survival task to a poisson regression task
#' # all unique event times are used as cutpoints
#' po_disc = po("trafotask_survregr_PEM")
#' task_regr = po_disc$train(list(task))[[1L]]
#'
#' # the end time points of the discrete time intervals
#' unique(task_regr$data(cols = "tend"))[[1L]]
#'
#' # train a classification learner
#' learner = lrn("classif.log_reg", predict_type = "prob")
#' learner$train(task_regr)
#' }
#' }
#'
#'
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "trafotask_survregr_PEM") {
param_set = ps(
cut = p_uty(default = NULL),
max_time = p_dbl(0, default = NULL, special_vals = list(NULL))
)
super$initialize(
id = id,
param_set = param_set,
input = data.table(
name = "input",
train = "TaskSurv",
predict = "TaskSurv"
),
output = data.table(
name = c("output", "transformed_data"),
train = c("TaskRegr", "data.table"),
predict = c("TaskRegr", "data.table")
)
)
}
),

private = list(
.train = function(input) {
task = input[[1L]]
assert_true(task$censtype == "right")
data = task$data()

if ("disc_status" %in% colnames(task$data())) {
stop("\"disc_status\" can not be a column in the input data.")
}

cut = assert_numeric(self$param_set$values$cut, null.ok = TRUE, lower = 0)
max_time = self$param_set$values$max_time

time_var = task$target_names[1]
event_var = task$target_names[2]
if (testInt(cut, lower = 1)) {
cut = seq(0, data[get(event_var) == 1, max(get(time_var))], length.out = cut + 1)
}

if (!is.null(max_time)) {
assert(max_time > data[get(event_var) == 1, min(get(time_var))],
"max_time must be greater than the minimum event time.")
}

form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")

long_data = pammtools::as_ped(data = data, formula = form, cut = cut, max_time = max_time)
self$state$cut = attributes(long_data)$trafo_args$cut
long_data = as.data.table(long_data)
setnames(long_data, old = "ped_status", new = "disc_status")

# remove some columns from `long_data`
long_data[, c("tstart", "interval") := NULL]
# keep id mapping
reps = table(long_data$id)
ids = rep(task$row_ids, times = reps)
id = NULL
long_data[, id := ids]

task_disc = TaskRegr$new(paste0(task$id, "_disc"), long_data,
target = "disc_status")
task_disc$set_col_roles("id", roles = "name")

list(task_disc, data.table())
},

.predict = function(input) {
task = input[[1]]
data = task$data()

# extract `cut` from `state`
cut = self$state$cut

time_var = task$target_names[1]
event_var = task$target_names[2]

max_time = max(cut)
time = data[[time_var]]
data[[time_var]] = max_time

status = data[[event_var]]
data[[event_var]] = 1

# update form
form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")

long_data = as.data.table(pammtools::as_ped(data, formula = form, cut = cut))
setnames(long_data, old = "ped_status", new = "disc_status")

disc_status = id = tend = obs_times = NULL # fixing global binding notes of data.table
long_data[, disc_status := 0]
# set correct id
rows_per_id = nrow(long_data) / length(unique(long_data$id))
long_data$obs_times = rep(time, each = rows_per_id)
ids = rep(task$row_ids, each = rows_per_id)
long_data[, id := ids]

# set correct disc_status
reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count
status = rep(status, times = reps)
long_data[long_data[, .I[tend >= obs_times], by = id]$V1, disc_status := status]

# remove some columns from `long_data`
long_data[, c("tstart", "interval", "obs_times") := NULL]
task_disc = TaskRegr$new(paste0(task$id, "_disc"), long_data,
target = "disc_status")
task_disc$set_col_roles("id", roles = "name")

# map observed times back
reps = table(long_data$id)
long_data$obs_times = rep(time, each = rows_per_id)
# subset transformed data
columns_to_keep = c("id", "obs_times", "tend", "disc_status", "offset")
long_data = long_data[, columns_to_keep, with = FALSE]

list(task_disc, long_data)
}
)
)

register_pipeop("trafotask_survregr_PEM", PipeOpTaskSurvRegrPEM)

0 comments on commit 6eb5a6e

Please sign in to comment.