Skip to content

Commit

Permalink
Cleanup todos and spacing (#307)
Browse files Browse the repository at this point in the history
* todo

* remove double space after period

* reorg namespace and manual

* remove fake is methods

* refactor makefile

* remove all dontrun

* run all examples and fix some

* rebuild
  • Loading branch information
zachmayer authored Aug 8, 2024
1 parent 28a6a9b commit d59760b
Show file tree
Hide file tree
Showing 34 changed files with 181 additions and 229 deletions.
59 changes: 37 additions & 22 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,37 +1,52 @@
# Makefile for R project

.PHONY: all install-deps install document update-test-fixtures test coverage-test coverage check fix-style lint spell build-vignettes release clean build

# Default target
.PHONY: help
help:
@echo "Available targets:"
@echo " all - Run clean, fix-style, document, install, build-vignettes, lint, spell, test, check, coverage"
@echo " install-deps - Install dependencies"
@echo " install - Install the whole package, including dependencies"
@echo " document - Generate documentation"
@echo " update-test-fixtures - Update test fixtures"
@echo " test - Run unit tests"
@echo " coverage - Generate coverage reports"
@echo " check - Run R CMD check as CRAN"
@echo " fix-style - Auto style the code"
@echo " lint - Check the code for lint"
@echo " spell - Check spelling"
@echo " build - Build the package"
@echo " build-vignettes - Build vignettes"
@echo " release - Release to CRAN"
@echo " clean - Clean up generated files"

.PHONY: all
all: clean fix-style document install build-vignettes lint spell test check coverage

# Install dependencies
.PHONY: install-deps
install-deps:
Rscript -e "if (!requireNamespace('devtools', quietly = TRUE)) install.packages('devtools')"
Rscript -e "devtools::install_deps()"
Rscript -e "devtools::install_dev_deps()"
Rscript -e "devtools::update_packages()"
Rscript -e "devtools::install_github('r-lib/lintr')"

# Install the whole package
.PHONY: install
install: install-deps
Rscript -e "devtools::install()"

# Generate documentation
.PHONY: document
document:
Rscript -e "devtools::document()"

# Update test fixtures
.PHONY: update-test-fixtures
update-test-fixtures:
Rscript inst/data-raw/build_test_data.R

# Run unit tests
.PHONY: test
test:
Rscript -e "Sys.setenv(NOT_CRAN='true'); devtools::test(stop_on_failure=TRUE, stop_on_warning=TRUE)"
rm -f caretEnsemble_test_plots.png

# Run test coverage
# Dunno why package_coverage makes the dir 'lib/'
coverage.rds: $(wildcard R/*.R) $(wildcard tests/testthat/*.R)
Rscript -e "\
Sys.setenv(NOT_CRAN = 'true'); \
Expand All @@ -40,45 +55,44 @@ coverage.rds: $(wildcard R/*.R) $(wildcard tests/testthat/*.R)
"
rm -rf lib/

# xml coverage report in cobertura format for app.codecov.io/gh/zachmayer/caretEnsemble
cobertura.xml: coverage.rds
Rscript -e "\
cov = readRDS('coverage.rds'); \
covr::to_cobertura(cov, filename='cobertura.xml'); \
"

# html coverage report for local viewing
coverage-report.html: coverage.rds
Rscript -e "\
cov = readRDS('coverage.rds'); \
covr::report(cov, file='coverage-report.html', browse=interactive()); \
"

# Test that coverage is 100%
.PHONY: coverage-test
coverage-test: coverage.rds
Rscript -e "\
cov = readRDS('coverage.rds'); \
cov_num = as.numeric(covr::percent_coverage(cov)); \
testthat::expect_gte(cov_num, 100.0); \
"

.PHONY: coverage
coverage: cobertura.xml coverage-report.html coverage-test

# Run R CMD check as CRAN
.PHONY: check
check: document
Rscript -e "devtools::check(cran = FALSE, remote = TRUE, manual = TRUE, force_suggests = TRUE, error_on = 'note')"
Rscript -e "devtools::check(cran = TRUE , remote = TRUE, manual = TRUE, force_suggests = TRUE, error_on = 'note')"

# Auto style the code
.PHONY: fix-style
fix-style:
Rscript -e "styler::style_pkg()"
Rscript -e "styler::style_dir('inst/')"

# Check the code for lint
.PHONY: lint
lint:
Rscript -e "Sys.setenv(LINTR_ERROR_ON_LINT='true'); devtools::load_all(); lintr::lint_package(cache = FALSE)"

# Check spelling
.PHONY: spell
spell:
Rscript -e " \
results = spelling::spell_check_package(); \
Expand All @@ -89,18 +103,19 @@ spell:
}; \
"

.PHONY: build
build:
Rscript -e "devtools::build()"

# Build vignettes
.PHONY: build-vignettes
build-vignettes:
Rscript -e "devtools::build_vignettes()"

# Release to CRAN
.PHONY: release
release:
Rscript -e "devtools::release()"
# Clean up generated files

.PHONY: clean
clean:
rm -rf *.Rcheck
rm -f *.tar.gz
Expand All @@ -113,4 +128,4 @@ clean:
rm -f caretEnsemble_test_plots.png
rm -f vignettes/caretEnsemble-intro.R
Rscript -e "devtools::clean_vignettes()"
Rscript -e "devtools::clean_dll()"
Rscript -e "devtools::clean_dll()"
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ export(caretStack)
export(extractMetric)
export(greedyMSE)
export(greedyMSE_caret)
export(is.caretList)
export(is.caretStack)
export(permutationImportance)
export(tuneCheck)
export(wtd.sd)
Expand Down
4 changes: 1 addition & 3 deletions R/caretEnsemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@
#' @return a \code{\link{caretEnsemble}} object
#' @export
#' @examples
#' \dontrun{
#' set.seed(42)
#' models <- caretList(iris[1:50, 1:2], iris[1:50, 3], methodList = c("glm", "lm"))
#' models <- caretList(iris[1:50, 1:2], iris[1:50, 3], methodList = c("rpart", "rf"))
#' ens <- caretEnsemble(models)
#' summary(ens)
#' }
caretEnsemble <- function(
all.models,
excluded_class_id = 0L,
Expand Down
56 changes: 20 additions & 36 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#' Build a list of train objects suitable for ensembling using the \code{\link{caretStack}}
#' function.
#'
#' @param ... arguments to pass to \code{\link[caret]{train}}. Don't use the formula interface, its slower
#' and buggier compared to the X, y interface. Use a \code{\link[data.table]{data.table}} for X.
#' @param ... arguments to pass to \code{\link[caret]{train}}. Don't use the formula interface, its slower
#' and buggier compared to the X, y interface. Use a \code{\link[data.table]{data.table}} for X.
#' Particularly if you have a large dataset and/or many models, using a data.table will
#' avoid unnecessary copies of your data and can save a lot of time and RAM.
#' These arguments will determine which train method gets dispatched.
Expand All @@ -22,24 +22,14 @@
#' it is dropped from the list.
#' @export
#' @examples
#' \dontrun{
#' myControl <- trainControl(method = "cv", number = 5)
#' caretList(
#' Sepal.Length ~ Sepal.Width,
#' head(iris, 50),
#' methodList = c("glm", "lm"),
#' trControl = myControl
#' )
#' caretList(
#' Sepal.Length ~ Sepal.Width,
#' head(iris, 50),
#' methodList = c("lm"),
#' tuneList = list(
#' nnet = caretModelSpec(method = "nnet", trace = FALSE, tuneLength = 1)
#' ),
#' trControl = myControl
#' )
#' )
#' }
caretList <- function(
...,
trControl = NULL,
Expand Down Expand Up @@ -127,7 +117,7 @@ caretList <- function(
#' @method predict caretList
#' @export
predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_class_id = 1L, ...) {
stopifnot(is.caretList(object))
stopifnot(methods::is(object, "caretList"))

# Decided whether to be verbose or quiet
apply_fun <- lapply
Expand All @@ -152,8 +142,10 @@ predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_
# E.g. you could mix classification and regression models
# caretPredict will aggregate multiple predictions for the same row (e.g. repeated CV)
# caretPredict will make sure the rows are sorted by the original row order
# If you want to ensemble models that were trained on different rows of data, use
# newdata to predict on a common dataset so they can be ensembles.
pred_rows <- vapply(preds, nrow, integer(1L))
stopifnot(pred_rows == pred_rows[1L]) # TODO: informative error message
stopifnot(pred_rows == pred_rows[1L])

# Name the predictions
for (i in seq_along(preds)) {
Expand Down Expand Up @@ -189,14 +181,6 @@ predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_
preds
}

#' @title Check if an object is a caretList object
#' @description Check if an object is a caretList object
#' @param object an R object
#' @export
is.caretList <- function(object) {
methods::is(object, "caretList")
}

#' @title Convert object to caretList object
#' @description Converts object into a caretList
#' @param object R Object
Expand Down Expand Up @@ -260,26 +244,22 @@ as.caretList.list <- function(object) {
#' @return a \code{\link{caretList}} object
#' @export
#' @examples
#' \dontrun{
#' model_list1 <- caretList(Class ~ .,
#' data = Sonar,
#' data(iris)
#' model_list1 <- caretList(Sepal.Width ~ .,
#' data = iris,
#' tuneList = list(
#' glm = caretModelSpec(method = "glm", family = "binomial"),
#' rpart = caretModelSpec(method = "rpart")
#' lm = caretModelSpec(method = "lm")
#' )
#' )
#'
#' model_list2 <- caretList(Class ~ .,
#' data = Sonar,
#' model_list2 <- caretList(Sepal.Width ~ .,
#' data = iris, tuneLength = 1L,
#' tuneList = list(
#' glm = caretModelSpec(method = "rpart"),
#' rpart = caretModelSpec(method = "rf")
#' rf = caretModelSpec(method = "rf")
#' )
#' )
#'
#' bigList <- c(model_list1, model_list2)
#' }
#'
c.caretList <- function(...) {
new_model_list <- unlist(lapply(list(...), function(x) {
if (inherits(x, "caretList")) {
Expand Down Expand Up @@ -418,6 +398,7 @@ methodCheck <- function(x) {
#' @description This function extracts the y variable from a set of arguments headed to a caret::train model.
#' Since there are 2 methods to call caret::train, this function also has 2 methods.
#' @param ... a set of arguments, as in the caret::train function
#' @keywords internal
extractCaretTarget <- function(...) {
UseMethod("extractCaretTarget")
}
Expand All @@ -428,6 +409,7 @@ extractCaretTarget <- function(...) {
#' or other type (e.g. sparse matrix). See Details below.
#' @param y a numeric or factor vector containing the outcome for each sample.
#' @param ... ignored
#' @keywords internal
extractCaretTarget.default <- function(x, y, ...) {
y
}
Expand All @@ -437,6 +419,8 @@ extractCaretTarget.default <- function(x, y, ...) {
#' @param form A formula of the form y ~ x1 + x2 + ...
#' @param data Data frame from which variables specified in formula are preferentially to be taken.
#' @param ... ignored
#' @method extractCaretTarget formula
#' @keywords internal
extractCaretTarget.formula <- function(form, data, ...) {
y <- stats::model.response(stats::model.frame(form, data))
names(y) <- NULL
Expand All @@ -448,8 +432,8 @@ extractCaretTarget.formula <- function(form, data, ...) {
#' @param x a caretList object
#' @param ... passed to extractMetric.train
#' @return A data.table with metrics from each model.
#' @export
#' @method extractMetric caretList
#' @export
extractMetric.caretList <- function(x, ...) {
metrics <- lapply(x, extractMetric.train, ...)
metrics <- data.table::rbindlist(metrics, use.names = TRUE, fill = TRUE)
Expand Down Expand Up @@ -485,7 +469,7 @@ plot.caretList <- function(x, metric = NULL, ...) {
#' @title Summarize a caretList
#' @description This function summarizes the performance of each model in a caretList object.
#' @param object a caretList object
#' @param metric The metric to show. If NULL will use the metric used to train each model
#' @param metric The metric to show. If NULL will use the metric used to train each model
#' @param ... passed to extractMetric
#' @return A data.table with metrics from each model.
#' @method summary caretList
Expand Down
23 changes: 10 additions & 13 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =
}

#' @title Aggregate mean or first
#' @description For numeric data take the mean. For character data take the first value.
#' @description For numeric data take the mean. For character data take the first value.
#' @param x a train object
#' @return a data.table::data.table with predictions
#' @keywords internal
Expand Down Expand Up @@ -239,22 +239,19 @@ dropExcludedClass <- function(x, all_classes, excluded_class_id) {
#' @return a \code{\link{caretList}} object
#' @export
#' @examples
#' \dontrun{
#' rpartTrain <- train(Class ~ .,
#' data = Sonar,
#' trControl = ctrl1,
#' method = "rpart"
#' data(iris)
#' model_lm <- caret::train(Sepal.Length ~ .,
#' data = iris,
#' method = "lm"
#' )
#'
#' rfTrain <- train(Class ~ .,
#' data = Sonar,
#' trControl = ctrl1,
#' method = "rf"
#' model_rf <- caret::train(Sepal.Length ~ .,
#' data = iris,
#' method = "rf",
#' tuneLength = 1L
#' )
#'
#' bigList <- c(model_list1, model_list2)
#' }
#'
#' model_list <- c(model_lm, model_rf)
c.train <- function(...) {
new_model_list <- unlist(lapply(list(...), function(x) {
if (inherits(x, "caretList")) {
Expand Down
Loading

0 comments on commit d59760b

Please sign in to comment.