Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

surv.aorsf: mtry_ratio can result in mtry = 1, which causes an error #259

Closed
jemus42 opened this issue Dec 7, 2022 · 11 comments
Closed

Comments

@jemus42
Copy link
Member

jemus42 commented Dec 7, 2022

Description

mtry_ratio can take valid values that, depending on the task, result in invalid values for mtry.

Maybe less a bug than likely undesired behavior where I'm not sure if this can/should be addressed in the learner or left as an exercise to the user, so to speak.

When mtry_ratio has an implicit lower bound not equal to 0, this can result in unintended behavior in benchmark scenarios where feature counts may vary widely across tasks, but mtry_ratio is tuned within [0, 1].

I'm not sure if there's anything that could be done learner-wise without causing other issues tho :/

That being said, I guess at least the mtry lower bound should be increased, as it's currently set to 1:

mtry = p_int(default = NULL, lower = 1L,

Reproducible example

library(mlr3)
library(mlr3proba)
library(mlr3extralearners)

testlrn = lrn("surv.aorsf",  n_tree = 50, control_type = "fast", mtry_ratio = 0.1)

# task "whas" has 9 features, mtry_ratio = 0.1 leads to mtry = 1
testlrn$train(tsk("whas"))
#> Error: mtry = 1 should be >= 2

Created on 2022-12-07 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.2 (2022-10-31)
#>  os       macOS Ventura 13.0.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/Berlin
#>  date     2022-12-07
#>  pandoc   2.19.2 @ /System/Volumes/Data/Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package           * version     date (UTC) lib source
#>  aorsf               0.0.4       2022-11-07 [1] CRAN (R 4.2.1)
#>  assertthat          0.2.1       2019-03-21 [1] CRAN (R 4.2.0)
#>  backports           1.4.1       2021-12-13 [1] CRAN (R 4.2.0)
#>  checkmate           2.1.0       2022-04-21 [1] CRAN (R 4.2.0)
#>  cli                 3.4.1       2022-09-23 [1] CRAN (R 4.2.0)
#>  codetools           0.2-18      2020-11-04 [1] CRAN (R 4.2.0)
#>  collapse            1.8.9       2022-10-07 [1] CRAN (R 4.2.0)
#>  colorspace          2.0-3       2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon              1.5.2       2022-09-29 [1] CRAN (R 4.2.0)
#>  data.table          1.14.6      2022-11-16 [1] CRAN (R 4.2.1)
#>  DBI                 1.1.3       2022-06-18 [1] CRAN (R 4.2.0)
#>  dictionar6          0.1.3       2021-09-13 [1] CRAN (R 4.2.0)
#>  digest              0.6.30      2022-10-18 [1] CRAN (R 4.2.0)
#>  distr6              1.6.11      2022-09-07 [1] Github (alan-turing-institute/distr6@8cae431)
#>  dplyr               1.0.10      2022-09-01 [1] CRAN (R 4.2.0)
#>  evaluate            0.18        2022-11-07 [1] CRAN (R 4.2.1)
#>  fansi               1.0.3       2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap             1.1.0       2021-01-25 [1] CRAN (R 4.2.0)
#>  fs                  1.5.2       2021-12-08 [1] CRAN (R 4.2.0)
#>  future              1.29.0      2022-11-06 [1] CRAN (R 4.2.1)
#>  generics            0.1.3       2022-07-05 [1] CRAN (R 4.2.0)
#>  ggplot2             3.4.0       2022-11-04 [1] CRAN (R 4.2.1)
#>  globals             0.16.2      2022-11-21 [1] CRAN (R 4.2.2)
#>  glue                1.6.2       2022-02-24 [1] CRAN (R 4.2.0)
#>  gtable              0.3.1       2022-09-01 [1] CRAN (R 4.2.0)
#>  highr               0.9         2021-04-16 [1] CRAN (R 4.2.0)
#>  htmltools           0.5.3       2022-07-18 [1] CRAN (R 4.2.1)
#>  knitr               1.41        2022-11-18 [1] CRAN (R 4.2.1)
#>  lattice             0.20-45     2021-09-22 [1] CRAN (R 4.2.0)
#>  lgr                 0.4.4       2022-09-05 [1] CRAN (R 4.2.1)
#>  lifecycle           1.0.3       2022-10-07 [1] CRAN (R 4.2.1)
#>  listenv             0.8.0       2019-12-05 [1] CRAN (R 4.2.0)
#>  magrittr            2.0.3       2022-03-30 [1] CRAN (R 4.2.0)
#>  Matrix              1.5-3       2022-11-11 [1] CRAN (R 4.2.0)
#>  mlr3              * 0.14.1      2022-12-01 [1] Github (mlr-org/mlr3@6cd6ee3)
#>  mlr3extralearners * 0.6.0-9000  2022-12-03 [1] Github (mlr-org/mlr3extralearners@5c7780e)
#>  mlr3misc            0.11.0      2022-09-22 [1] CRAN (R 4.2.0)
#>  mlr3pipelines       0.4.2-9000  2022-11-23 [1] Github (mlr-org/mlr3pipelines@5d8d8ab)
#>  mlr3proba         * 0.4.15      2022-12-07 [1] Github (mlr-org/mlr3proba@e208ec0)
#>  mlr3viz             0.5.10      2022-08-15 [1] CRAN (R 4.2.0)
#>  munsell             0.5.0       2018-06-12 [1] CRAN (R 4.2.0)
#>  ooplah              0.2.0       2022-01-21 [1] CRAN (R 4.2.0)
#>  palmerpenguins      0.1.1       2022-08-15 [1] CRAN (R 4.2.0)
#>  paradox             0.10.0.9000 2022-11-20 [1] Github (mlr-org/paradox@3be7ba6)
#>  parallelly          1.32.1      2022-07-21 [1] CRAN (R 4.2.0)
#>  param6              0.2.4       2022-09-07 [1] Github (xoopR/param6@0fa3577)
#>  pillar              1.8.1       2022-08-19 [1] CRAN (R 4.2.0)
#>  pkgconfig           2.0.3       2019-09-22 [1] CRAN (R 4.2.0)
#>  pracma              2.4.2       2022-09-22 [1] CRAN (R 4.2.0)
#>  purrr               0.3.5       2022-10-06 [1] CRAN (R 4.2.1)
#>  R.cache             0.16.0      2022-07-21 [1] CRAN (R 4.2.0)
#>  R.methodsS3         1.8.2       2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo                1.25.0      2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils             2.12.2      2022-11-11 [1] CRAN (R 4.2.1)
#>  R6                  2.5.1       2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp                1.0.9       2022-07-08 [1] CRAN (R 4.2.1)
#>  reprex              2.0.2       2022-08-17 [1] CRAN (R 4.2.0)
#>  rlang               1.0.6       2022-09-24 [1] CRAN (R 4.2.0)
#>  rmarkdown           2.18        2022-11-09 [1] CRAN (R 4.2.1)
#>  rstudioapi          0.14        2022-08-22 [1] CRAN (R 4.2.1)
#>  scales              1.2.1       2022-08-20 [1] CRAN (R 4.2.0)
#>  sessioninfo         1.2.2       2021-12-06 [1] CRAN (R 4.2.0)
#>  set6                0.2.5       2022-09-07 [1] Github (xoopR/set6@e65ffee)
#>  stringi             1.7.8       2022-07-11 [1] CRAN (R 4.2.0)
#>  stringr             1.5.0       2022-12-02 [1] CRAN (R 4.2.0)
#>  styler              1.8.1.9000  2022-12-03 [1] Github (r-lib/styler@d137eb6)
#>  survival            3.4-0       2022-08-09 [1] CRAN (R 4.2.0)
#>  tibble              3.1.8       2022-07-22 [1] CRAN (R 4.2.0)
#>  tidyselect          1.2.0       2022-10-10 [1] CRAN (R 4.2.0)
#>  utf8                1.2.2       2021-07-24 [1] CRAN (R 4.2.0)
#>  uuid                1.1-0       2022-04-19 [1] CRAN (R 4.2.0)
#>  vctrs               0.5.1       2022-11-16 [1] CRAN (R 4.2.1)
#>  withr               2.5.0       2022-03-03 [1] CRAN (R 4.2.0)
#>  xfun                0.35        2022-11-16 [1] CRAN (R 4.2.1)
#>  yaml                2.3.6       2022-10-18 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Users/Lukas/Library/R/arm64/4.2/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────
@sebffischer
Copy link
Member

So I suggest to:

  1. Replace the value 1 in mtry = max(ceiling(mtry_ratio * n_features), 1) to 2
  2. Increase the lower bound as you suggested

@bcjaeger Do you think that is a good idea?

@jemus42
Copy link
Member Author

jemus42 commented Dec 8, 2022

I'm a little worried about tuning though.
Let's say you do a random search over mtry_ratio in [0, 1], uniformly sampling in that range will then lead to a "clump" at mtry = 2 for low values of mtry_ratio (depending on n_features), and I'm not sure whether that's problematic or acceptable.

@bcjaeger
Copy link
Contributor

bcjaeger commented Dec 8, 2022

Hi! Thanks for noticing this,

@bcjaeger Do you think that is a good idea?

I do, both 1. and 2. make sense. I could remove this error from orsf() if that would be more helpful. It's possible to fit oblique random forests with one predictor variable, it's just not really 'oblique'.

Let's say you do a random search over mtry_ratio in [0, 1], uniformly sampling in that range will then lead to a "clump" at mtry = 2 for low values of mtry_ratio (depending on n_features), and I'm not sure whether that's problematic or acceptable.

This is a good point. For orsf learners, could we restrict the mtry_ratio search to cover [alpha, 1], where alpha is the minimally acceptable mtry_ratio? (the value of mtry_ratio that would give mtry = 2.)

@jemus42
Copy link
Member Author

jemus42 commented Dec 8, 2022

For orsf learners, could we restrict the mtry_ratio search to cover [alpha, 1], where alpha is the minimally acceptable mtry_ratio?

I don't know if/how we can, since alpha would depend on n_features, which I don't think can be used in a search space / ParamSet for tuning. We can specify e.g. mtry_ratio = p_dbl(0, 1) - which I guess is a thing in the first place because tuning mtry directly would require knowledge of n_features to find a meaningful upper bound.
And in a benchmark where n_features varies, that's a littly icky :/

Allowing mtry = 1 would be "easy", I agree, but you are of course right that at this point it would be something like a degenerate case - kind of like running a random forest but setting n_tree = 1 🥴

@bcjaeger
Copy link
Contributor

bcjaeger commented Dec 8, 2022

That definitely makes sense.

I will check out aorsf and see if allowing mtry = 1 causes any unexpected failures in the C++ routines. If it doesn't, I think I can change the error for 1 predictor to be a warning that indicates oblique splitting doesn't really apply if only one predictor is supplied.

If aorsf gives a warning instead of an error in this scenario, would it resolve this issue?

@sebffischer
Copy link
Member

sebffischer commented Dec 8, 2022

I think a warning would be better and would solve the issue.

@bcjaeger
Copy link
Contributor

bcjaeger commented Dec 8, 2022

It looks like allowing mtry = 1 is going to work out for aorsf. I've pushed this change to the main branch in ropensci/aorsf and will submit to CRAN next week.

As far as changes in mlr3 go, I would still be in support of @sebffischer's solution if you would like to avoid seeing a warning message about 1 predictor being used in orsf().

@sebffischer
Copy link
Member

Unfortunately I don't really know the algorithm. Is it possible that there are situations where one predictor might be better than 2? In that case I would even argue that the warning should be submitted and users should know what they are doing when they are setting parameters. I would argue the warning belongs into the documentation and not the code.

But I think it is your call @bcjaeger

@bcjaeger
Copy link
Contributor

bcjaeger commented Dec 9, 2022

Decision trees can be axis based or oblique. Axis based trees use one predictor to partition data into two new branches of the tree, while oblique trees use a linear combination of predictors. A linear combination of one predictor is the same thing as the original predictor as far as a decision tree split is concerned, so an oblique tree with one predictor is the same thing as an axis based one.

I'm okay with putting the warning into documentation. I would guess that using mtry = 1 while fitting an orsf() model is not likely to be common outside of tuning routines, so a note in the docs is probably all we need. I have to travel today and this weekend but I should be able to get this done early next week.

@sebffischer
Copy link
Member

Thanks a lot! :)

@bcjaeger
Copy link
Contributor

Hello! Just wanted to let you know orsf() allows mtry= 1 now and does not present a warning:

library(aorsf)
orsf(pbc_orsf, time + status ~ . -id, mtry = 1)
#> ---------- Oblique random survival forest
#> 
#>      Linear combinations: Accelerated
#>           N observations: 276
#>                 N events: 111
#>                  N trees: 500
#>       N predictors total: 17
#>    N predictors per node: 1
#>  Average leaves per tree: 20
#> Min observations in leaf: 5
#>       Min events in leaf: 1
#>           OOB stat value: 0.84
#>            OOB stat type: Harrell's C-statistic
#>      Variable importance: anova
#> 
#> -----------------------------------------

Created on 2022-12-14 with reprex v2.0.2

This change will be in version 0.0.5 of aorsf, which I submitted to CRAN earlier today.

sebffischer added a commit that referenced this issue Dec 15, 2022
The learner failed when setting mtry to 1. This was addressed
in the new aorsf release.

#259
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants