Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Version: 1.2.1.9000
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
person("Posit Software, PBC", role = c("cph", "fnd"))
)
Description: The ability to tune models is important. 'tune' contains
functions and classes to be used in conjunction with other
Expand All @@ -27,12 +27,12 @@ Imports:
ggplot2,
glue (>= 1.6.2),
GPfit,
hardhat (>= 1.2.0),
hardhat (>= 1.4.0.9002),
lifecycle (>= 1.0.0),
parsnip (>= 1.2.0),
parsnip (>= 1.2.1.9003),
purrr (>= 1.0.0),
recipes (>= 1.0.4),
rlang (>= 1.1.0),
recipes (>= 1.1.0.9001),
rlang (>= 1.1.4),
rsample (>= 1.2.1.9000),
tailor,
tibble (>= 3.1.0),
Expand All @@ -57,8 +57,11 @@ Suggests:
xgboost,
xml2
Remotes:
tidymodels/hardhat,
tidymodels/parsnip,
tidymodels/recipes,
tidymodels/rsample,
tidymodels/tailor,
tidymodels/tailor,
tidymodels/workflows
Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture,
tidyverse/tidytemplate
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ export(tune_bayes)
export(tune_grid)
export(val_class_and_single)
export(val_class_or_null)
import(rlang)
import(vctrs)
import(workflows)
importFrom(GPfit,GP_fit)
Expand Down
2 changes: 1 addition & 1 deletion R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#' @importFrom cli cli_inform cli_warn cli_abort qty
#' @importFrom foreach foreach getDoParName %dopar%
#' @importFrom tibble obj_sum size_sum

#' @import rlang

# ------------------------------------------------------------------------------
# Only a small number of functions in workflows.
Expand Down
10 changes: 5 additions & 5 deletions R/acquisition.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ print.prob_improve <- function(x, ...) {
#' @export
predict.prob_improve <-
function(object, new_data, maximize, iter, best, ...) {
check_direction(maximize)
check_best(best)
check_bool(maximize)
check_number_decimal(best, allow_infinite = FALSE)

if (is.function(object$trade_off)) {
trade_off <- object$trade_off(iter)
Expand Down Expand Up @@ -126,8 +126,8 @@ exp_improve <- function(trade_off = 0, eps = .Machine$double.eps) {

#' @export
predict.exp_improve <- function(object, new_data, maximize, iter, best, ...) {
check_direction(maximize)
check_best(best)
check_bool(maximize)
check_number_decimal(best, allow_infinite = FALSE)

if (is.function(object$trade_off)) {
trade_off <- object$trade_off(iter)
Expand Down Expand Up @@ -177,7 +177,7 @@ conf_bound <- function(kappa = 0.1) {

#' @export
predict.conf_bound <- function(object, new_data, maximize, iter, ...) {
check_direction(maximize)
check_bool(maximize)

if (is.function(object$kappa)) {
kappa <- object$kappa(iter)
Expand Down
31 changes: 10 additions & 21 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -493,26 +493,6 @@ get_objective_name <- function(x, metrics) {
x
}


# ------------------------------------------------------------------------------
# acq functions

check_direction <- function(x) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DELETE

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SO SATISFYING

if (!is.logical(x) || length(x) != 1) {
rlang::abort("`maximize` should be a single logical.")
}
invisible(NULL)
}


check_best <- function(x) {
if (!is.numeric(x) || length(x) != 1 || is.na(x)) {
rlang::abort("`best` should be a single, non-missing numeric.")
}
invisible(NULL)
}


# ------------------------------------------------------------------------------

check_class_or_null <- function(x, cls = "numeric") {
Expand All @@ -537,6 +517,7 @@ val_class_or_null <- function(x, cls = "numeric", where = NULL) {
}
invisible(NULL)
}
# TODO remove this once finetune is updated
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself and @simonpcouch to make similar changes to finetune so that we can remove these three checking functions


check_class_and_single <- function(x, cls = "numeric") {
isTRUE(inherits(x, cls) & length(x) == 1)
Expand All @@ -558,7 +539,7 @@ val_class_and_single <- function(x, cls = "numeric", where = NULL) {
}
invisible(NULL)
}

# TODO remove this once finetune is updated

# Check the data going into the GP. If there are all missing values, fail. If some
# are missing, remove them and send a warning. If all metrics are the same, fail.
Expand Down Expand Up @@ -644,3 +625,11 @@ check_eval_time <- function(eval_time, metrics) {
invisible(NULL)

}

check_time_limit_arg <- function(x, call = rlang::caller_env()) {
if (!inherits(x, c("logical", "numeric")) || length(x) != 1L) {
cli::cli_abort("{.arg time_limit} should be either a single numeric or
logical value.", call = call)
}
invisible(NULL)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL that this control option existed😆

5 changes: 1 addition & 4 deletions R/compute_metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ compute_metrics.tune_results <- function(x,
summarize = TRUE,
event_level = "first") {
rlang::check_dots_empty()
check_bool(summarize)
if (!".predictions" %in% names(x)) {
rlang::abort(paste0(
"`x` must have been generated with the ",
Expand Down Expand Up @@ -114,10 +115,6 @@ compute_metrics.tune_results <- function(x,
))
}

if (!inherits(summarize, "logical") || length(summarize) != 1L) {
rlang::abort("The `summarize` argument must be a single logical value.")
}

param_names <- .get_tune_parameter_names(x)
outcome_name <- .get_tune_outcome_names(x)

Expand Down
55 changes: 27 additions & 28 deletions R/control.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ control_grid <- function(verbose = FALSE, allow_par = TRUE,
# Any added arguments should also be added in superset control functions
# in other packages

# add options for seeds per resample
# add options for seeds per resample
check_bool(verbose)
check_bool(allow_par)
check_bool(save_pred)
check_bool(save_workflow)
check_string(event_level)
check_character(pkgs, allow_null = TRUE)
check_function(extract, allow_null = TRUE)

val_class_and_single(verbose, "logical", "control_grid()")
val_class_and_single(allow_par, "logical", "control_grid()")
val_class_and_single(save_pred, "logical", "control_grid()")
val_class_and_single(save_workflow, "logical", "control_grid()")
val_class_and_single(event_level, "character", "control_grid()")
val_class_or_null(pkgs, "character", "control_grid()")
val_class_or_null(extract, "function", "control_grid()")
val_parallel_over(parallel_over, "control_grid()")


Expand Down Expand Up @@ -241,26 +241,27 @@ control_bayes <-
# in other packages

# add options for seeds per resample
check_bool(verbose)
check_bool(verbose_iter)
check_bool(allow_par)
check_bool(save_pred)
check_bool(save_workflow)
check_bool(save_gp_scoring)
check_character(pkgs, allow_null = TRUE)
check_function(extract, allow_null = TRUE)
check_number_whole(no_improve, min = 0, allow_infinite = TRUE)
check_number_whole(uncertain, min = 0, allow_infinite = TRUE)
check_number_whole(seed)

check_time_limit_arg(time_limit)

val_class_and_single(verbose, "logical", "control_bayes()")
val_class_and_single(verbose_iter, "logical", "control_bayes()")
val_class_and_single(save_pred, "logical", "control_bayes()")
val_class_and_single(save_gp_scoring, "logical", "control_bayes()")
val_class_and_single(save_workflow, "logical", "control_bayes()")
val_class_and_single(no_improve, c("numeric", "integer"), "control_bayes()")
val_class_and_single(uncertain, c("numeric", "integer"), "control_bayes()")
val_class_and_single(seed, c("numeric", "integer"), "control_bayes()")
val_class_or_null(extract, "function", "control_bayes()")
val_class_and_single(time_limit, c("logical", "numeric"), "control_bayes()")
val_class_or_null(pkgs, "character", "control_bayes()")
val_class_and_single(event_level, "character", "control_bayes()")
val_parallel_over(parallel_over, "control_bayes()")
val_class_and_single(allow_par, "logical", "control_bayes()")


if (!is.infinite(uncertain) && uncertain > no_improve) {
cli::cli_alert_warning(
"Uncertainty sample scheduled after {uncertain} poor iterations but the search will stop after {no_improve}."
cli::cli_warn(
"Uncertainty sample scheduled after {uncertain} poor iterations but the
search will stop after {no_improve}."
)
}

Expand Down Expand Up @@ -296,13 +297,11 @@ print.control_bayes <- function(x, ...) {
# ------------------------------------------------------------------------------

val_parallel_over <- function(parallel_over, where) {
if (is.null(parallel_over)) {
return(invisible(NULL))
check_string(parallel_over, allow_null = TRUE)
if (!is.null(parallel_over)) {
rlang::arg_match0(parallel_over, c("resamples", "everything"), "parallel_over")
}

val_class_and_single(parallel_over, "character", where)
rlang::arg_match0(parallel_over, c("resamples", "everything"), "parallel_over")

invisible(NULL)
}

Expand Down
4 changes: 1 addition & 3 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ extract_spec_parsnip.tune_results <- function(x, ...) {
#' @rdname extract-tune
extract_recipe.tune_results <- function(x, ..., estimated = TRUE) {
check_empty_dots(...)
if (!rlang::is_bool(estimated)) {
rlang::abort("`estimated` must be a single `TRUE` or `FALSE`.")
}
check_bool(estimated)
extract_recipe(extract_workflow(x), estimated = estimated)
}
check_empty_dots <- function(...) {
Expand Down
Loading