Skip to content

Commit

Permalink
Add option to keep resamples by fold
Browse files Browse the repository at this point in the history
- Add aggregate_resamples parameter to extractBestPreds
- Pass parameter through caretList chain
- Add unit tests for new functionality
- Maintain backward compatibility with default TRUE
  • Loading branch information
devloai[bot] committed Dec 10, 2024
1 parent 9eff498 commit 30074ed
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
7 changes: 5 additions & 2 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ table #' Create a list of several train models from the caret package
#' @param continue_on_fail logical, should a valid caretList be returned that
#' excludes models that fail, default is FALSE
#' @param trim logical should the train models be trimmed to save memory and speed up stacking
#' @param aggregate_resamples logical, whether to aggregate resamples by taking mean/first value. Default TRUE.
#' @return A list of \code{\link[caret]{train}} objects. If the model fails to build,
#' it is dropped from the list.
#' @export
Expand All @@ -37,7 +38,8 @@ caretList <- function(
tuneList = NULL,
metric = NULL,
continue_on_fail = FALSE,
trim = TRUE) {
trim = TRUE,
aggregate_resamples = TRUE) {
# Checks
if (is.null(tuneList) && is.null(methodList)) {
stop("Please either define a methodList or tuneList", call. = FALSE)
Expand Down Expand Up @@ -79,7 +81,8 @@ caretList <- function(
global_args[["metric"]] <- metric

# Loop through the tuneLists and fit caret models with those specs
modelList <- lapply(tuneList, caretTrain, global_args = global_args, continue_on_fail = continue_on_fail, trim = trim)
modelList <- lapply(tuneList, caretTrain, global_args = global_args, continue_on_fail = continue_on_fail,

Check notice on line 84 in R/caretList.R

View check run for this annotation

codefactor.io / CodeFactor

R/caretList.R#L84

Remove trailing whitespace. (trailing_whitespace_linter)

Check warning on line 84 in R/caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=R/caretList.R,line=84,col=108,[trailing_whitespace_linter] Remove trailing whitespace.
trim = trim, aggregate_resamples = aggregate_resamples)

Check notice on line 85 in R/caretList.R

View check run for this annotation

codefactor.io / CodeFactor

R/caretList.R#L85

Hanging indent should be 22 spaces but is 21 spaces. (indentation_linter)

Check warning on line 85 in R/caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=R/caretList.R,line=85,col=21,[indentation_linter] Hanging indent should be 22 spaces but is 21 spaces.
names(modelList) <- names(tuneList)
nulls <- vapply(modelList, is.null, logical(1L))
modelList <- modelList[!nulls]
Expand Down
23 changes: 15 additions & 8 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {
#' If `TRUE`, the function will return `NULL` if the `train` function fails.
#' @param trim A logical indicating whether to trim the output model.
#' If `TRUE`, the function will remove some elements that are not needed from the output model.
#' @param aggregate_resamples A logical indicating whether to aggregate resamples by taking mean/first value.
#' @return The output of the `train` function.
#' @keywords internal
caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim = TRUE) {
caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim = TRUE, aggregate_resamples = TRUE) {
# Combine args
# I think my handling here is correct (update globals with locals, which allows locals be partial)
# but it would be nice to have some tests
Expand All @@ -100,7 +101,7 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =

# Only save stacked predictions for the best model
if ("pred" %in% names(model)) {
model[["pred"]] <- extractBestPreds(model)
model[["pred"]] <- extractBestPreds(model, aggregate_resamples = aggregate_resamples)
}

if (trim) {
Expand Down Expand Up @@ -147,9 +148,10 @@ aggregate_mean_or_first <- function(x) {
#' @title Extract the best predictions from a train object
#' @description Extract the best predictions from a train object.
#' @param x a train object
#' @param aggregate_resamples logical, whether to aggregate resamples by taking mean/first value. Default TRUE.
#' @return a data.table::data.table with predictions
#' @keywords internal
extractBestPreds <- function(x) {
extractBestPreds <- function(x, aggregate_resamples = TRUE) {
stopifnot(methods::is(x, "train"))
if (is.null(x$pred)) {
stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE)
Expand All @@ -167,14 +169,19 @@ extractBestPreds <- function(x) {
# Drop rows for other tunes
pred <- pred[best_tune, ]

# If we have multiple resamples per row
# If we have multiple resamples per row and aggregate_resamples is TRUE
# e.g. for repeated CV, we need to average the predictions
keys <- "rowIndex"
data.table::setkeyv(pred, keys)
pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys]
if (aggregate_resamples) {
keys <- "rowIndex"
data.table::setkeyv(pred, keys)
pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys]
} else {
# When not aggregating, keep all columns but ensure consistent order
data.table::setcolorder(pred, c("rowIndex", setdiff(names(pred), "rowIndex")))
}

# Order results consistently
data.table::setorderv(pred, keys)
data.table::setorderv(pred, "rowIndex")

# Return
pred
Expand Down
42 changes: 42 additions & 0 deletions tests/testthat/test-caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,48 @@ testthat::test_that("predict.caretList works for classification and regression",
testthat::context("caretList")
################################################################

testthat::test_that("caretList respects aggregate_resamples parameter", {
# Create a trainControl with repeated CV to ensure multiple resamples per row
ctrl <- caret::trainControl(
method = "repeatedcv",
number = 3,

Check notice on line 153 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L153

Use 3L or 3.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 153 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=153,col=15,[implicit_integer_linter] Use 3L or 3.0 to avoid implicit integers.
repeats = 2,

Check notice on line 154 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L154

Use 2L or 2.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 154 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=154,col=16,[implicit_integer_linter] Use 2L or 2.0 to avoid implicit integers.
savePredictions = "final"
)

Check notice on line 157 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L157

Remove trailing whitespace. (trailing_whitespace_linter)

Check warning on line 157 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=157,col=1,[trailing_whitespace_linter] Remove trailing whitespace.
# Test with aggregate_resamples = FALSE
test_no_agg <- caretList(
x = train[, -23L],
y = train[, "Class"],
methodList = c("knn"),

Check notice on line 162 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L162

Remove unnecessary c() of a constant. (unnecessary_concatenation_linter)

Check warning on line 162 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=162,col=18,[unnecessary_concatenation_linter] Remove unnecessary c() of a constant.
trControl = ctrl,
aggregate_resamples = FALSE
)

Check notice on line 166 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L166

Remove trailing whitespace. (trailing_whitespace_linter)

Check warning on line 166 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=166,col=1,[trailing_whitespace_linter] Remove trailing whitespace.
# Get predictions from the first model
preds_no_agg <- test_no_agg[[1]]$pred

Check notice on line 168 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L168

Use 1L or 1.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 168 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=168,col=33,[implicit_integer_linter] Use 1L or 1.0 to avoid implicit integers.

Check notice on line 169 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L169

Remove trailing whitespace. (trailing_whitespace_linter)

Check warning on line 169 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=169,col=1,[trailing_whitespace_linter] Remove trailing whitespace.
# Test with aggregate_resamples = TRUE (default)
test_agg <- caretList(
x = train[, -23L],
y = train[, "Class"],
methodList = c("knn"),

Check notice on line 174 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L174

Remove unnecessary c() of a constant. (unnecessary_concatenation_linter)

Check warning on line 174 in tests/testthat/test-caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-caretList.R,line=174,col=18,[unnecessary_concatenation_linter] Remove unnecessary c() of a constant.
trControl = ctrl
)

Check notice on line 177 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L177

Remove trailing whitespace. (trailing_whitespace_linter)
# Get predictions from the first model
preds_agg <- test_agg[[1]]$pred

Check notice on line 179 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L179

Use 1L or 1.0 to avoid implicit integers. (implicit_integer_linter)

Check notice on line 180 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L180

Remove trailing whitespace. (trailing_whitespace_linter)
# The non-aggregated predictions should have more rows than aggregated ones
# since they keep all resamples
testthat::expect_gt(nrow(preds_no_agg), nrow(preds_agg))

Check notice on line 184 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L184

Remove trailing whitespace. (trailing_whitespace_linter)
# Both should have rowIndex as first column and same set of columns
testthat::expect_identical(names(preds_no_agg)[1], "rowIndex")

Check notice on line 186 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L186

Use 1L or 1.0 to avoid implicit integers. (implicit_integer_linter)
testthat::expect_identical(names(preds_agg)[1], "rowIndex")

Check notice on line 187 in tests/testthat/test-caretList.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-caretList.R#L187

Use 1L or 1.0 to avoid implicit integers. (implicit_integer_linter)
testthat::expect_setequal(names(preds_no_agg), names(preds_agg))
})

testthat::test_that("caretList works for various scenarios", {
# Basic classification
test1 <- caretList(
Expand Down

0 comments on commit 30074ed

Please sign in to comment.