Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi_predict() for coxnet models #70

Merged
merged 15 commits into from
Jul 7, 2021
Merged
6 changes: 5 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Generated by roxygen2: do not edit by hand

S3method(fit,proportional_hazards)
S3method(multi_predict,"_coxnet")
S3method(predict,"_coxnet")
S3method(predict_raw,"_coxnet")
S3method(predict_survival,"_coxnet")
S3method(print,"_coxnet")
S3method(translate,proportional_hazards)
Expand All @@ -11,7 +13,6 @@ export(cond_inference_surv_cforest)
export(cond_inference_surv_ctree)
export(flexsurv_probs)
export(glmnet_fit_wrapper)
export(linear_pred_coxnet)
export(survival_prob_cforest)
export(survival_prob_coxnet)
export(survival_prob_cph)
Expand All @@ -29,9 +30,12 @@ importFrom(generics,fit)
importFrom(parsnip,check_final_param)
importFrom(parsnip,eval_args)
importFrom(parsnip,model_printer)
importFrom(parsnip,multi_predict)
importFrom(parsnip,new_model_spec)
importFrom(parsnip,null_value)
importFrom(parsnip,predict.model_fit)
importFrom(parsnip,predict_raw)
importFrom(parsnip,predict_raw.model_fit)
importFrom(parsnip,predict_survival)
importFrom(parsnip,predict_survival.model_fit)
importFrom(parsnip,set_encoding)
Expand Down
7 changes: 3 additions & 4 deletions R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' @importFrom parsnip translate model_printer translate.default
#' @importFrom parsnip update_engine_parameters check_final_param
#' @importFrom parsnip update_main_parameters show_call
#' @importFrom parsnip multi_predict predict_raw predict_raw.model_fit
#' @importFrom withr with_options
#' @importFrom stats predict approx quantile na.pass na.exclude
#' @importFrom survival strata untangle.specials
Expand All @@ -19,10 +20,8 @@
#' @importFrom generics fit

utils::globalVariables(
c("time", ".time", "object", "new_data", ".label", ".pred",
".cuts", ".id", ".pred_hazard_cumulative", ".tmp", ".pred_survival",
".pred_survival_lower", ".pred_survival_upper", "engine",
"predictor_indicators", ".strata")
c("time", ".time", "object", "new_data", ".label", ".pred", ".cuts",
".id", ".tmp", "engine", "predictor_indicators", ".strata", "group")
)

# ------------------------------------------------------------------------------
Expand Down
43 changes: 31 additions & 12 deletions R/aaa_survival_prop.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ survival_prob_cph <- function(x, new_data, times, output = "surv", conf.int = .9
dplyr::select(-.row)
}

keep_cols <- function(x, output) {
if (output == "surv") {
x <- dplyr::select(x, .time, .pred_survival, .row)
} else if (output == "conf") {
x <- dplyr::select(x, .time, .pred_survival_lower, .pred_survival_upper,
.row)
keep_cols <- function(x, output, keep_penalty = FALSE) {
if (keep_penalty) {
cols_to_keep <- c(".row", "penalty", ".time")
} else {
x <- dplyr::select(x, .time, .pred_hazard_cumulative, .row)
cols_to_keep <- c(".row", ".time")
}
x
output_cols <- switch(output,
surv = ".pred_survival",
conf = c("pred_survival_lower", ".pred_survival_upper"),
haz = ".pred_hazard_cumulative")
cols_to_keep <- c(cols_to_keep, output_cols)
dplyr::select(x, cols_to_keep)
}

stack_survfit <- function(x, n) {
Expand Down Expand Up @@ -167,8 +169,10 @@ cph_survival_pre <- function(new_data, object) {
#' @return A nested tibble.
#' @keywords internal
#' @export
survival_prob_coxnet <- function(object, new_data, times, output = "surv", ...) {
survival_prob_coxnet <- function(object, new_data, times, output = "surv", penalty = NULL, ...) {

output <- match.arg(output, c("surv", "haz"))
multi <- length(penalty) > 1

new_x <- parsnip::.convert_form_to_xy_new(
object$preproc$coxnet,
Expand All @@ -185,21 +189,36 @@ survival_prob_coxnet <- function(object, new_data, times, output = "surv", ...)
object$fit,
newx = new_x,
newstrata = new_strata,
s = penalty,
x = object$training_data$x,
y = object$training_data$y,
na.action = na.exclude,
...
)

res <- stack_survfit(y, nrow(new_data)) %>%
dplyr::group_nest(.row, .key = ".pred") %>%
if (multi) {
names(y) <- penalty
keep_penalty <- TRUE
stacked_survfit <-
purrr::map_dfr(y, stack_survfit, n = nrow(new_data), .id = "penalty") %>%
dplyr::mutate(penalty = as.numeric(penalty)) %>%
dplyr::group_nest(.row, penalty, .key = ".pred")
} else {
keep_penalty <- FALSE
stacked_survfit <-
stack_survfit(y, nrow(new_data)) %>%
dplyr::group_nest(.row, .key = ".pred")
}
res <-
stacked_survfit %>%
mutate(
.pred = purrr::map(.pred, ~ dplyr::bind_rows(prob_template, .x))
) %>%
tidyr::unnest(cols = c(.pred)) %>%
interpolate_km_values(times, new_strata) %>%
keep_cols(output) %>%
keep_cols(output, keep_penalty) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)

res
}
83 changes: 83 additions & 0 deletions R/proportional_hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,30 @@ check_glmnet_penalty <- function(x) {
# predict_survival.model_fit()
# survival_prob_coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "linear_pred")` when object has
# classes "_coxnet" and "model_fit":
#
# multi_predict()
# multi_predict._coxnet(penalty = NULL)
# predict._coxnet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_raw()
# predict_raw._coxnet()
# predict_raw.model_fit(opts = list(s = penalty))
# predict.coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "survival")` when object has
# classes "_coxnet" and "model_fit":
#
# multi_predict()
# multi_predict._coxnet(penalty = NULL)
# predict._coxnet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_survival()
# predict_survival._coxnet()
# predict_survival.model_fit()
# survival_prob_coxnet()

#' @export
predict._coxnet <-
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
Expand All @@ -206,6 +230,65 @@ predict_survival._coxnet <- function(object, new_data, ...) {
predict_survival.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_raw._coxnet <- function(object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

object$spec <- eval_args(object$spec)
opts$s <- object$spec$args$penalty
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}

#' @export
multi_predict._coxnet <- function(object,
new_data,
type = NULL,
penalty = NULL,
...) {
if (any(names(enquos(...)) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

dots <- list(...)

object$spec <- eval_args(object$spec)

if (is.null(penalty)) {
# See discussion in https://github.com/tidymodels/parsnip/issues/195
if (!is.null(object$spec$args$penalty)) {
penalty <- object$spec$args$penalty
} else {
penalty <- object$fit$lambda
}
}

if (type != "linear_pred"){
pred <- predict(object, new_data = new_data, type = type, ...,
penalty = penalty, multi = TRUE)
} else {
pred <- predict(object, new_data = new_data, type = "raw",
opts = dots, penalty = penalty, multi = TRUE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this else statement is pretty long, you might consider making two helpers for the specific predict types and doing:

switch(
  type,
  linear_pred = multi_predict_coxnet_linear_pred(...),
  survival = multi_predict_coxnet_survival(...),
  abort("Internal error: Unknown `type`.")
)

That would also make it easier to extend if we get more types

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the code for linear_pred into its helper function 👍 Regarding the switch statement: I think the other types may require a similar structure as the survival probabilities so hopefully the rest is just predict(type = type).


# post-processing into nested tibble
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are currently predict methods for types "linear_pred" and "survival". For the "linear_pred"-type predictions, this follows what parsnip does for linear_reg() with a glmnet engine. For the survival probabilities, we use the survival curves from survfit() and have the wrapper survival_prob_coxnet() already so I extended that one to be able to deal with a vector of penalties. This also allows for convenient minimal nesting where we only group according to strata, see also #47 and #63.

param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- pred %>%
as_tibble() %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::pivot_longer(
- .row,
names_to = "group",
values_to = ".pred_linear_pred"
)
pred <- dplyr::full_join(param_key, pred, by = "group") %>%
dplyr::select(-group) %>%
dplyr::arrange(.row, penalty) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
}

pred
}

#' @export
print._coxnet <- function(x, ...) {
cat("parsnip model object\n\n")
Expand Down
40 changes: 24 additions & 16 deletions R/proportional_hazards_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ make_proportional_hazards_glmnet <- function() {
mode = "censored regression",
type = "linear_pred",
value = list(
pre = NULL,
pre = coxnet_predict_pre,
post = organize_glmnet_pred,
func = c(fun = "linear_pred_coxnet"),
func = c(fun = "predict"),
args =
list(
object = expr(object),
new_data = expr(new_data),
object = expr(object$fit),
newx = expr(new_data),
type = "link",
s = expr(object$spec$args$penalty)
)
Expand All @@ -181,10 +181,27 @@ make_proportional_hazards_glmnet <- function() {
object = expr(object),
new_data = expr(new_data),
times = expr(time),
s = expr(object$spec$args$penalty)
penalty = expr(object$spec$args$penalty)
)
)
)

parsnip::set_pred(
model = "proportional_hazards",
eng = "glmnet",
mode = "censored regression",
type = "raw",
value = list(
pre = coxnet_predict_pre,
post = NULL,
func = c(fun = "predict"),
args =
list(object = expr(object$fit),
newx = expr(new_data)
)
)
)

}


Expand Down Expand Up @@ -349,18 +366,9 @@ check_dots_coxnet <- function(x) {
invisible(NULL)
}

#' A wrapper for predict() with coxnet models
#' @param object A fitted `_coxnet` object.
#' @param new_data Data for prediction.
#' @param ... Options to pass to [glmnet::predict.coxnet()].
#' @return A matrix.
#' @keywords internal
#' @export
linear_pred_coxnet <- function(object, new_data, ...) {
new_x <- parsnip::.convert_form_to_xy_new(
coxnet_predict_pre <- function(new_data, object) {
parsnip::.convert_form_to_xy_new(
object$preproc$coxnet,
new_data,
composition = "matrix")$x

predict(object$fit, newx = new_x, ...)
}
22 changes: 0 additions & 22 deletions man/linear_pred_coxnet.Rd

This file was deleted.

9 changes: 8 additions & 1 deletion man/survival_prob_coxnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading