Skip to content

Commit

Permalink
Merge pull request #70 from tidymodels/coxnet-multi-predict
Browse files Browse the repository at this point in the history
`multi_predict()` for coxnet models
  • Loading branch information
hfrick authored Jul 7, 2021
2 parents a7e537d + cf982f0 commit 253533f
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 70 deletions.
6 changes: 4 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Generated by roxygen2: do not edit by hand

S3method(fit,proportional_hazards)
S3method(multi_predict,"_coxnet")
S3method(predict,"_coxnet")
S3method(predict_raw,"_coxnet")
S3method(predict_survival,"_coxnet")
S3method(print,"_coxnet")
S3method(translate,proportional_hazards)
Expand All @@ -11,7 +13,6 @@ export(cond_inference_surv_cforest)
export(cond_inference_surv_ctree)
export(flexsurv_probs)
export(glmnet_fit_wrapper)
export(linear_pred_coxnet)
export(survival_prob_cforest)
export(survival_prob_coxnet)
export(survival_prob_cph)
Expand All @@ -29,11 +30,12 @@ importFrom(generics,fit)
importFrom(parsnip,check_final_param)
importFrom(parsnip,eval_args)
importFrom(parsnip,model_printer)
importFrom(parsnip,multi_predict)
importFrom(parsnip,new_model_spec)
importFrom(parsnip,null_value)
importFrom(parsnip,predict.model_fit)
importFrom(parsnip,predict_raw)
importFrom(parsnip,predict_survival)
importFrom(parsnip,predict_survival.model_fit)
importFrom(parsnip,set_encoding)
importFrom(parsnip,set_model_arg)
importFrom(parsnip,set_new_model)
Expand Down
9 changes: 4 additions & 5 deletions R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
#' @importFrom tibble is_tibble as_tibble tibble
#' @importFrom parsnip set_new_model new_model_spec update_dot_check null_value
#' @importFrom parsnip set_encoding set_model_arg eval_args
#' @importFrom parsnip predict.model_fit predict_survival predict_survival.model_fit
#' @importFrom parsnip predict.model_fit predict_survival
#' @importFrom parsnip translate model_printer translate.default
#' @importFrom parsnip update_engine_parameters check_final_param
#' @importFrom parsnip update_main_parameters show_call
#' @importFrom parsnip multi_predict predict_raw
#' @importFrom withr with_options
#' @importFrom stats predict approx quantile na.pass na.exclude
#' @importFrom survival strata untangle.specials
Expand All @@ -19,10 +20,8 @@
#' @importFrom generics fit

utils::globalVariables(
c("time", ".time", "object", "new_data", ".label", ".pred",
".cuts", ".id", ".pred_hazard_cumulative", ".tmp", ".pred_survival",
".pred_survival_lower", ".pred_survival_upper", "engine",
"predictor_indicators", ".strata")
c("time", ".time", "object", "new_data", ".label", ".pred", ".cuts",
".id", ".tmp", "engine", "predictor_indicators", ".strata", "group")
)

# ------------------------------------------------------------------------------
Expand Down
43 changes: 31 additions & 12 deletions R/aaa_survival_prop.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ survival_prob_cph <- function(x, new_data, times, output = "surv", conf.int = .9
dplyr::select(-.row)
}

keep_cols <- function(x, output) {
if (output == "surv") {
x <- dplyr::select(x, .time, .pred_survival, .row)
} else if (output == "conf") {
x <- dplyr::select(x, .time, .pred_survival_lower, .pred_survival_upper,
.row)
keep_cols <- function(x, output, keep_penalty = FALSE) {
if (keep_penalty) {
cols_to_keep <- c(".row", "penalty", ".time")
} else {
x <- dplyr::select(x, .time, .pred_hazard_cumulative, .row)
cols_to_keep <- c(".row", ".time")
}
x
output_cols <- switch(output,
surv = ".pred_survival",
conf = c(".pred_survival_lower", ".pred_survival_upper"),
haz = ".pred_hazard_cumulative")
cols_to_keep <- c(cols_to_keep, output_cols)
dplyr::select(x, cols_to_keep)
}

stack_survfit <- function(x, n) {
Expand Down Expand Up @@ -167,8 +169,10 @@ cph_survival_pre <- function(new_data, object) {
#' @return A nested tibble.
#' @keywords internal
#' @export
survival_prob_coxnet <- function(object, new_data, times, output = "surv", ...) {
survival_prob_coxnet <- function(object, new_data, times, output = "surv", penalty = NULL, ...) {

output <- match.arg(output, c("surv", "haz"))
multi <- length(penalty) > 1

new_x <- parsnip::.convert_form_to_xy_new(
object$preproc$coxnet,
Expand All @@ -185,21 +189,36 @@ survival_prob_coxnet <- function(object, new_data, times, output = "surv", ...)
object$fit,
newx = new_x,
newstrata = new_strata,
s = penalty,
x = object$training_data$x,
y = object$training_data$y,
na.action = na.exclude,
...
)

res <- stack_survfit(y, nrow(new_data)) %>%
dplyr::group_nest(.row, .key = ".pred") %>%
if (multi) {
names(y) <- penalty
keep_penalty <- TRUE
stacked_survfit <-
purrr::map_dfr(y, stack_survfit, n = nrow(new_data), .id = "penalty") %>%
dplyr::mutate(penalty = as.numeric(penalty)) %>%
dplyr::group_nest(.row, penalty, .key = ".pred")
} else {
keep_penalty <- FALSE
stacked_survfit <-
stack_survfit(y, nrow(new_data)) %>%
dplyr::group_nest(.row, .key = ".pred")
}
res <-
stacked_survfit %>%
mutate(
.pred = purrr::map(.pred, ~ dplyr::bind_rows(prob_template, .x))
) %>%
tidyr::unnest(cols = c(.pred)) %>%
interpolate_km_values(times, new_strata) %>%
keep_cols(output) %>%
keep_cols(output, keep_penalty) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)

res
}
90 changes: 89 additions & 1 deletion R/proportional_hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,30 @@ check_glmnet_penalty <- function(x) {
# predict_survival.model_fit()
# survival_prob_coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "linear_pred")` when object has
# classes "_coxnet" and "model_fit":
#
# multi_predict()
# multi_predict._coxnet(penalty = NULL)
# predict._coxnet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_raw()
# predict_raw._coxnet()
# predict_raw.model_fit(opts = list(s = penalty))
# predict.coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "survival")` when object has
# classes "_coxnet" and "model_fit":
#
# multi_predict()
# multi_predict._coxnet(penalty = NULL)
# predict._coxnet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_survival()
# predict_survival._coxnet()
# predict_survival.model_fit()
# survival_prob_coxnet()

#' @export
predict._coxnet <-
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
Expand All @@ -203,7 +227,71 @@ predict_survival._coxnet <- function(object, new_data, ...) {
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

object$spec <- eval_args(object$spec)
predict_survival.model_fit(object, new_data = new_data, ...)
NextMethod()
}

#' @export
predict_raw._coxnet <- function(object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

object$spec <- eval_args(object$spec)
opts$s <- object$spec$args$penalty
NextMethod()
}

#' @export
multi_predict._coxnet <- function(object,
new_data,
type = NULL,
penalty = NULL,
...) {
if (any(names(enquos(...)) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

dots <- list(...)

object$spec <- eval_args(object$spec)

if (is.null(penalty)) {
# See discussion in https://github.com/tidymodels/parsnip/issues/195
if (!is.null(object$spec$args$penalty)) {
penalty <- object$spec$args$penalty
} else {
penalty <- object$fit$lambda
}
}

if (type == "linear_pred"){
pred <- multi_predict_coxnet_linear_pred(object, new_data = new_data,
opts = dots, penalty = penalty)
} else {
pred <- predict(object, new_data = new_data, type = type, ...,
penalty = penalty, multi = TRUE)
}

pred
}

multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {
pred <- predict(object, new_data = new_data, type = "raw",
opts = opts, penalty = penalty, multi = TRUE)

# post-processing into nested tibble
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- pred %>%
as_tibble() %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::pivot_longer(
- .row,
names_to = "group",
values_to = ".pred_linear_pred"
)
pred <- dplyr::inner_join(param_key, pred, by = "group") %>%
dplyr::select(-group) %>%
dplyr::arrange(.row, penalty) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
}

#' @export
Expand Down
40 changes: 24 additions & 16 deletions R/proportional_hazards_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ make_proportional_hazards_glmnet <- function() {
mode = "censored regression",
type = "linear_pred",
value = list(
pre = NULL,
pre = coxnet_predict_pre,
post = organize_glmnet_pred,
func = c(fun = "linear_pred_coxnet"),
func = c(fun = "predict"),
args =
list(
object = expr(object),
new_data = expr(new_data),
object = expr(object$fit),
newx = expr(new_data),
type = "link",
s = expr(object$spec$args$penalty)
)
Expand All @@ -181,10 +181,27 @@ make_proportional_hazards_glmnet <- function() {
object = expr(object),
new_data = expr(new_data),
times = expr(time),
s = expr(object$spec$args$penalty)
penalty = expr(object$spec$args$penalty)
)
)
)

parsnip::set_pred(
model = "proportional_hazards",
eng = "glmnet",
mode = "censored regression",
type = "raw",
value = list(
pre = coxnet_predict_pre,
post = NULL,
func = c(fun = "predict"),
args =
list(object = expr(object$fit),
newx = expr(new_data)
)
)
)

}


Expand Down Expand Up @@ -349,18 +366,9 @@ check_dots_coxnet <- function(x) {
invisible(NULL)
}

#' A wrapper for predict() with coxnet models
#' @param object A fitted `_coxnet` object.
#' @param new_data Data for prediction.
#' @param ... Options to pass to [glmnet::predict.coxnet()].
#' @return A matrix.
#' @keywords internal
#' @export
linear_pred_coxnet <- function(object, new_data, ...) {
new_x <- parsnip::.convert_form_to_xy_new(
coxnet_predict_pre <- function(new_data, object) {
parsnip::.convert_form_to_xy_new(
object$preproc$coxnet,
new_data,
composition = "matrix")$x

predict(object$fit, newx = new_x, ...)
}
22 changes: 0 additions & 22 deletions man/linear_pred_coxnet.Rd

This file was deleted.

9 changes: 8 additions & 1 deletion man/survival_prob_coxnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 253533f

Please sign in to comment.