Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue51 #52

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -2883,6 +2883,7 @@ ObliqueForest <- R6::R6Class(
oob_data <- data.table(
n_predictors = seq(n_predictors),
stat_value = rep(NA_real_, n_predictors),
variables_included = vector(mode = 'list', length = n_predictors),
predictors_included = vector(mode = 'list', length = n_predictors),
predictor_dropped = rep(NA_character_, n_predictors)
)
Expand Down Expand Up @@ -2915,6 +2916,34 @@ ObliqueForest <- R6::R6Class(
importance_group_factors = TRUE,
write_forest = FALSE)


fctr_key <- lapply(self$get_fctr_info()$keys, function(x) x[-1])

fctr_key <- data.frame(
variable = rep(names(fctr_key), sapply(fctr_key, length)),
predictor = unlist(fctr_key),
row.names = NULL
)

variable_key <- data.frame(
predictor = self$get_names_x(ref_coded = TRUE)
)

if(is_empty(fctr_key)){

variable_key$variable <- variable_key$predictor

} else {

variable_key <- merge(variable_key, fctr_key,
by = 'predictor',
all.x = TRUE)

}

variable_key$variable[is.na(variable_key$variable)] <-
variable_key$predictor[is.na(variable_key$variable)]

max_progress <- n_predictors - n_predictor_min
current_progress <- 0
start_time <- last_time <- Sys.time()
Expand Down Expand Up @@ -2963,9 +2992,17 @@ ObliqueForest <- R6::R6Class(
worst_index <- which.min(cpp_output$importance)
worst_predictor <- colnames(cpp_args$x)[worst_index]


.variables_included <- with(
variable_key,
unique(variable[predictor %in% colnames(cpp_args$x)])
)


oob_data[n_predictors,
`:=`(n_predictors = n_predictors,
stat_value = cpp_output$eval_oobag$stat_values[1,1],
variables_included = .variables_included,
predictors_included = colnames(cpp_args$x),
predictor_dropped = worst_predictor)]

Expand Down
8 changes: 8 additions & 0 deletions R/orsf_vs.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
#' @return a [data.table][data.table::data.table-package] with four columns:
#' - *n_predictors*: the number of predictors used
#' - *stat_value*: the out-of-bag statistic
#' - *variables_included*: the names of the variables included
#' - *predictors_included*: the names of the predictors included
#' - *predictor_dropped*: the predictor selected to be dropped
#'
#' @details
#'
#' The difference between `variables_included` and `predictors_included` is
#' referent coding. The `variable` would be the name of a factor variable
#' in the training data, while the `predictor` would be the name of that
#' same factor with the levels of the factor appended. For example, if
#' the variable is `diabetes` with `levels = c("no", "yes")`, then the
#' variable name is `diabetes` and the predictor name is `diabetes_yes`.
#'
#' `tree_seeds` should be specified in `object` so that each successive run
#' of `orsf` will be evaluated in the same out-of-bag samples as the initial
#' run.
Expand Down
Loading