Skip to content

Commit

Permalink
Remove return_weights as an option from predict.caretStack. Use `…
Browse files Browse the repository at this point in the history
…varImp` instead. (#302)

* remove return_weights.  use varImp to get model weights

* comment

* fix stack

* test on more releases

* rebuild

* breaking tests

* Revert "breaking tests"

This reverts commit 949ea70.

* dont use xml either

* dont use xml either

* require 4.1.0

* rebuild
  • Loading branch information
zachmayer authored Aug 7, 2024
1 parent b3c088c commit bc8d5de
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 1,530 deletions.
18 changes: 13 additions & 5 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@ jobs:

strategy:
fail-fast: false
matrix:
config:
matrix: # https://en.wikipedia.org/wiki/R_(programming_language)#Version_names
config: # https://github.com/r-hub/rversions rversions::r_versions()

# macOS / Widnows on latest
- {os: macos-latest, r: 'release'}
- {os: windows-latest, r: 'release'}
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'}
- {os: ubuntu-latest, r: 'oldrel-1'}

# Ubuntu on latest, devel, release, oldrel as well as version in DESCRIPTION
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'} # rversions::r_release()
- {os: ubuntu-latest, r: 'oldrel-1'} # rversions::r_oldrel()
- {os: ubuntu-latest, r: 'oldrel-2'}
- {os: ubuntu-latest, r: 'oldrel-3'}
# - {os: ubuntu-latest, r: 'oldrel-4'} After the next major release add this back in
- {os: ubuntu-latest, r: '4.1.0'} # Oldest supported release. 2021-05-18: 3+ years old

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# huge HTML file messes up git stats lol
# huge XML/HTML file messes up git stats lol
coverage-report.html
cobertura.xml

# weirdness from covr::package_coverage
lib/
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Description: Functions for creating ensembles of caret models: caretList()
non-linear combinations of these models, using a caret::train() model as a
meta-model.
Depends:
R (>= 3.2.0)
R (>= 4.1.0)
Suggests:
MASS,
caTools,
Expand Down
50 changes: 14 additions & 36 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ caretStack <- function(
#' @param newdata a new dataframe to make predictions on
#' @param se logical, should prediction errors be produced? Default is false.
#' @param level tolerance/confidence level
#' @param return_weights a logical indicating whether prediction weights for each model
#' should be returned
#' @param excluded_class_id Which class to exclude from predictions. Note that if the caretStack
#' was trained with an excluded_class_id, that class is ALWAYS excluded from the predictions from the
Expand Down Expand Up @@ -136,7 +135,6 @@ predict.caretStack <- function(
newdata = NULL,
se = FALSE,
level = 0.95,
return_weights = FALSE,
excluded_class_id = 0L,
return_class_only = FALSE,
verbose = FALSE,
Expand All @@ -153,59 +151,43 @@ predict.caretStack <- function(

# Check return_class_only
if (return_class_only) {
stopifnot(
is_class,
!se
)
stopifnot(is_class, !se)
excluded_class_id <- 0L
}

# Calculate variable importance if needed
if (se || return_weights) {
imp <- caret::varImp(object, newdata = newdata, normalize = TRUE)
}

# Get predictions from the submodels on the new data
# If there's no new data, we just use stacked predictions from the ensemble model
if (!is.null(newdata)) {
# These will be regular predictions
newdata <- data.table::as.data.table(newdata)
sub_model_preds <- stats::predict(
# We need theres if there's newdata, for passing the base model predictions to the stack model
# We also need these if we're calculting standard errors for the predictions
sub_model_preds <- if (!is.null(newdata) || se) {
stats::predict(
object$models,
newdata = newdata,
verbose = verbose,
excluded_class_id = object[["excluded_class_id"]]
)
newdata <- sub_model_preds
} else if (se) {
# These will be stacked predictions
sub_model_preds <- stats::predict(
object$models,
newdata = newdata,
verbose = verbose,
excluded_class_id = object[["excluded_class_id"]]
)
} else {
sub_model_preds <- NULL
}

# Now predict on the stack
# If newdata is NULL, this will be stacked predictions from caret::train
# If newdata is present, this will be regular predictions on top
# of the sub_model_preds.
meta_preds <- caretPredict(object$ens_model, newdata = newdata, excluded_class_id = excluded_class_id, ...)
meta_preds <- caretPredict(
object$ens_model,
newdata = if (!is.null(newdata)) sub_model_preds,
excluded_class_id = excluded_class_id,
...
)

# Decide output:
# IF SE, data.table of predictins, lower, and upper bounds
# IF return_class_only, factor of class levels
# ELSE, data.table of predictions
if (se) {
imp <- caret::varImp(object, newdata = newdata, normalize = TRUE)
std_error <- as.matrix(sub_model_preds[, names(imp), with = FALSE])
std_error <- apply(std_error, 1L, wtd.sd, w = imp, na.rm = TRUE)
std_error <- stats::qnorm(level) * std_error
if (ncol(meta_preds) == 1L) {
meta_preds <- meta_preds[[1L]]
}
meta_preds <- if (ncol(meta_preds) == 1L) meta_preds[[1L]] else meta_preds
out <- data.table::data.table(
pred = meta_preds,
lwr = meta_preds - std_error,
Expand All @@ -221,11 +203,6 @@ predict.caretStack <- function(
out <- meta_preds
}

# Add weights to output if needed
if (return_weights) {
attr(out, "weights") <- imp
}

# Return
out
}
Expand Down Expand Up @@ -465,6 +442,7 @@ plot.caretStack <- function(x, metric = NULL, ...) {
#' )
#' autoplot(ens)
#' }
# https://github.com/thomasp85/patchwork/issues/226 — why we need importFrom patchwork plot_layout
autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
stopifnot(methods::is(object, "caretStack"))
ensemble_data <- extractPredObsResid(object$ens_model, show_class_id = show_class_id)
Expand Down
Loading

0 comments on commit bc8d5de

Please sign in to comment.