diff --git a/DESCRIPTION b/DESCRIPTION index 4a51e39b..09935786 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -53,6 +53,7 @@ Suggests: rmarkdown, nnet, kernlab, + mockr, h2o, SuperLearner Config/Needs/website: diff --git a/R/fit_members.R b/R/fit_members.R index fc45f4ea..bfe8c846 100644 --- a/R/fit_members.R +++ b/R/fit_members.R @@ -69,6 +69,7 @@ #' @export fit_members <- function(model_stack, ...) { check_model_stack(model_stack) + check_for_required_packages(model_stack) dat <- model_stack[["train"]] @@ -221,3 +222,59 @@ check_model_stack <- function(model_stack) { check_inherits(model_stack, "model_stack") } } + +# given a model stack, find the packages required to fit members and predict +# on new values, and error if any of them are not loaded +check_for_required_packages <- function(x) { + # for dispatch to required_pkgs.workflow when model + # is loaded in a fresh environment + suppressPackageStartupMessages(requireNamespace("workflows")) + + pkgs <- + purrr::map( + x$model_defs, + parsnip::required_pkgs + ) %>% + unlist() %>% + unique() + + installed <- purrr::map_lgl( + pkgs, + is_installed_ + ) + + if (any(!installed)) { + error_needs_install(pkgs, installed) + } + + purrr::map( + pkgs, + ~suppressPackageStartupMessages(requireNamespace(.x, quietly = TRUE)) + ) + + invisible(TRUE) +} + +# takes in a vector of package names and a logical vector giving +# whether or not each is installed +error_needs_install <- function(pkgs, installed) { + plural <- sum(!installed) != 1 + + last_sep <- if (sum(!installed) == 2) {"` and `"} else {"`, and `"} + + need_install <- paste0( + "`", + glue::glue_collapse(pkgs[!installed], sep = "`, `", last = last_sep), + "`" + ) + + glue_stop( + "The following package{if (plural) 's' else ''} ", + "need{if (plural) '' else 's'} to be installed before ", + "fitting members: {need_install}" + ) +} + +is_installed_ <- function(pkg) { + rlang::is_installed(pkg) +} diff --git a/tests/testthat/_snaps/fit_members.md b/tests/testthat/_snaps/fit_members.md new file mode 100644 index 00000000..e94dd2b9 --- /dev/null +++ b/tests/testthat/_snaps/fit_members.md @@ -0,0 +1,16 @@ +# fit_members checks for required packages + + The following package needs to be installed before fitting members: `a` + +--- + + The following packages need to be installed before fitting members: `a` and `b` + +--- + + The following packages need to be installed before fitting members: `a`, `b`, and `c` + +--- + + The following packages need to be installed before fitting members: `recipes`, `parsnip`, and `kernlab` + diff --git a/tests/testthat/test_fit_members.R b/tests/testthat/test_fit_members.R index e2761e44..74c9536d 100644 --- a/tests/testthat/test_fit_members.R +++ b/tests/testthat/test_fit_members.R @@ -18,7 +18,6 @@ library(kernlab) skip_if_not_installed("nnet") library(nnet) - test_that("basic fit_members works", { skip_on_cran() @@ -121,3 +120,36 @@ test_that("fit_members errors informatively with a bad model_stack arg", { "`model_stack` have already been fitted and need not" ) }) + +test_that("fit_members checks for required packages", { + skip_on_cran() + + skip_if_not_installed("mockr") + + library(mockr) + + # check pluralization of error + expect_snapshot_error(error_needs_install(letters[1], rep(FALSE, 1))) + expect_snapshot_error(error_needs_install(letters[1:2], rep(FALSE, 2))) + expect_snapshot_error(error_needs_install(letters[1:3], rep(FALSE, 3))) + + # loads dependency when it's installed but not loaded + unloadNamespace("kernlab") + + expect_s3_class( + st_reg_1_ %>% + fit_members(), + "model_stack" + ) + + expect_true(isNamespaceLoaded("kernlab")) + + # errors informatively when it's not installed + mockr::with_mock( + is_installed_ = function(x) {FALSE}, + {expect_snapshot_error( + st_reg_1_ %>% + fit_members() + )} + ) +})