Skip to content

Commit

Permalink
Merge pull request #52 from ropensci/issue51
Browse files Browse the repository at this point in the history
Issue51
  • Loading branch information
bcjaeger authored Apr 18, 2024
2 parents 8d86f63 + f9f45fc commit 684e544
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
37 changes: 37 additions & 0 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -2883,6 +2883,7 @@ ObliqueForest <- R6::R6Class(
oob_data <- data.table(
n_predictors = seq(n_predictors),
stat_value = rep(NA_real_, n_predictors),
variables_included = vector(mode = 'list', length = n_predictors),
predictors_included = vector(mode = 'list', length = n_predictors),
predictor_dropped = rep(NA_character_, n_predictors)
)
Expand Down Expand Up @@ -2915,6 +2916,34 @@ ObliqueForest <- R6::R6Class(
importance_group_factors = TRUE,
write_forest = FALSE)


fctr_key <- lapply(self$get_fctr_info()$keys, function(x) x[-1])

fctr_key <- data.frame(
variable = rep(names(fctr_key), sapply(fctr_key, length)),
predictor = unlist(fctr_key),
row.names = NULL
)

variable_key <- data.frame(
predictor = self$get_names_x(ref_coded = TRUE)
)

if(is_empty(fctr_key)){

variable_key$variable <- variable_key$predictor

} else {

variable_key <- merge(variable_key, fctr_key,
by = 'predictor',
all.x = TRUE)

}

variable_key$variable[is.na(variable_key$variable)] <-
variable_key$predictor[is.na(variable_key$variable)]

max_progress <- n_predictors - n_predictor_min
current_progress <- 0
start_time <- last_time <- Sys.time()
Expand Down Expand Up @@ -2963,9 +2992,17 @@ ObliqueForest <- R6::R6Class(
worst_index <- which.min(cpp_output$importance)
worst_predictor <- colnames(cpp_args$x)[worst_index]


.variables_included <- with(
variable_key,
unique(variable[predictor %in% colnames(cpp_args$x)])
)


oob_data[n_predictors,
`:=`(n_predictors = n_predictors,
stat_value = cpp_output$eval_oobag$stat_values[1,1],
variables_included = .variables_included,
predictors_included = colnames(cpp_args$x),
predictor_dropped = worst_predictor)]

Expand Down
8 changes: 8 additions & 0 deletions R/orsf_vs.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
#' @return a [data.table][data.table::data.table-package] with four columns:
#' - *n_predictors*: the number of predictors used
#' - *stat_value*: the out-of-bag statistic
#' - *variables_included*: the names of the variables included
#' - *predictors_included*: the names of the predictors included
#' - *predictor_dropped*: the predictor selected to be dropped
#'
#' @details
#'
#' The difference between `variables_included` and `predictors_included` is
#' referent coding. The `variable` would be the name of a factor variable
#' in the training data, while the `predictor` would be the name of that
#' same factor with the levels of the factor appended. For example, if
#' the variable is `diabetes` with `levels = c("no", "yes")`, then the
#' variable name is `diabetes` and the predictor name is `diabetes_yes`.
#'
#' `tree_seeds` should be specified in `object` so that each successive run
#' of `orsf` will be evaluated in the same out-of-bag samples as the initial
#' run.
Expand Down

0 comments on commit 684e544

Please sign in to comment.