Skip to content

Commit

Permalink
Default train control (#317)
Browse files Browse the repository at this point in the history
* readme changes from other pr

* update template

* use defaultControl

* default control and metrics

* rebuild

* use defaultControl

* rebuild

* dont track coverage stats
  • Loading branch information
zachmayer authored Aug 12, 2024
1 parent dcb8d37 commit e1e5e2d
Show file tree
Hide file tree
Showing 21 changed files with 207 additions and 227 deletions.
53 changes: 25 additions & 28 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,48 +1,45 @@
If you are making a feature request or starting a discussion, you can ignore everything below and go wild =D
# FEATURE REQUEST

If you are filing a bug, make sure these boxes are checked before submitting your issue— thank you!

- [ ] Start a new R session
- [ ] Install the latest version of caretEnsemble: `devtools::install_github("zachmayer/caretEnsemble")`
- [ ] Install the latest version of caret: `update.packages(oldPkgs="caret", ask=FALSE)`
# BUG
Please use this checklist for bug reports:
- [ ] [Write a minimal reproducible example](http://stackoverflow.com/a/5963610)
- [ ] Include example data in the minimal reproducible example
- [ ] run `sessionInfo()`

### Minimal, reproducible example:
Text and example code modified from [the R FAQ on stackoverflow](http://stackoverflow.com/a/5963610)
## Minimal, reproducible example:
start a NEW R session!

#### Minimal dataset:
```{R}
set.seed(1)
dat <- caret::twoClassSim(100)
X <- dat[,1:5]
y <- dat[["Class"]]
```
If you have some data that would be too difficult to construct using `caret::twoClassSim` or `caret::SLC14_1`, then you can always make a subset of your original data, using e.g. `head()`, `subset()` or the indices. Then use e.g. `dput()` to give us something that can be put in R immediately, e.g. `dput(head(iris, 4))`
rm(list=ls(all=T))
gc(reset=T)
set.seed(1L)
If you must use `dput(head())`, please first remove an columns from your dataset that are not necessary to reproduce the error.

If your data frame has a factor with many levels, the `dput` output can be unwieldy because it will still list all the possible factor levels even if they aren't present in the the subset of your data. To solve this issue, you can use the `droplevels()` function. Notice below how species is a factor with only one level: `dput(droplevels(head(iris, 4)))`
dat <- caret::twoClassSim(100L)
X <- dat[,1L:5L]
y <- dat[["Class"]]
#### Minimal, runnable code:
```{R}
models <- caretEnsemble::caretList(
X, y,
methodList=c('glm', 'rpart'),
trControl=caret::trainControl(
method="cv",
number=5,
classProbs=TRUE,
savePredictions="final")
methodList=c('glm', 'rpart')
)
ens <- caretEnsemble::caretStack(models)
print(ens)
```

### Session Info:
If you have some data that would be too difficult to construct using `caret::twoClassSim` or `caret::SLC14_1`, then you can always make a subset of your original data, using e.g. `head()`, `subset()` or the indices. Then use e.g. `dput()` to give us something that can be put in R immediately, e.g. `dput(head(iris, 4))`

If you must use `dput(head())`, please first remove an columns from your dataset that are not necessary to reproduce the error.

If your data frame has a factor with many levels, the `dput` output can be unwieldy because it will still list all the possible factor levels even if they aren't present in the the subset of your data. To solve this issue, you can use the `droplevels()` function: `dput(droplevels(head(iris, 4)))`

## Session Info:
```{R}
utils::sessionInfo()
```

You can delete the text in each section that explains how to do it correctly.
Be sure to test your 2 chunks of code in an empty R session before submitting your issue!
Please cut/paste the output. If your version of caret or caretEnsemble is old, upgrade them with:
```R
update.packages(oldPkgs="caret", ask=FALSE)
devtools::install_github("zachmayer/caretEnsemble")
```
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# huge XML/HTML file messes up git stats lol
coverage.rds
coverage-report.html
cobertura.xml

Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ export(caretEnsemble)
export(caretList)
export(caretModelSpec)
export(caretStack)
export(defaultControl)
export(defaultMetric)
export(extractMetric)
export(greedyMSE)
export(greedyMSE_caret)
Expand Down
71 changes: 53 additions & 18 deletions R/caretList.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' Create a list of several train models from the caret package
table #' Create a list of several train models from the caret package
#'
#' Build a list of train objects suitable for ensembling using the \code{\link{caretStack}}
#' function.
Expand All @@ -8,7 +8,7 @@
#' Particularly if you have a large dataset and/or many models, using a data.table will
#' avoid unnecessary copies of your data and can save a lot of time and RAM.
#' These arguments will determine which train method gets dispatched.
#' @param trControl a \code{\link[caret]{trainControl}} object. If null, we will construct a good one.
#' @param trControl a \code{\link[caret]{trainControl}} object. If NULL, will use defaultControl.
#' @param methodList optional, a character vector of caret models to ensemble.
#' One of methodList or tuneList must be specified.
#' @param tuneList optional, a NAMED list of caretModelSpec objects.
Expand Down Expand Up @@ -60,25 +60,12 @@ caretList <- function(
is_class <- is.factor(target) || is.character(target)
is_binary <- length(unique(target)) == 2L

# Determine metric
# Determine metric and trControl
if (is.null(metric)) {
metric <- "RMSE"
if (is_class) {
metric <- if (is_binary) "ROC" else "Accuracy"
}
metric <- defaultMetric(is_class = is_class, is_binary = is_binary)
}

# Make a trainControl if it is missing
if (is.null(trControl)) {
trControl <- caret::trainControl(
method = "cv",
number = 5L,
index = caret::createFolds(target, k = 5L, list = TRUE, returnTrain = TRUE),
savePredictions = "final",
classProbs = is_class,
summaryFunction = ifelse(is_class && is_binary, caret::twoClassSummary, caret::defaultSummary),
returnData = FALSE
)
trControl <- defaultControl(target, is_class = is_class, is_binary = is_binary)
}

# ALWAYS save class probs
Expand Down Expand Up @@ -181,6 +168,54 @@ predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_
preds
}

#' @title Construct a default train control for use with caretList
#' @description Unlike caret::trainControl, this function defaults to 5 fold CV.
#' CV is good for stacking, as every observation is in the test set exactly once.
#' We use 5 instead of 10 to save compute time, as caretList is for fitting many
#' models. We also construct explicit fold indexes and return the stacked predictions,
#' which are needed for stacking. For classification models we return class probabilities.
#' @param target the target variable.
#' @param number the number of folds to use.
#' @param is_class logical, is this a classification or regression problem.
#' @param is_binary logical, is this binary classification.
#' @param ... other arguments to pass to \code{\link[caret]{trainControl}}
#' @export
defaultControl <- function(
target,
number = 5L,
is_class = is.factor(target) || is.character(target),
is_binary = length(unique(target)) == 2L,
...) {
caret::trainControl(
method = "cv",
number = number,
index = caret::createFolds(target, k = number, list = TRUE, returnTrain = TRUE),
savePredictions = "final",
classProbs = is_class,
summaryFunction = ifelse(is_class && is_binary, caret::twoClassSummary, caret::defaultSummary),
returnData = FALSE,
...
)
}

#' @title Construct a default metric
#' @description Caret defaults to RMSE for classification and RMSE for regression.
#' For classification, I would rather use ROC.
#' @param is_class logical, is this a classification or regression problem.
#' @param is_binary logical, is this binary classification.
#' @export
defaultMetric <- function(is_class, is_binary) {
if (is_class) {
if (is_binary) {
"ROC"
} else {
"Accuracy"
}
} else {
"RMSE"
}
}

#' @title Convert object to caretList object
#' @description Converts object into a caretList
#' @param object R Object
Expand Down
28 changes: 20 additions & 8 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#' (for transfer learning).
#' If NULL, will use the observed levels from the first model in the caret stack
#' If 0, will include all levels.
#' @param metric the metric to use for grid search on the stacking model.
#' @param trControl a trainControl object to use for training the ensemble model. If NULL, will use defaultControl.
#' @param excluded_class_id The integer level to exclude from binary classification or multiclass problems.
#' @param ... additional arguments to pass to the stacking model
#' @return S3 caretStack object
Expand All @@ -38,8 +40,11 @@ caretStack <- function(
all.models,
new_X = NULL,
new_y = NULL,
metric = NULL,
trControl = NULL,
excluded_class_id = 1L,
...) {
# Check all.models
if (!methods::is(all.models, "caretList")) {
warning("Attempting to coerce all.models to a caretList.", call. = FALSE)
all.models <- as.caretList(all.models)
Expand Down Expand Up @@ -67,7 +72,7 @@ caretStack <- function(
stopifnot(nrow(preds) == nrow(new_X))
}

# Build a caret model
# Choose the target
obs <- new_y
if (is.null(obs)) {
obs <- data.table::data.table(all.models[[1L]]$pred)
Expand All @@ -76,7 +81,19 @@ caretStack <- function(
obs <- obs[["obs"]]
}
stopifnot(nrow(preds) == length(obs))
model <- caret::train(preds, obs, ...)

# Make a trainControl
is_class <- is.factor(obs) || is.character(obs)
is_binary <- length(unique(obs)) == 2L
if (is.null(metric)) {
metric <- defaultMetric(is_class = is_class, is_binary = is_binary)
}
if (is.null(trControl)) {
trControl <- defaultControl(obs, is_class = is_class, is_binary = is_binary)
}

# Train the model
model <- caret::train(preds, obs, metric = metric, trControl = trControl, ...)

# Return final model
out <- list(
Expand Down Expand Up @@ -438,12 +455,7 @@ stackedTrainResiduals <- function(object, show_class_id = 2L) {
#' @examples
#' set.seed(42)
#' data(models.reg)
#' ens <- caretStack(
#' models.reg,
#' trControl = caret::trainControl(
#' method = "cv", savePredictions = "final"
#' )
#' )
#' ens <- caretStack(models.reg)
#' autoplot(ens)
# https://github.com/thomasp85/patchwork/issues/226 — why we need importFrom patchwork plot_layout
autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
Expand Down
Binary file removed coverage.rds
Binary file not shown.
27 changes: 14 additions & 13 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
CMD
Caruana
CodeFactor
classProbs
Deane
Ksikes
Makefile
Mizil
Multiclass
Niculescu
SDs
al
autoplot
caretList
caretModelSpec
caretStack
caretTrain
Caruana
classProbs
CMD
CodeFactor
coercible
Deane
defaultControl
dev
devtools
dotplot
ensembled
ensembling
extractMetric
et
extractMetric
ggplot
ggplot2
github
Expand All @@ -31,13 +26,18 @@ greedyMSE
importances
kable
knitr
Ksikes
linters
Makefile
methodList
Mizil
modelInfo
modelLookup
mtry
Multiclass
multiclass
newdata
Niculescu
nnet
observeds
optimizers
Expand All @@ -50,11 +50,12 @@ roxygen
rpart
savePredictions
scikit
SDs
trainControl
travis
tuneGrid
tuneList
unintuitive
varImp
vecstack
yhat
yhat
28 changes: 2 additions & 26 deletions inst/data-raw/build_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,20 @@ X.reg <- model.matrix(~., iris[, -1L])
X.class <- X.reg
Y.class <- factor(ifelse(iris$Sepal.Length <= 6.2, "No", "Yes"))

# Reusable control
myControl_reg <- caret::trainControl(
method = "cv",
number = 10L,
p = 0.75,
savePrediction = TRUE,
classProbs = FALSE,
returnResamp = "final",
returnData = TRUE
)

myControl_class <- caret::trainControl(
method = "cv",
number = 10L,
p = 0.75,
savePrediction = TRUE,
summaryFunction = caret::twoClassSummary,
classProbs = TRUE,
returnResamp = "final",
returnData = TRUE
)

# Regression
set.seed(482L)
models.reg <- caretEnsemble::caretList(
x = X.reg,
y = Y.reg,
methodList = c("rf", "glm", "rpart", "treebag"),
trControl = myControl_reg
methodList = c("rf", "glm", "rpart", "treebag")
)

# Classification
set.seed(482L)
models.class <- caretEnsemble::caretList(
x = X.class,
y = Y.class,
methodList = c("rf", "glm", "rpart", "treebag"),
trControl = myControl_class
methodList = c("rf", "glm", "rpart", "treebag")
)

# Save
Expand Down
7 changes: 1 addition & 6 deletions man/autoplot.caretStack.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/caretList.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e1e5e2d

Please sign in to comment.