Skip to content

Commit

Permalink
Use greedyMSE for caretEnsemble (#306)
Browse files Browse the repository at this point in the history
* rebuild

* add manual files and update docs
  • Loading branch information
zachmayer authored Aug 8, 2024
1 parent 5df7e14 commit 28a6a9b
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 115 deletions.
57 changes: 27 additions & 30 deletions R/caretEnsemble.R
Original file line number Diff line number Diff line change
@@ -1,26 +1,15 @@
#' @title Check binary classification
#' @description Check that the problem is a binary classification problem
#'
#' @param list_of_models a list of caret models to check
#' @keywords internal
check_binary_classification <- function(list_of_models) {
if (is.list(list_of_models) && length(list_of_models) > 1L) {
lapply(list_of_models, function(x) {
# avoid regression models
if (methods::is(x, "train") && !is.null(x$pred$obs) && is.factor(x$pred$obs) && nlevels(x$pred$obs) > 2L) {
stop("caretEnsemble only supports binary classification problems", call. = FALSE)
}
})
}
invisible(NULL)
}

#' @title Combine several predictive models via weights
#'
#' @description Find a good linear combination of several classification or regression models,
#' using linear regression.
#' @description Find a greedy, positive only linear combination of several \code{\link[caret]{train}} objects
#'
#' @details greedyMSE works well when you want an ensemble that will never be worse than any
#' single model in the dataset. In the worst case scenario, it will select the single
#' best model, if none of them can be ensembled to improve the overall score. It will
#' also never assign any model a negative coefficient, which can help avoid
#' unintuitive cases at prediction time (e.g. if the correlations between
#' predictors breaks down on new data, negative coefficients can lead to bad results).
#'
#' @details Every model in the "library" must be a separate \code{train} object. For
#' @note Every model in the "library" must be a separate \code{train} object. For
#' example, if you wish to combine a random forests with several different
#' values of mtry, you must build a model for each value of mtry. If you
#' use several values of mtry in one train model, (e.g. tuneGrid =
Expand All @@ -29,14 +18,13 @@ check_binary_classification <- function(list_of_models) {
#' RMSE is used to ensemble regression models, and AUC is used to ensemble
#' Classification models. This function does not currently support multi-class
#' problems
#' @note Currently when missing values are present in the training data, weights
#' are calculated using only observations which are complete across all models
#' in the library.The optimizer ignores missing values and calculates the weights with the
#' observations and predictions available for each model separately. If each of the
#' models has a different pattern of missingness in the predictors, then the resulting
#' ensemble weights may be biased and the function issues a message.
#' @param all.models an object of class caretList
#' @param ... additional arguments to pass to the optimization function
#' @param excluded_class_id The integer level to exclude from binary classification or multiclass problems.
#' By default no classes are excluded, as the greedy optimizer requires all classes because it cannot
#' use negative coefficients.
#' @param tuneLength The size of the grid to search for tuning the model. Defaults to 1, as
#' the only parameter to optimize is the number of iterations, and the default of 100 works well.
#' @param ... additional arguments to pass caret::train
#' @return a \code{\link{caretEnsemble}} object
#' @export
#' @examples
Expand All @@ -46,9 +34,18 @@ check_binary_classification <- function(list_of_models) {
#' ens <- caretEnsemble(models)
#' summary(ens)
#' }
caretEnsemble <- function(all.models, ...) {
check_binary_classification(all.models)
out <- caretStack(all.models, method = "glm", ...)
caretEnsemble <- function(
all.models,
excluded_class_id = 0L,
tuneLength = 1L,
...) {
out <- caretStack(
all.models,
excluded_class_id = excluded_class_id,
tuneLength = tuneLength,
method = greedyMSE_caret(),
...
)
class(out) <- c("caretEnsemble", "caretStack")
out
}
7 changes: 4 additions & 3 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,9 @@ autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
ggplot2::theme_bw()

# Model Weights
wghtFrame <- data.table::as.data.table(stats::coef(object[["ens_model"]][["finalModel"]]))
data.table::set(wghtFrame, j = "method", value = row.names(wghtFrame))
imp <- caret::varImp(object[["ens_model"]][["finalModel"]])
wghtFrame <- data.table::as.data.table(imp)
data.table::set(wghtFrame, j = "method", value = row.names(imp))
names(wghtFrame) <- c("weights", "method")
g3 <- ggplot2::ggplot(wghtFrame, ggplot2::aes(.data[["method"]], .data[["weights"]])) +
ggplot2::geom_bar(stat = "identity", fill = I("gray50"), color = I("black")) +
Expand Down Expand Up @@ -550,6 +551,6 @@ autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
ggplot2::scale_y_continuous("Residuals") +
ggplot2::labs(title = paste0("Residuals Against ", xvars[2L])) +
ggplot2::theme_bw()
out <- g1 + g2 / (g3 + g4) / (g5 + g6)
out <- (g1 + g2) / (g3 + g4) / (g5 + g6)
out
}
19 changes: 10 additions & 9 deletions R/greedyOpt.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,15 @@ predict.greedyMSE <- function(object, newdata, return_labels = FALSE, ...) {

pred <- newdata %*% object$model_weights
if (ncol(pred) > 1L) {
pred <- pred / rowSums(pred)
}

if (return_labels) {
stopifnot(ncol(pred) > 1L)
lev <- colnames(object$model_weights)
pred <- lev[apply(pred, 1L, which.max)]
pred <- factor(pred, levels = lev)
if (return_labels) {
lev <- colnames(object$model_weights)
pred <- lev[apply(pred, 1L, which.max)]
pred <- factor(pred, levels = lev)
} else {
pred <- pred / rowSums(pred)
}
} else {
pred <- pred[, 1L]
}

pred
Expand All @@ -141,7 +142,7 @@ predict.greedyMSE <- function(object, newdata, return_labels = FALSE, ...) {
#' @title caret interface for greedyMSE
#' @description caret interface for greedyMSE. greedyMSE works
#' well when you want an ensemble that will never be worse than any single predictor
#' in the in dataset. It does not use an intercept and it does not allow for
#' in the dataset. It does not use an intercept and it does not allow for
#' negative coefficients. This makes it highly constrained and in general
#' does not work well on standard classification and regression problems.
#' However, it does work well in the case of:
Expand Down
Binary file modified coverage.rds
Binary file not shown.
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ trainControl
travis
tuneGrid
tuneList
unintuitive
varImp
vecstack
yhat
30 changes: 18 additions & 12 deletions man/caretEnsemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 0 additions & 15 deletions man/check_binary_classification.Rd

This file was deleted.

26 changes: 26 additions & 0 deletions man/greedyMSE_caret.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions man/varImp.greedyMSE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 11 additions & 6 deletions tests/testthat/test-caretEnsemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,14 @@ testthat::test_that("We can ensemble models of different predictors", {
testthat::expect_s3_class(pred.nest, "data.table")
testthat::expect_identical(nrow(pred.nest), 150L)

# Ensemble errors on NAs
X_reg_new <- X.reg
X_reg_new[2L, 3L] <- NA
X_reg_new[25L, 3L] <- NA
p_with_nas <- predict(ensNest, newdata = X_reg_new)
expect_error(
predict(ensNest, newdata = X_reg_new),
"is.finite(newdata) are not all TRUE",
fixed = TRUE
)
})

testthat::context("Does ensemble prediction work with new data")
Expand Down Expand Up @@ -130,12 +134,13 @@ testthat::test_that("caretEnsemble works for classification models", {
models.class,
trControl = caret::trainControl(
method = "cv",
number = 2L,
number = 10L,
savePredictions = "final",
classProbs = TRUE
)
)
testthat::expect_s3_class(ens.class, "caretEnsemble")
ens.class$ens_model$finalModel

# Predictions
pred_stacked <- predict(ens.class) # stacked predictions
Expand All @@ -158,11 +163,11 @@ testthat::test_that("caretEnsemble works for classification models", {
testthat::expect_identical(ncol(pred_one), 2L)

# stacked predcitons should be similar to in sample predictions
testthat::expect_equal(pred_stacked, pred_in_sample, tol = 0.2)
testthat::expect_equal(pred_stacked, pred_in_sample, tol = 0.1)

# One row predictions
testthat::expect_equivalent(pred_one$Yes, 0.03833661, tol = 0.05)
testthat::expect_equivalent(pred_one$No, 0.9616634, tol = 0.05)
testthat::expect_equivalent(pred_one$Yes, 0.02, tol = 0.05)
testthat::expect_equivalent(pred_one$No, 0.98, tol = 0.05)
})

testthat::context("Do ensembles of custom models work?")
Expand Down
13 changes: 11 additions & 2 deletions tests/testthat/test-ensembleMethods.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,17 @@ testthat::test_that("caret::varImp.caretEnsemble", {
for (s in c(TRUE, FALSE)) {
i <- caret::varImp(m, normalize = s)
testthat::expect_is(i, "numeric")
testthat::expect_length(i, length(m$models))
testthat::expect_named(i, names(m$models))
if (isClassifier(m)) {
len <- length(m$models) * 2L
n <- c(outer(c("rf", "glm", "rpart", "treebag"), c("No", "Yes"), paste, sep = "_"))
n <- matrix(n, ncol = 2L)
n <- c(t(n))
} else {
len <- length(m$models)
n <- names(m$models)
}
testthat::expect_length(i, len)
testthat::expect_named(i, n)
if (s) {
testthat::expect_true(all(i >= 0.0))
testthat::expect_true(all(i <= 1.0))
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-greedyMSE.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ testthat::test_that("predict works with data.frame input", {
newdata_df <- as.data.frame(regression_data$X)
pred <- predict(model, newdata_df)

testthat::expect_is(pred, "matrix")
testthat::expect_identical(nrow(pred), nrow(newdata_df))
testthat::expect_is(pred, "numeric")
testthat::expect_length(pred, nrow(newdata_df))
})

# Test predict with label return for classification
Expand Down
20 changes: 0 additions & 20 deletions tests/testthat/test-helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,26 +188,6 @@ testthat::test_that("isClassifierAndValidate stops when a classification model d
testthat::expect_error(lapply(model_list, isClassifierAndValidate, validate_for_stacking = TRUE), err)
testthat::context("Test helper functions for multiclass classification")

testthat::test_that("Check errors in caretEnsemble for multiclass classification work", {
data(iris)
model_list <- caretList(
x = iris[, -5L],
y = iris[, 5L],
methodList = c("rpart", "glmnet")
)

err <- "caretEnsemble only supports binary classification problems"
testthat::expect_error(check_binary_classification(model_list), err)
testthat::expect_null(check_binary_classification(models.class))
testthat::expect_null(check_binary_classification(models.reg))

# Do not produce errors when another object is passed
testthat::expect_null(check_binary_classification(NULL))
testthat::expect_null(check_binary_classification(2L))
testthat::expect_null(check_binary_classification(list("string")))
testthat::expect_null(check_binary_classification(iris))
})

testthat::test_that("Configuration function for excluded level work", {
# Integers work
testthat::expect_identical(validateExcludedClass(0L), 0L)
Expand Down
16 changes: 0 additions & 16 deletions tests/testthat/test-multiclass.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,6 @@ testthat::test_that("We can make a confusion matrix", {
testthat::expect_gt(cm$overall["Accuracy"], 0.9)
})

testthat::test_that("Multiclass is not supported for caretEnsemble", {
data(iris)
data(models.class)
data(models.reg)
model_list <- caretList(
x = iris[, -5L],
y = iris[, 5L],
methodList = c("glmnet", "rpart"),
tuneList = list(
nnet = caretModelSpec(method = "nnet", trace = FALSE)
)
)

testthat::expect_error(caretEnsemble(model_list), "caretEnsemble only supports binary classification problems")
})

testthat::test_that("caretList and caretStack handle imbalanced multiclass data", {
set.seed(123L)
n <- 1000L
Expand Down

0 comments on commit 28a6a9b

Please sign in to comment.