Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Suggests:
rmarkdown,
nnet,
kernlab,
mockr,
h2o,
SuperLearner
Config/Needs/website:
Expand Down
57 changes: 57 additions & 0 deletions R/fit_members.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#' @export
fit_members <- function(model_stack, ...) {
check_model_stack(model_stack)
check_for_required_packages(model_stack)

dat <- model_stack[["train"]]

Expand Down Expand Up @@ -221,3 +222,59 @@ check_model_stack <- function(model_stack) {
check_inherits(model_stack, "model_stack")
}
}

# given a model stack, find the packages required to fit members and predict
# on new values, and error if any of them are not loaded
check_for_required_packages <- function(x) {
# for dispatch to required_pkgs.workflow when model
# is loaded in a fresh environment
suppressPackageStartupMessages(requireNamespace("workflows"))

pkgs <-
purrr::map(
x$model_defs,
parsnip::required_pkgs
) %>%
unlist() %>%
unique()

installed <- purrr::map_lgl(
pkgs,
is_installed_
)

if (any(!installed)) {
error_needs_install(pkgs, installed)
}

purrr::map(
pkgs,
~suppressPackageStartupMessages(requireNamespace(.x, quietly = TRUE))
)

invisible(TRUE)
}

# takes in a vector of package names and a logical vector giving
# whether or not each is installed
error_needs_install <- function(pkgs, installed) {
plural <- sum(!installed) != 1

last_sep <- if (sum(!installed) == 2) {"` and `"} else {"`, and `"}

need_install <- paste0(
"`",
glue::glue_collapse(pkgs[!installed], sep = "`, `", last = last_sep),
"`"
)

glue_stop(
"The following package{if (plural) 's' else ''} ",
"need{if (plural) '' else 's'} to be installed before ",
"fitting members: {need_install}"
)
}

is_installed_ <- function(pkg) {
rlang::is_installed(pkg)
}
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/fit_members.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# fit_members checks for required packages

The following package needs to be installed before fitting members: `a`

---

The following packages need to be installed before fitting members: `a` and `b`

---

The following packages need to be installed before fitting members: `a`, `b`, and `c`

---

The following packages need to be installed before fitting members: `recipes`, `parsnip`, and `kernlab`

34 changes: 33 additions & 1 deletion tests/testthat/test_fit_members.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ library(kernlab)
skip_if_not_installed("nnet")
library(nnet)


test_that("basic fit_members works", {
skip_on_cran()

Expand Down Expand Up @@ -121,3 +120,36 @@ test_that("fit_members errors informatively with a bad model_stack arg", {
"`model_stack` have already been fitted and need not"
)
})

test_that("fit_members checks for required packages", {
skip_on_cran()

skip_if_not_installed("mockr")

library(mockr)

# check pluralization of error
expect_snapshot_error(error_needs_install(letters[1], rep(FALSE, 1)))
expect_snapshot_error(error_needs_install(letters[1:2], rep(FALSE, 2)))
expect_snapshot_error(error_needs_install(letters[1:3], rep(FALSE, 3)))

# loads dependency when it's installed but not loaded
unloadNamespace("kernlab")

expect_s3_class(
st_reg_1_ %>%
fit_members(),
"model_stack"
)

expect_true(isNamespaceLoaded("kernlab"))

# errors informatively when it's not installed
mockr::with_mock(
is_installed_ = function(x) {FALSE},
{expect_snapshot_error(
st_reg_1_ %>%
fit_members()
)}
)
})