From 669973029f7620c18e94dc4550c3e485e5028df5 Mon Sep 17 00:00:00 2001 From: "Simon P. Couch" Date: Mon, 9 Sep 2024 11:40:03 -0500 Subject: [PATCH] use type checkers in `aaa_models.R` (#1180) --- R/aaa_models.R | 217 ++++++++++++-------------- R/predict_classprob.R | 2 +- tests/testthat/_snaps/registration.md | 2 +- tests/testthat/_snaps/svm_linear.md | 4 +- 4 files changed, 108 insertions(+), 117 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 7af9c94b9..1ebe8b4ef 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -70,6 +70,7 @@ get_model_env <- function() { #' @keywords internal #' @export get_from_env <- function(items) { + check_character(items) mod_env <- get_model_env() rlang::env_get(mod_env, items, default = NULL) } @@ -86,12 +87,7 @@ set_in_env <- function(...) { #' @keywords internal #' @export set_env_val <- function(name, value) { - if (length(name) != 1 || !is.character(name)) { - cli::cli_abort( - "{.arg name} should be a single character value, \\ - not {.obj_type_friendly {name}}." - ) - } + check_string(name, allow_empty = FALSE) mod_env <- get_model_env() x <- list(value) names(x) <- name @@ -116,61 +112,43 @@ error_set_object <- function(object, func) { cli::cli_abort(msg, call = call2(func)) } -check_eng_val <- function(eng) { - if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) { - cli::cli_abort( - "Please supply a character string for an engine name (e.g. {.val lm}), \\ - not {.obj_type_friendly {eng}}." - ) - } +check_eng_val <- function(eng, call = caller_env()) { + check_string(eng, allow_empty = FALSE, call = call) invisible(NULL) } -check_model_exists <- function(model) { - if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) { - cli::cli_abort( - "Please supply a character string for a model name \\ - (e.g. {.val linear_reg}), not {.obj_type_friendly {model}}." - ) - } +check_model_exists <- function(model, call = caller_env()) { + check_string(model, allow_empty = FALSE, call = call) current <- get_model_env() if (!any(current$models == model)) { cli::cli_abort( - "Model {.val {model}} has not been registered." + "Model {.val {model}} has not been registered.", + call = call ) } invisible(NULL) } -check_model_doesnt_exist <- function(model) { - if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) { - cli::cli_abort( - "Please supply a character string for a model name \\ - (e.g. {.val linear_reg}), not {.obj_type_friendly {model}}." - ) - } +check_model_doesnt_exist <- function(model, call = caller_env()) { + check_string(model, allow_empty = FALSE, call = call) current <- get_model_env() if (any(current$models == model)) { cli::cli_abort( - "Model {.val {model}} already exists." + "Model {.val {model}} already exists.", + call = call ) } invisible(NULL) } -check_mode_val <- function(mode) { - if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) { - cli::cli_abort( - "Please supply a character string for a mode \\ - (e.g. {.val regression}), not {.obj_type_friendly {mode}}." - ) - } +check_mode_val <- function(mode, call = caller_env()) { + check_string(mode, allow_empty = FALSE, call = call) invisible(NULL) } @@ -216,11 +194,12 @@ stop_missing_engine <- function(cls, call) { ) } -check_mode_for_new_engine <- function(cls, eng, mode) { +check_mode_for_new_engine <- function(cls, eng, mode, call = caller_env()) { all_modes <- get_from_env(paste0(cls, "_modes")) if (!(mode %in% all_modes)) { cli::cli_abort( - "{.val {mode}} is not a known mode for model {.fn {cls}}." + "{.val {mode}} is not a known mode for model {.fn {cls}}.", + call = call ) } invisible(NULL) @@ -316,44 +295,25 @@ check_mode_with_no_engine <- function(cls, mode, call) { } } -check_arg_val <- function(arg) { - if (rlang::is_missing(arg) || length(arg) != 1 || !is.character(arg)) - cli::cli_abort( - "Please supply a character string for the argument, \\ - not {.obj_type_friendly {arg}}." - ) - invisible(NULL) -} - -check_submodels_val <- function(has_submodel, call = caller_env()) { - if (!is.logical(has_submodel) || length(has_submodel) != 1) { - cli::cli_abort( - "The {.arg submodels} argument should be a single logical. \\ - not {.obj_type_friendly {has_submodel}}", - call = call - ) - } - invisible(NULL) -} - -check_func_val <- function(func) { +check_func_val <- function(func, call = caller_env()) { msg <- "{.arg func} should be a named vector with element {.field fun} and \\ the optional elements {.field pkg}, {.field range}, {.field trans}, \\ and {.field values}. {.field func} and {.field pkg} should both be \\ single character strings." - if (rlang::is_missing(func) || !is.vector(func)) - cli::cli_abort(msg) + if (rlang::is_missing(func) || !is.vector(func)) { + cli::cli_abort(msg, call = call) + } nms <- sort(names(func)) if (all(is.null(nms))) { - cli::cli_abort(msg) + cli::cli_abort(msg, call = call) } if (length(func) == 1) { if (isTRUE(any(nms != "fun"))) { - cli::cli_abort(msg) + cli::cli_abort(msg, call = call) } } else { # check for extra names: @@ -361,23 +321,31 @@ check_func_val <- function(func) { nm_check <- nms %in% c("fun", "pkg", "range", "trans", "values") not_allowed <- nms[!(nms %in% allow_nms)] if (length(not_allowed) > 0) { - cli::cli_abort(msg) + cli::cli_abort(msg, call = call) } } - if (!is.character(func[["fun"]])) { - cli::cli_abort(msg) - } - if (any(nms == "pkg") && !is.character(func[["pkg"]])) { - cli::cli_abort(msg) + check_string( + func[["fun"]], + allow_empty = FALSE, + arg = I("The `fun` element of `func`"), + call = call + ) + if (any(nms == "pkg")) { + check_string( + func[["pkg"]], + allow_null = TRUE, + arg = I("The `pkg` element of `func`"), + call = call + ) } invisible(NULL) } -check_fit_info <- function(fit_obj) { +check_fit_info <- function(fit_obj, call = caller_env()) { if (is.null(fit_obj)) { - cli::cli_abort("The {.arg fit_obj} argument cannot be NULL.") + cli::cli_abort("The {.arg fit_obj} argument cannot be NULL.", call = call) } # check required data elements @@ -386,7 +354,8 @@ check_fit_info <- function(fit_obj) { if (!all(has_req_nms)) { cli::cli_abort( - "The {.arg value} argument should have elements: {.field {exp_nms}}." + "The {.arg value} argument should have elements: {.field {exp_nms}}.", + call = call ) } @@ -397,14 +366,16 @@ check_fit_info <- function(fit_obj) { if (any(!has_opt_nms)) { cli::cli_abort( "The {.arg value} argument can only have optional elements: \\ - {.field {exp_nms}}." + {.field {exp_nms}}.", + call = call ) } if (any(other_nms == "data")) { data_nms <- names(fit_obj$data) if (length(data_nms == 0) || any(data_nms == "")) { cli::cli_abort( - "All elements of the {.field data} argument vector must be named." + "All elements of the {.field data} argument vector must be named.", + call = call ) } } @@ -413,37 +384,49 @@ check_fit_info <- function(fit_obj) { check_func_val(fit_obj$func) if (!is.list(fit_obj$defaults)) { - cli::cli_abort("The {.field defaults} element should be a list.") + cli::cli_abort("The {.field defaults} element should be a list.", call = call) } invisible(NULL) } -check_pred_info <- function(pred_obj, type) { +check_pred_info <- function(pred_obj, type, call = caller_env()) { if (all(type != pred_types)) { cli::cli_abort( - "The prediction type should be one of: {.val {pred_types}}." + "The prediction type should be one of: {.val {pred_types}}.", + call = call ) } exp_nms <- c("args", "func", "post", "pre") if (!isTRUE(all.equal(sort(names(pred_obj)), exp_nms))) { cli::cli_abort( - "The {.field predict} module should have elements: {.val {exp_nms}}." + "The {.field predict} module should have elements: {.val {exp_nms}}.", + call = call ) } - if (!is.null(pred_obj$pre) & !is.function(pred_obj$pre)) { - cli::cli_abort("The {.field pre} module should be null or a function.") - } - if (!is.null(pred_obj$post) & !is.function(pred_obj$post)) { - cli::cli_abort("The {.field post} module should be null or a function.") - } + check_function( + pred_obj$pre, + allow_null = TRUE, + arg = I("The `pre` element of `pred_obj`"), + call = call + ) + + check_function( + pred_obj$post, + allow_null = TRUE, + arg = I("The `post` element of `pred_obj`"), + call = call + ) check_func_val(pred_obj$func) if (!is.list(pred_obj$args)) { - cli::cli_abort("The {.field args} element should be a list.") + cli::cli_abort( + "The {.field args} element should be a list.", + call = call + ) } invisible(NULL) @@ -454,31 +437,24 @@ spec_has_pred_type <- function(object, type) { any(possible_preds == type) } -check_spec_pred_type <- function(object, type) { +check_spec_pred_type <- function(object, type, call = caller_env()) { if (!spec_has_pred_type(object, type)) { possible_preds <- names(object$spec$method$pred) cli::cli_abort( "No {.val {type}} prediction method available for this model. \\ - {.arg type} should be one of: {.val {possible_preds}}." - ) - } - invisible(NULL) -} - -check_pkg_val <- function(pkg) { - if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) { - cli::cli_abort( - "Please supply a single character value for the package name." + {.arg type} should be one of: {.val {possible_preds}}.", + call = call ) } invisible(NULL) } -check_interface_val <- function(x) { +check_interface_val <- function(x, call = call) { exp_interf <- c("data.frame", "formula", "matrix") if (length(x) != 1 || !(x %in% exp_interf)) { cli::cli_abort( - "The {.field interface} element should have a single of: {exp_interf}." + "The {.field interface} element should have a single of: {exp_interf}.", + call = call ) } invisible(NULL) @@ -680,10 +656,10 @@ set_model_engine <- function(model, mode, eng) { set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { check_model_exists(model) check_eng_val(eng) - check_arg_val(parsnip) - check_arg_val(original) + check_string(parsnip, allow_empty = FALSE) + check_string(original, allow_empty = FALSE) check_func_val(func) - check_submodels_val(has_submodel) + check_bool(has_submodel) old_args <- get_from_env(paste0(model, "_args")) @@ -698,7 +674,10 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE) if (inherits(updated, "try-error")) { - cli::cli_abort("An error occurred when adding the new argument.") + cli::cli_abort( + "An error occurred when adding the new argument.", + parent = updated + ) } updated <- vctrs::vec_unique(updated) @@ -716,7 +695,7 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { set_dependency <- function(model, eng, pkg = "parsnip", mode = NULL) { check_model_exists(model) check_eng_val(eng) - check_pkg_val(pkg) + check_string(pkg, allow_empty = FALSE) model_info <- get_from_env(model) pkg_info <- get_from_env(paste0(model, "_pkgs")) @@ -801,7 +780,8 @@ get_dependency <- function(model) { # same model/mode/engine (and prediction type). If it already exists and the # new information is different, fail with a message. See issue #653 is_discordant_info <- function(model, mode, eng, candidate, - pred_type = NULL, component = "fit") { + pred_type = NULL, component = "fit", + call = caller_env()) { current <- get_from_env(paste0(model, "_", component)) # For older versions of parsnip before set_encoding() @@ -831,7 +811,8 @@ is_discordant_info <- function(model, mode, eng, candidate, cli::cli_abort( "The combination of engine {.var {eng}} and mode {.var {mode}} \\ {.val {p_type}} already has {component} data for model {.var {model}} \\ - and the new information being registered is different." + and the new information being registered is different.", + call = call ) } @@ -840,7 +821,7 @@ is_discordant_info <- function(model, mode, eng, candidate, # Also check for general registration -check_unregistered <- function(model, mode, eng) { +check_unregistered <- function(model, mode, eng, call = caller_env()) { model_info <- get_from_env(model) has_engine <- model_info %>% @@ -849,7 +830,9 @@ check_unregistered <- function(model, mode, eng) { if (has_engine != 1) { cli::cli_abort( "The combination of engine {.var {eng}} and mode {.var {mode}} has not \\ - been registered for model {.var {model}}.") + been registered for model {.var {model}}.", + call = call + ) } invisible(NULL) } @@ -880,7 +863,10 @@ set_fit <- function(model, mode, eng, value) { old_fits <- get_from_env(paste0(model, "_fit")) updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE) if (inherits(updated, "try-error")) { - cli::cli_abort("An error occurred when adding the new fit module.") + cli::cli_abort( + "An error occurred when adding the new fit module.", + parent = updated + ) } set_env_val( @@ -931,7 +917,10 @@ set_pred <- function(model, mode, eng, type, value) { old_pred <- get_from_env(paste0(model, "_predict")) updated <- try(dplyr::bind_rows(old_pred, new_pred), silent = TRUE) if (inherits(updated, "try-error")) { - cli::cli_abort("An error occurred when adding the new fit module.") + cli::cli_abort( + "An error occurred when adding the new fit module.", + parent = updated + ) } set_env_val(paste0(model, "_predict"), updated) @@ -1088,9 +1077,9 @@ pred_value_template <- function(pre = NULL, post = NULL, func, ...) { # ------------------------------------------------------------------------------ -check_encodings <- function(x) { +check_encodings <- function(x, call = caller_env()) { if (!is.list(x)) { - cli::cli_abort("{.arg values} should be a list.") + cli::cli_abort("{.arg values} should be a list.", call = call) } req_args <- list(predictor_indicators = rlang::na_chr, compute_intercept = rlang::na_lgl, @@ -1101,14 +1090,16 @@ check_encodings <- function(x) { if (length(missing_args) > 0) { cli::cli_abort( "The values passed to {.fn set_encoding} are missing arguments: \\ - {.field {missing_args}}." + {.field {missing_args}}.", + call = call ) } extra_args <- setdiff(names(x), names(req_args)) if (length(extra_args) > 0) { cli::cli_abort( "The values passed to {.fn set_encoding} had extra arguments: \\ - {.arg {extra_args}}." + {.arg {extra_args}}.", + call = call ) } invisible(x) diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 4fb5fb957..86bea49d3 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -9,7 +9,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) { cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.") } - check_spec_pred_type(object, "prob") + check_spec_pred_type(object, "prob", call = caller_env()) check_spec_levels(object) if (inherits(object$fit, "try-error")) { diff --git a/tests/testthat/_snaps/registration.md b/tests/testthat/_snaps/registration.md index 723bacac8..f4f5bf243 100644 --- a/tests/testthat/_snaps/registration.md +++ b/tests/testthat/_snaps/registration.md @@ -3,7 +3,7 @@ Code set_model_engine("sponge", mode = "regression", eng = "gum") Condition - Error in `check_mode_for_new_engine()`: + Error in `set_model_engine()`: ! "regression" is not a known mode for model `sponge()`. # showing model info diff --git a/tests/testthat/_snaps/svm_linear.md b/tests/testthat/_snaps/svm_linear.md index d493980dd..cce55900c 100644 --- a/tests/testthat/_snaps/svm_linear.md +++ b/tests/testthat/_snaps/svm_linear.md @@ -20,7 +20,7 @@ Code predict(cls_form, hpc_no_m[ind, -5], type = "prob") Condition - Error in `check_spec_pred_type()`: + Error in `predict()`: ! No "prob" prediction method available for this model. `type` should be one of: "class" and "raw". --- @@ -28,6 +28,6 @@ Code predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob") Condition - Error in `check_spec_pred_type()`: + Error in `predict()`: ! No "prob" prediction method available for this model. `type` should be one of: "class" and "raw".