Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tunable.model_spec() includes rows for non-tunable arguments marked with tune() #1104

Open
simonpcouch opened this issue Apr 4, 2024 · 1 comment · May be fixed by #1105
Open

tunable.model_spec() includes rows for non-tunable arguments marked with tune() #1104

simonpcouch opened this issue Apr 4, 2024 · 1 comment · May be fixed by #1105

Comments

@simonpcouch
Copy link
Contributor

library(tidymodels)

tunable(
  linear_reg() %>% 
  set_engine("glmnet", dfmax = tune())
)
#> # A tibble: 3 × 5
#>   name    call_info        source     component  component_id
#>   <chr>   <list>           <chr>      <chr>      <chr>       
#> 1 penalty <named list [2]> model_spec linear_reg main        
#> 2 mixture <named list [3]> model_spec linear_reg main        
#> 3 dfmax   <NULL>           model_spec linear_reg engine

Created on 2024-04-04 with reprex v2.1.0

@simonpcouch
Copy link
Contributor Author

Okay, a few more notes here.

It seems like the architecture here is:

  1. tunable() gets called on a boost_tree() model spec (or other model type)
  2. tunable.boost_tree() kicks in
    2a) it first calls tunable.model_spec(),

parsnip/R/tunable.R

Lines 279 to 280 in eb526fa

tunable.boost_tree <- function(x, ...) {
res <- NextMethod()

...which only excludes rows that are both main arguments and have no tunable information associated, so there's a bunch of engine arguments with NULL call_info

parsnip/R/tunable.R

Lines 43 to 46 in eb526fa

has_info <- purrr::map_lgl(res$call_info, is.null)
rm_list <- !(has_info & (res$component_id == "main"))
res <- res[rm_list, ]

2b) then, it adds engine-specific parameter information. One way could be add_engine_parameters():

res <- add_engine_parameters(res, xgboost_engine_args)

parsnip/R/tunable.R

Lines 56 to 65 in eb526fa

add_engine_parameters <- function(pset, engines) {
is_engine_param <- pset$name %in% engines$name
if (any(is_engine_param)) {
engine_names <- pset$name[is_engine_param]
pset <- pset[!is_engine_param,]
pset <-
dplyr::bind_rows(pset, engines %>% dplyr::filter(name %in% engines$name))
}
pset
}

xgboost_engine_args is an inlined tibble in the source:

parsnip/R/tunable.R

Lines 91 to 106 in eb526fa

xgboost_engine_args <-
tibble::tibble(
name = c(
"alpha",
"lambda",
"scale_pos_weight"
),
call_info = list(
list(pkg = "dials", fun = "penalty_L1"),
list(pkg = "dials", fun = "penalty_L2"),
list(pkg = "dials", fun = "scale_pos_weight")
),
source = "model_spec",
component = "boost_tree",
component_id = "engine"
)

...the other possible way is to manually insert call_info entries:

parsnip/R/tunable.R

Lines 283 to 286 in eb526fa

res$call_info[res$name == "sample_size"] <-
list(list(pkg = "dials", fun = "sample_prop"))
res$call_info[res$name == "learn_rate"] <-
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))

Barring a rewrite that stores tunable information in the model environment (see #826), I think our solution might be to filter out rows with NULL call_info at the end of existing non-model_spec tunable() methods.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant