diff --git a/R/greedyOpt.R b/R/greedyOpt.R index 75cab8ec..1441809d 100644 --- a/R/greedyOpt.R +++ b/R/greedyOpt.R @@ -160,6 +160,7 @@ predict.greedyMSE <- function(object, newdata, return_labels = FALSE, ...) { greedyMSE_caret <- function() { list( label = "Greedy Mean Squared Error Optimizer", + method = "greedyMSE", library = NULL, loop = NULL, type = c("Regression", "Classification"), diff --git a/coverage.rds b/coverage.rds index 247b11a1..32488b6c 100644 Binary files a/coverage.rds and b/coverage.rds differ diff --git a/tests/testthat/test-caretList.R b/tests/testthat/test-caretList.R index 8d2ce95e..54606c7f 100644 --- a/tests/testthat/test-caretList.R +++ b/tests/testthat/test-caretList.R @@ -254,3 +254,44 @@ testthat::test_that("caretList supports combined regression, binary, multiclass" testthat::expect_identical(nrow(stacked_p), nrow(iris)) testthat::expect_identical(nrow(new_p), 10L) }) + +testthat::test_that("caretList supports custom models", { + set.seed(42L) + + # Use the custom greedyMSE model + custom_list <- list( + custom.mse = caretModelSpec(method = greedyMSE_caret(), tuneLength = 1L) + ) + + # Fit it reg/bin/multi (it supports all 3!) + reg_models <- caretList(Sepal.Length ~ Sepal.Width, iris, tuneList = custom_list) + bin_models <- caretList(factor(ifelse(Species == "setosa", "Y", "N")) ~ Sepal.Width, iris, tuneList = custom_list) + multi_models <- caretList(Species ~ Sepal.Width, iris, tuneList = custom_list) + + # Check the fit + all_models <- c(reg_models, bin_models, multi_models) + testthat::expect_s3_class(all_models, "caretList") + + # Check predictions + stacked_p <- predict(all_models) + new_p <- predict(all_models, newdata = iris[1L:10L, ]) + testthat::expect_is(stacked_p, "data.table") + testthat::expect_is(new_p, "data.table") + testthat::expect_identical(nrow(stacked_p), nrow(iris)) + testthat::expect_identical(nrow(new_p), 10L) + + # Check we can stack it + # Note that caretStack with method=greedyMSE_caret() + # is what caretEnsemble does under the hood + ens <- caretStack( + all_models, + method = greedyMSE_caret(), + trControl = trainControl(method = "cv", number = 2L, savePredictions = "final") + ) + stacked_p <- predict(ens) + new_p <- predict(ens, newdata = iris[1L:10L, ]) + testthat::expect_is(stacked_p, "data.table") + testthat::expect_is(new_p, "data.table") + testthat::expect_identical(nrow(stacked_p), nrow(iris)) + testthat::expect_identical(nrow(new_p), 10L) +})