Skip to content

Commit

Permalink
Refactor inner functions (#303)
Browse files Browse the repository at this point in the history
* rebuild

* rebuild

* fix tests and rebuild
  • Loading branch information
zachmayer authored Aug 7, 2024
1 parent bc8d5de commit d3af788
Show file tree
Hide file tree
Showing 17 changed files with 241 additions and 185 deletions.
9 changes: 5 additions & 4 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ caretList <- function(
if (is.null(metric)) {
metric <- "RMSE"
if (is_class) {
metric <- "Accuracy"
if (is_binary) {
metric <- "ROC"
}
metric <- if (is_binary) "ROC" else "Accuracy"
}
}

Expand All @@ -94,6 +91,10 @@ caretList <- function(
)
}

# ALWAYS save class probs
trControl[["classProbs"]] <- is_class
trControl["savePredictions"] <- "final"

# Capture global arguments for train as a list
# Squish trControl back onto the global arguments list
global_args <- list(...)
Expand Down
202 changes: 106 additions & 96 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,22 @@
#' @param excluded_class_id an integer indicating the class to exclude. If 0L, no class is excluded
#' @param ... additional arguments to pass to \code{\link[caret]{predict.train}}, if newdata is not NULL
#' @return a data.table
#' @keywords internal
caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {
stopifnot(methods::is(object, "train"))

# Extract the model type
model_type <- extractModelType(object, validate_for_stacking = is.null(newdata))
is_class <- isClassifierAndValidate(object, validate_for_stacking = is.null(newdata))

# If newdata is NULL, return the stacked predictions
if (is.null(newdata)) {
# Extract the best tune
a <- data.table::data.table(object$bestTune, key = names(object$bestTune))

# Extract the best predictions
b <- data.table::data.table(object$pred, key = names(object$bestTune))

# Subset pred data to the best tune only
pred <- b[a, ]

# Keep only the predictions
keep_cols <- "pred"
if (model_type == "Classification") {
keep_cols <- levels(object)
}
pred <- pred[, c("rowIndex", keep_cols), drop = FALSE, with = FALSE]

# If we have multiple resamples per row
# e.g. for repeated CV, we need to average the predictions
data.table::setkeyv(pred, "rowIndex")
pred <- pred[, lapply(.SD, mean), by = "rowIndex"]
data.table::setorderv(pred, "rowIndex")

# Remove the rowIndex
data.table::set(pred, j = "rowIndex", value = NULL)
pred <- extractBestPreds(object)
keep_cols <- if (is_class) levels(object) else "pred"
pred <- pred[, keep_cols, with = FALSE]

# Otherwise, predict on newdata
} else {
if (model_type == "Classification") {
if (is_class) {
pred <- caret::predict.train(object, type = "prob", newdata = newdata, ...)
} else {
pred <- caret::predict.train(object, type = "raw", newdata = newdata, ...)
Expand All @@ -62,7 +42,7 @@ caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {
# Make sure in both cases we have consitent column names and column order
# Drop the excluded class for classificaiton
stopifnot(nrow(pred) == nrow(newdata))
if (model_type == "Classification") {
if (is_class) {
stopifnot(
ncol(pred) == nlevels(object),
names(pred) == levels(object)
Expand Down Expand Up @@ -110,9 +90,9 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =
model <- do.call(caret::train, model_args)
}

# Use data.table for stacked predictions
# Only save stacked predictions for the best model
if ("pred" %in% names(model)) {
model[["pred"]] <- data.table::data.table(model[["pred"]])
model[["pred"]] <- extractBestPreds(model)
}

if (trim) {
Expand Down Expand Up @@ -143,6 +123,55 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =
model
}

#' @title Aggregate mean or first
#' @description For numeric data take the mean. For character data take the first value.
#' @param x a train object
#' @return a data.table::data.table with predictions
#' @keywords internal
aggregate_mean_or_first <- function(x) {
if (is.numeric(x)) {
mean(x)
} else {
x[1L]
}
}

#' @title Extract the best predictions from a train object
#' @description Extract the best predictions from a train object.
#' @param x a train object
#' @return a data.table::data.table with predictions
#' @keywords internal
extractBestPreds <- function(x) {
stopifnot(methods::is(x, "train"))
if (is.null(x$pred)) {
stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE)
}
stopifnot(methods::is(x$pred, "data.frame"))

# Extract the best tune
keys <- names(x$bestTune)
best_tune <- data.table::data.table(x$bestTune, key = keys)

# Extract the best predictions
pred <- data.table::data.table(x$pred, key = keys)

# Subset pred data to the best tune only
# Drop rows for other tunes
pred <- pred[best_tune, ]

# If we have multiple resamples per row
# e.g. for repeated CV, we need to average the predictions
keys <- "rowIndex"
data.table::setkeyv(pred, keys)
pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys]

# Order results consistently
data.table::setorderv(pred, keys)

# Return
pred
}

#' @title Validate the excluded class
#' @description Helper function to ensure that the excluded level for classification is an integer.
#' Set to 0L to exclude no class.
Expand Down Expand Up @@ -202,45 +231,6 @@ dropExcludedClass <- function(x, all_classes, excluded_class_id) {
x
}

#' @title Extract the model type from a \code{\link[caret]{train}} object
#' @description Extract the model type from a \code{\link[caret]{train}} object.
#' For classification, validates that the model can predict probabilities, and,
#' if stacked predictions are requested, that classProbs = TRUE.
#' @param object a \code{\link[caret]{train}} object
#' @param validate_for_stacking a logical indicating whether to validate the object for stacked predictions
#' @return a character string
#' @keywords internal
extractModelType <- function(object, validate_for_stacking = TRUE) {
stopifnot(methods::is(object, "train"))

# Extract type
model_type <- object$modelType

# Class or reg?
is_class <- model_type == "Classification"

# Validate for predictions
if (is_class && !is.function(object$modelInfo$prob)) {
stop("No probability function found. Re-fit with a method that supports prob.", call. = FALSE)
}
# Validate for stacked predictions
if (validate_for_stacking) {
err <- "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions."
if (is.null(object$control$savePredictions)) {
stop(err, call. = FALSE)
}
if (!object$control$savePredictions %in% c("all", "final", TRUE)) {
stop(err, call. = FALSE)
}
if (is_class && !object$control$classProbs) {
stop("classProbs = FALSE. Re-fit with classProbs = TRUE in trainControl.", call. = FALSE)
}
}

# Return
model_type
}

#' @title S3 definition for concatenating train objects
#'
#' @description take N objects of class train and concatenate into an object of class caretList for future ensembling
Expand Down Expand Up @@ -308,8 +298,8 @@ extractMetric <- function(x, ...) {
#' @param ... ignored
#' If NULL, uses the metric that was used to train the model.
#' @return A numeric representing the metric desired metric.
#' @export
#' @method extractMetric train
#' @export
extractMetric.train <- function(x, metric = NULL, ...) {
if (is.null(metric) || !metric %in% names(x$results)) {
metric <- x$metric
Expand Down Expand Up @@ -340,6 +330,7 @@ extractMetric.train <- function(x, metric = NULL, ...) {
#' used instead.
#' @param x a single caret train object
#' @return Name associated with model
#' @keywords internal
extractModelName <- function(x) {
if (is.list(x$method)) {
checkCustomModel(x$method)$method
Expand All @@ -350,33 +341,52 @@ extractModelName <- function(x) {
}
}

#' @title Extract the best predictions and observations from a train object
#' @description This function extracts the best predictions and observations from a train object
#' and then calculates residuals. It only uses one class for classification models, by default class 2.
#' @param object a \code{train} object
#' @param show_class_id For classification only: which class level to use for residuals
#' @return a data.table::data.table with predictions, observeds, and residuals
extractPredObsResid <- function(object, show_class_id = 2L) {
if (is.null(object$pred)) {
stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE)
#' @title Is Classifier
#' @description Check if a model is a classifier.
#' @param model A train object from the caret package.
#' @return A logical indicating whether the model is a classifier.
#' @keywords internal
isClassifier <- function(model) {
stopifnot(methods::is(model, "train") || methods::is(model, "caretStack"))
if (methods::is(model, "train")) {
out <- model$modelType == "Classification"
} else {
out <- model$ens_model$modelType == "Classification"
}
stopifnot(
methods::is(object, "train"),
is.data.frame(object$pred)
)
keep_cols <- c("pred", "obs", "rowIndex")
type <- object$modelType
predobs <- data.table::data.table(object$pred)
if (type == "Classification") {
show_class <- levels(object)[show_class_id]
data.table::set(predobs, j = "pred", value = predobs[[show_class]])
data.table::set(predobs, j = "obs", value = as.integer(predobs[["obs"]] == show_class))
out
}

#' @title Validate a model type
#' @description Validate the model type from a \code{\link[caret]{train}} object.
#' For classification, validates that the model can predict probabilities, and,
#' if stacked predictions are requested, that classProbs = TRUE.
#' @param object a \code{\link[caret]{train}} object
#' @param validate_for_stacking a logical indicating whether to validate the object for stacked predictions
#' @return a logical. TRUE if classifier, otherwise FALSE.
#' @keywords internal
isClassifierAndValidate <- function(object, validate_for_stacking = TRUE) {
stopifnot(methods::is(object, "train"))

is_class <- isClassifier(object)

# Validate for predictions
if (is_class && !is.function(object$modelInfo$prob)) {
stop("No probability function found. Re-fit with a method that supports prob.", call. = FALSE)
}
predobs <- predobs[, keep_cols, with = FALSE]
data.table::setkeyv(predobs, "rowIndex")
predobs <- predobs[, lapply(.SD, mean), by = "rowIndex"]
r <- predobs[["obs"]] - predobs[["pred"]]
data.table::set(predobs, j = "resid", value = r)
data.table::setorderv(predobs, "rowIndex")
predobs
# Validate for stacked predictions
if (validate_for_stacking) {
err <- "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions."
if (is.null(object$control$savePredictions)) {
stop(err, call. = FALSE)
}
if (!object$control$savePredictions %in% c("all", "final", TRUE)) {
stop(err, call. = FALSE)
}
if (is_class && !object$control$classProbs) {
stop("classProbs = FALSE. Re-fit with classProbs = TRUE in trainControl.", call. = FALSE)
}
}

# Return
is_class
}
43 changes: 35 additions & 8 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,10 @@ predict.caretStack <- function(
check_caretStack(object)

# Extract model types
model_type <- object$ens_model$modelType
is_class <- model_type == "Classification"
is_class <- isClassifier(object)

# If the excluded class wasn't set at train time, set it
object <- set_excluded_class_id(object, model_type)
object <- set_excluded_class_id(object, is_class)

# Check return_class_only
if (return_class_only) {
Expand Down Expand Up @@ -224,10 +223,10 @@ check_caretStack <- function(object) {
#' @description Set the excluded class id for a caretStack object
#'
#' @param object a caretStack object
#' @param model_type the model type as a character vector with length 1
#' @param is_class the model type as a logical vector with length 1
#' @keywords internal
set_excluded_class_id <- function(object, model_type) {
if (model_type == "Classification" && is.null(object[["excluded_class_id"]])) {
set_excluded_class_id <- function(object, is_class) {
if (is_class && is.null(object[["excluded_class_id"]])) {
object[["excluded_class_id"]] <- 1L
warning("No excluded_class_id set. Setting to 1L.", call. = FALSE)
}
Expand Down Expand Up @@ -412,6 +411,34 @@ plot.caretStack <- function(x, metric = NULL, ...) {
plt
}

#' @title Extracted stacked residuals for the autoplot
#' @description This function extracts the predictions, observeds, and residuals from a \code{train} object.
#' It uses the object's stacked predictions from cross-validation.
#' @param object a \code{train} object
#' @param show_class_id For classification only: which class level to use for residuals
#' @return a data.table::data.table with predictions, observeds, and residuals
#' @keywords internal
stackedTrainResiduals <- function(object, show_class_id = 2L) {
stopifnot(methods::is(object, "train"))
is_class <- isClassifier(object)
predobs <- extractBestPreds(object)
rowIndex <- predobs[["rowIndex"]]
pred <- predobs[["pred"]]
obs <- predobs[["obs"]]
if (is_class) {
show_class <- levels(object)[show_class_id]
pred <- predobs[[show_class]]
obs <- as.integer(obs == show_class)
}
predobs <- data.table::data.table(
rowIndex = rowIndex,
pred = pred,
obs = obs,
resid = obs - pred
)
predobs
}

#' @title Convenience function for more in-depth diagnostic plots of caretStack objects
#' @description This function provides a more robust series of diagnostic plots
#' for a caretEnsemble object.
Expand Down Expand Up @@ -445,7 +472,7 @@ plot.caretStack <- function(x, metric = NULL, ...) {
# https://github.com/thomasp85/patchwork/issues/226 — why we need importFrom patchwork plot_layout
autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
stopifnot(methods::is(object, "caretStack"))
ensemble_data <- extractPredObsResid(object$ens_model, show_class_id = show_class_id)
ensemble_data <- stackedTrainResiduals(object$ens_model, show_class_id = show_class_id)

# Performance metrics by model
g1 <- plot(object) + ggplot2::labs(title = "Metric and SD For Component Models")
Expand All @@ -470,7 +497,7 @@ autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
ggplot2::theme_bw()

# Disagreement in sub-model residuals
sub_model_data <- lapply(object$models, extractPredObsResid, show_class_id = show_class_id)
sub_model_data <- lapply(object$models, stackedTrainResiduals, show_class_id = show_class_id)
for (model_name in names(sub_model_data)) {
data.table::set(sub_model_data[[model_name]], j = "model", value = model_name)
}
Expand Down
16 changes: 0 additions & 16 deletions R/permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,6 @@ mae <- function(a, b) {
mean(abs(a - b))
}

#' @title Is Classifier
#' @description Check if a model is a classifier.
#' @param model A train object from the caret package.
#' @return A logical indicating whether the model is a classifier.
#' @keywords internal
isClassifier <- function(model) {
stopifnot(methods::is(model, "train") || methods::is(model, "caretStack"))
if (methods::is(model, "train")) {
out <- model$modelType == "Classification"
} else {
out <- model$ens_model$modelType == "Classification"
}
out
}


#' @title Shuffled MAE
#' @description Compute the mean absolute error of a model's predictions when a variable is shuffled.
#' @param original_data A data.table of the original data.
Expand Down
Binary file modified coverage.rds
Binary file not shown.
Loading

0 comments on commit d3af788

Please sign in to comment.