Skip to content

Commit

Permalink
new vignette and fix bug it IDd
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmayer committed Aug 13, 2024
1 parent 474ddfd commit 01a1e6d
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 6 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Title: Ensembles of Caret Models
Version: 4.0.0
Date: 2024-08-12
Authors@R: c(person(c("Zachary", "A."), "Deane-Mayer", role = c("aut", "cre", "cph"), email = "[email protected]"),
person(c("Jared", "E.", "Knowles", role="ctb", email="[email protected]"),
person(c("Jared", "E.", "Knowles"), role="ctb", email="[email protected]"),
person("Antón", "López", role="ctb", email="[email protected]")
)
URL: https://github.com/zachmayer/caretEnsemble
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ view-coverage: coverage-report.html
open coverage-report.html

.PHONY: coverage
coverage: cobertura.xml coverage-report.html coverage-test view-coverage
coverage: cobertura.xml coverage-report.html view-coverage coverage-test

.PHONY: check
check:
Expand Down
10 changes: 7 additions & 3 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,19 @@ predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_
#' @export
defaultControl <- function(
target,
method = 'cv',

Check notice on line 185 in R/caretList.R

View check run for this annotation

codefactor.io / CodeFactor

R/caretList.R#L185

Only use double-quotes. (quotes_linter)

Check warning on line 185 in R/caretList.R

View workflow job for this annotation

GitHub Actions / lint

file=R/caretList.R,line=185,col=14,[quotes_linter] Only use double-quotes.
number = 5L,
savePredictions = "final",
index = caret::createFolds(target, k = number, list = TRUE, returnTrain = TRUE),
is_class = is.factor(target) || is.character(target),
is_binary = length(unique(target)) == 2L,
...) {
stopifnot(savePredictions %in% c("final", "all"))
caret::trainControl(
method = "cv",
method = method,
number = number,
index = caret::createFolds(target, k = number, list = TRUE, returnTrain = TRUE),
savePredictions = "final",
index = index,
savePredictions = savePredictions,
classProbs = is_class,
summaryFunction = ifelse(is_class && is_binary, caret::twoClassSummary, caret::defaultSummary),
returnData = FALSE,
Expand Down
5 changes: 5 additions & 0 deletions R/permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ shuffled_mae <- function(model, original_data, target, pred_type, shuffle_idx) {
new_preds <- as.matrix(stats::predict(model, original_data, type = pred_type))
data.table::set(original_data, j = var, value = old_var)

if(anyNA(new_preds)) { # This shoudn't happen, but it does with rpart.

Check notice on line 57 in R/permutationImportance.R

View check run for this annotation

codefactor.io / CodeFactor

R/permutationImportance.R#L57

Place a space before left parenthesis, except in a function call. (spaces_left_parentheses_linter)

Check warning on line 57 in R/permutationImportance.R

View workflow job for this annotation

GitHub Actions / lint

file=R/permutationImportance.R,line=57,col=7,[spaces_left_parentheses_linter] Place a space before left parenthesis, except in a function call.
new_preds[is.na(new_preds)] <- 0.0
}

mae(new_preds, target)
}, numeric(1L))

Expand Down Expand Up @@ -105,6 +109,7 @@ permutationImportance <- function(
is.numeric(preds_orig),
is.finite(preds_orig)
)

# Error of shuffled variables
mae_vars <- shuffled_mae(model, newdata, preds_orig, pred_type, shuffle_idx)

Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test-permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,19 @@ testthat::test_that("permutationImportance handles various edge cases", {
check_importance_scores(imp_identical, names(x_identical))
testthat::expect_equal(imp_identical[["x1"]], imp_identical[["x3"]], tol = 1e-1)
})
######################################################################
testthat::context("NAN predictions from rpart")
######################################################################


testthat::test_that("permutationImportance handles NAN predictions from rpart", {
set.seed(42)

Check notice on line 221 in tests/testthat/test-permutationImportance.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-permutationImportance.R#L221

Use 42L or 42.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 221 in tests/testthat/test-permutationImportance.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-permutationImportance.R,line=221,col=14,[implicit_integer_linter] Use 42L or 42.0 to avoid implicit integers.
model_list <- caretEnsemble::caretList(
x = iris[, 1:4],

Check notice on line 223 in tests/testthat/test-permutationImportance.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-permutationImportance.R#L223

Use 1L or 1.0 to avoid implicit integers. (implicit_integer_linter)

Check notice on line 223 in tests/testthat/test-permutationImportance.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-permutationImportance.R#L223

Use 4L or 4.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 223 in tests/testthat/test-permutationImportance.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-permutationImportance.R,line=223,col=17,[implicit_integer_linter] Use 1L or 1.0 to avoid implicit integers.

Check warning on line 223 in tests/testthat/test-permutationImportance.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-permutationImportance.R,line=223,col=19,[implicit_integer_linter] Use 4L or 4.0 to avoid implicit integers.
y = iris[, 5],

Check notice on line 224 in tests/testthat/test-permutationImportance.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-permutationImportance.R#L224

Use 5L or 5.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 224 in tests/testthat/test-permutationImportance.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-permutationImportance.R,line=224,col=17,[implicit_integer_linter] Use 5L or 5.0 to avoid implicit integers.
methodList = "rpart"
)
ens <- caretEnsemble(model_list)
imp <- caret::varImp(ens)
testthat::expect_true(all(is.finite(imp)))
})

Check notice on line 230 in tests/testthat/test-permutationImportance.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-permutationImportance.R#L230

Add a terminal newline. (trailing_blank_lines_linter)

Check warning on line 230 in tests/testthat/test-permutationImportance.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-permutationImportance.R,line=230,col=3,[trailing_blank_lines_linter] Add a terminal newline.
123 changes: 122 additions & 1 deletion vignettes/Version-4.0-New-Features.Rmd
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
---
title: "Version-4.0-New-Features"
author: "Zach Deane-Mayer"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Version-4.0-New-Features}
Expand All @@ -10,6 +12,125 @@ vignette: >
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
comment = "#>",
echo = TRUE,
warning = FALSE,
message = FALSE
)
```

caretEnsemble 4.0.0 introduces many new features! Let's quickly go over them.

# Multiclass support
caretEnsemble now fully supports multiclass problems:
```{r}
set.seed(42)

Check warning on line 27 in vignettes/Version-4.0-New-Features.Rmd

View workflow job for this annotation

GitHub Actions / lint

file=vignettes/Version-4.0-New-Features.Rmd,line=27,col=12,[implicit_integer_linter] Use 42L or 42.0 to avoid implicit integers.
model_list <- caretEnsemble::caretList(
x = iris[, 1:4],

Check warning on line 29 in vignettes/Version-4.0-New-Features.Rmd

View workflow job for this annotation

GitHub Actions / lint

file=vignettes/Version-4.0-New-Features.Rmd,line=29,col=15,[implicit_integer_linter] Use 1L or 1.0 to avoid implicit integers.

Check warning on line 29 in vignettes/Version-4.0-New-Features.Rmd

View workflow job for this annotation

GitHub Actions / lint

file=vignettes/Version-4.0-New-Features.Rmd,line=29,col=17,[implicit_integer_linter] Use 4L or 4.0 to avoid implicit integers.
y = iris[, 5],
methodList = c("rpart", "rf")
)
```

# Greedy Optimizer in caretEnsemble
The new version uses a greedy optimizer by default, ensuring the ensemble is never worse than the worst single model:
```{r}
ens <- caretEnsemble::caretEnsemble(model_list)
summary(ens)
```

# Enhanced S3 Methods
caretStack (and by extension, caretEnsemble) now supports various S3 methods:
```{r}
print(ens)
summary(ens)
plot(ens)
ggplot2::autoplot(ens)
```
# Improved Default trainControl
A new default trainControl constructor makes it easier to build appropriate controls for caretLists. These controls include explicit indexes based on the target, return stacked predictions, and use probability estimates for classification models.
```{r}
class_control <- caretEnsemble::defaultControl(iris$Species)
print(ls(class_control))
```

```{r}
reg_control <- caretEnsemble::defaultControl(iris$Sepal.Length)
print(ls(reg_control))
```

# Mixed Resampling Strategies
Models with different resampling strategies can now be ensembled:
```{r}
set.seed(42)
y <- iris[, 1]
x <- iris[, 2:3]
flex_list <- caretEnsemble::caretList(
x = x,
y = y,
methodList = c("rpart", "rf"),
trControl = caretEnsemble::defaultControl(y, number = 3)
)
flex_list$glm_boot <- caret::train(
x = x,
y = y,
method = "glm",
trControl = caretEnsemble::defaultControl(y, method = "boot", number=50)
)
flex_ens <- caretEnsemble::caretEnsemble(flex_list)
print(flex_ens)
```

# Mixed Model Types
caretEnsemble now allows ensembling of mixed lists of classification and regression models:
```{r}
set.seed(42)
X <- iris[,1:4]
target_class <- iris[, 5]
target_reg <- as.integer(iris[, 5] == 'virginica')
model_class <- caret::train(iris[, 1:4], target_class, method = "rf", trControl = caretEnsemble::defaultControl(target_class))
model_reg <- caret::train(iris[, 1:4], target_reg, method = "rf", trControl = caretEnsemble::defaultControl(target_reg))
mixed_list <- caretEnsemble::as.caretList(list(class=model_class, reg=model_reg))
mixed_ens <- caretEnsemble::caretEnsemble(mixed_list)
print(mixed_ens)
```

# Transfer Learning
caretStack now supports transfer learning for ensembling models trained on different datasets:
```{r}
set.seed(42)
train_idx <- sample(1:nrow(iris), 100)
train_data <- iris[train_idx, ]
new_data <- iris[-train_idx, ]
model_list <- caretEnsemble::caretList(
x = train_data[, 1:4],
y = train_data[, 5],
methodList = c("rpart", "rf")
)
transfer_ens <- caretEnsemble::caretEnsemble(
model_list,
new_X = new_data[, 1:4],
new_y = new_data[, 5]
)
print(transfer_ens)
```

We can also predict on new data:
```{r}
preds <- predict(transfer_ens, newdata = head(new_data))
print(preds)
```

# Permutation Importance
Permutation importance is now the default method for variable importance in caretLists and caretStacks:
```{r}
importance <- caret::varImp(transfer_ens)
print(importance)
```

This completes our demonstration of the key new features in caretEnsemble 4.0. These enhancements provide greater flexibility, improved performance, and easier usage for ensemble modeling in R.

0 comments on commit 01a1e6d

Please sign in to comment.