Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi_predict() for coxnet models #70

Merged
merged 15 commits into from
Jul 7, 2021
Merged
6 changes: 4 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Generated by roxygen2: do not edit by hand

S3method(fit,proportional_hazards)
S3method(multi_predict,"_coxnet")
S3method(predict,"_coxnet")
S3method(predict_raw,"_coxnet")
S3method(predict_survival,"_coxnet")
S3method(print,"_coxnet")
S3method(translate,proportional_hazards)
Expand All @@ -11,7 +13,6 @@ export(cond_inference_surv_cforest)
export(cond_inference_surv_ctree)
export(flexsurv_probs)
export(glmnet_fit_wrapper)
export(linear_pred_coxnet)
export(survival_prob_cforest)
export(survival_prob_coxnet)
export(survival_prob_cph)
Expand All @@ -29,11 +30,12 @@ importFrom(generics,fit)
importFrom(parsnip,check_final_param)
importFrom(parsnip,eval_args)
importFrom(parsnip,model_printer)
importFrom(parsnip,multi_predict)
importFrom(parsnip,new_model_spec)
importFrom(parsnip,null_value)
importFrom(parsnip,predict.model_fit)
importFrom(parsnip,predict_raw)
importFrom(parsnip,predict_survival)
importFrom(parsnip,predict_survival.model_fit)
importFrom(parsnip,set_encoding)
importFrom(parsnip,set_model_arg)
importFrom(parsnip,set_new_model)
Expand Down
9 changes: 4 additions & 5 deletions R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
#' @importFrom tibble is_tibble as_tibble tibble
#' @importFrom parsnip set_new_model new_model_spec update_dot_check null_value
#' @importFrom parsnip set_encoding set_model_arg eval_args
#' @importFrom parsnip predict.model_fit predict_survival predict_survival.model_fit
#' @importFrom parsnip predict.model_fit predict_survival
#' @importFrom parsnip translate model_printer translate.default
#' @importFrom parsnip update_engine_parameters check_final_param
#' @importFrom parsnip update_main_parameters show_call
#' @importFrom parsnip multi_predict predict_raw
#' @importFrom withr with_options
#' @importFrom stats predict approx quantile na.pass na.exclude
#' @importFrom survival strata untangle.specials
Expand All @@ -19,10 +20,8 @@
#' @importFrom generics fit

utils::globalVariables(
c("time", ".time", "object", "new_data", ".label", ".pred",
".cuts", ".id", ".pred_hazard_cumulative", ".tmp", ".pred_survival",
".pred_survival_lower", ".pred_survival_upper", "engine",
"predictor_indicators", ".strata")
c("time", ".time", "object", "new_data", ".label", ".pred", ".cuts",
".id", ".tmp", "engine", "predictor_indicators", ".strata", "group")
)

# ------------------------------------------------------------------------------
Expand Down
43 changes: 31 additions & 12 deletions R/aaa_survival_prop.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ survival_prob_cph <- function(x, new_data, times, output = "surv", conf.int = .9
dplyr::select(-.row)
}

keep_cols <- function(x, output) {
if (output == "surv") {
x <- dplyr::select(x, .time, .pred_survival, .row)
} else if (output == "conf") {
x <- dplyr::select(x, .time, .pred_survival_lower, .pred_survival_upper,
.row)
keep_cols <- function(x, output, keep_penalty = FALSE) {
if (keep_penalty) {
cols_to_keep <- c(".row", "penalty", ".time")
} else {
x <- dplyr::select(x, .time, .pred_hazard_cumulative, .row)
cols_to_keep <- c(".row", ".time")
}
x
output_cols <- switch(output,
surv = ".pred_survival",
conf = c(".pred_survival_lower", ".pred_survival_upper"),
haz = ".pred_hazard_cumulative")
cols_to_keep <- c(cols_to_keep, output_cols)
dplyr::select(x, cols_to_keep)
}

stack_survfit <- function(x, n) {
Expand Down Expand Up @@ -167,8 +169,10 @@ cph_survival_pre <- function(new_data, object) {
#' @return A nested tibble.
#' @keywords internal
#' @export
survival_prob_coxnet <- function(object, new_data, times, output = "surv", ...) {
survival_prob_coxnet <- function(object, new_data, times, output = "surv", penalty = NULL, ...) {

output <- match.arg(output, c("surv", "haz"))
multi <- length(penalty) > 1

new_x <- parsnip::.convert_form_to_xy_new(
object$preproc$coxnet,
Expand All @@ -185,21 +189,36 @@ survival_prob_coxnet <- function(object, new_data, times, output = "surv", ...)
object$fit,
newx = new_x,
newstrata = new_strata,
s = penalty,
x = object$training_data$x,
y = object$training_data$y,
na.action = na.exclude,
...
)

res <- stack_survfit(y, nrow(new_data)) %>%
dplyr::group_nest(.row, .key = ".pred") %>%
if (multi) {
names(y) <- penalty
keep_penalty <- TRUE
stacked_survfit <-
purrr::map_dfr(y, stack_survfit, n = nrow(new_data), .id = "penalty") %>%
dplyr::mutate(penalty = as.numeric(penalty)) %>%
dplyr::group_nest(.row, penalty, .key = ".pred")
} else {
keep_penalty <- FALSE
stacked_survfit <-
stack_survfit(y, nrow(new_data)) %>%
dplyr::group_nest(.row, .key = ".pred")
}
res <-
stacked_survfit %>%
mutate(
.pred = purrr::map(.pred, ~ dplyr::bind_rows(prob_template, .x))
) %>%
tidyr::unnest(cols = c(.pred)) %>%
interpolate_km_values(times, new_strata) %>%
keep_cols(output) %>%
keep_cols(output, keep_penalty) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)

res
}
90 changes: 89 additions & 1 deletion R/proportional_hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,30 @@ check_glmnet_penalty <- function(x) {
# predict_survival.model_fit()
# survival_prob_coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "linear_pred")` when object has
# classes "_coxnet" and "model_fit":
#
# multi_predict()
# multi_predict._coxnet(penalty = NULL)
# predict._coxnet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_raw()
# predict_raw._coxnet()
# predict_raw.model_fit(opts = list(s = penalty))
# predict.coxnet()

# glmnet call stack for censored regression using `multi_predict(type = "survival")` when object has
# classes "_coxnet" and "model_fit":
#
# multi_predict()
# multi_predict._coxnet(penalty = NULL)
# predict._coxnet(multi = TRUE) <-- checks and sets penalty
# predict.model_fit() <-- checks for extra vars in ...
# predict_survival()
# predict_survival._coxnet()
# predict_survival.model_fit()
# survival_prob_coxnet()

#' @export
predict._coxnet <-
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
Expand All @@ -203,7 +227,71 @@ predict_survival._coxnet <- function(object, new_data, ...) {
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

object$spec <- eval_args(object$spec)
predict_survival.model_fit(object, new_data = new_data, ...)
NextMethod()
}

#' @export
predict_raw._coxnet <- function(object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

object$spec <- eval_args(object$spec)
opts$s <- object$spec$args$penalty
NextMethod()
}

#' @export
multi_predict._coxnet <- function(object,
new_data,
type = NULL,
penalty = NULL,
...) {
if (any(names(enquos(...)) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

dots <- list(...)

object$spec <- eval_args(object$spec)

if (is.null(penalty)) {
# See discussion in https://github.com/tidymodels/parsnip/issues/195
if (!is.null(object$spec$args$penalty)) {
penalty <- object$spec$args$penalty
} else {
penalty <- object$fit$lambda
}
}

if (type == "linear_pred"){
pred <- multi_predict_coxnet_linear_pred(object, new_data = new_data,
opts = dots, penalty = penalty)
} else {
pred <- predict(object, new_data = new_data, type = type, ...,
penalty = penalty, multi = TRUE)
}

pred
}

multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {
pred <- predict(object, new_data = new_data, type = "raw",
opts = opts, penalty = penalty, multi = TRUE)

# post-processing into nested tibble
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- pred %>%
as_tibble() %>%
dplyr::mutate(.row = seq_len(nrow(pred))) %>%
tidyr::pivot_longer(
- .row,
names_to = "group",
values_to = ".pred_linear_pred"
)
pred <- dplyr::inner_join(param_key, pred, by = "group") %>%
dplyr::select(-group) %>%
dplyr::arrange(.row, penalty) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)
}

#' @export
Expand Down
40 changes: 24 additions & 16 deletions R/proportional_hazards_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ make_proportional_hazards_glmnet <- function() {
mode = "censored regression",
type = "linear_pred",
value = list(
pre = NULL,
pre = coxnet_predict_pre,
post = organize_glmnet_pred,
func = c(fun = "linear_pred_coxnet"),
func = c(fun = "predict"),
args =
list(
object = expr(object),
new_data = expr(new_data),
object = expr(object$fit),
newx = expr(new_data),
type = "link",
s = expr(object$spec$args$penalty)
)
Expand All @@ -181,10 +181,27 @@ make_proportional_hazards_glmnet <- function() {
object = expr(object),
new_data = expr(new_data),
times = expr(time),
s = expr(object$spec$args$penalty)
penalty = expr(object$spec$args$penalty)
)
)
)

parsnip::set_pred(
model = "proportional_hazards",
eng = "glmnet",
mode = "censored regression",
type = "raw",
value = list(
pre = coxnet_predict_pre,
post = NULL,
func = c(fun = "predict"),
args =
list(object = expr(object$fit),
newx = expr(new_data)
)
)
)

}


Expand Down Expand Up @@ -349,18 +366,9 @@ check_dots_coxnet <- function(x) {
invisible(NULL)
}

#' A wrapper for predict() with coxnet models
#' @param object A fitted `_coxnet` object.
#' @param new_data Data for prediction.
#' @param ... Options to pass to [glmnet::predict.coxnet()].
#' @return A matrix.
#' @keywords internal
#' @export
linear_pred_coxnet <- function(object, new_data, ...) {
new_x <- parsnip::.convert_form_to_xy_new(
coxnet_predict_pre <- function(new_data, object) {
parsnip::.convert_form_to_xy_new(
object$preproc$coxnet,
new_data,
composition = "matrix")$x

predict(object$fit, newx = new_x, ...)
}
22 changes: 0 additions & 22 deletions man/linear_pred_coxnet.Rd

This file was deleted.

9 changes: 8 additions & 1 deletion man/survival_prob_coxnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading