Skip to content

Commit

Permalink
make variable selection routine compatible with new orsf_cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 8, 2023
1 parent eccf536 commit 68dbd4b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
17 changes: 3 additions & 14 deletions R/orsf_vs.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ orsf_vs <- function(object,
forest_weights <- NULL

forest_outcomes <- get_names_y(object)

formula <- stats::as.formula(
paste( paste(forest_outcomes, collapse = ' + '), "~ ." )
)
Expand All @@ -62,19 +63,7 @@ orsf_vs <- function(object,
predictor_dropped = rep(NA_character_, n_predictors)
)

oob_last_stat_value <- get_last_oob_stat_value(forest_object)
oob_worst_predictor <- get_last_vi(forest_object)

oob_data[n_predictors,
`:=`(n_predictors = n_predictors,
stat_value = oob_last_stat_value,
predictors_included = forest_predictors,
predictor_dropped = oob_worst_predictor)]

cols_kept <- c(
forest_outcomes,
setdiff(forest_predictors, oob_worst_predictor)
)
cols_kept <- c(forest_outcomes, forest_predictors)

while(n_predictors > n_predictor_min){

Expand All @@ -92,7 +81,7 @@ orsf_vs <- function(object,
split_min_events = get_split_min_events(object),
split_min_obs = get_split_min_obs(object),
split_min_stat = get_split_min_stat(object),
oobag_pred_type = get_oobag_pred_type(object),
oobag_pred_type = 'mort',
oobag_pred_horizon = get_oobag_pred_horizon(object),
oobag_eval_every = get_n_tree(object),
oobag_fun = get_oobag_fun(object),
Expand Down
31 changes: 31 additions & 0 deletions tests/testthat/test-orsf_vs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@


test_that(
desc = "variable selection filters junk preds",
code = {

pbc_with_junk <- pbc

n_junk_preds <- 50

junk_names <- paste("junk", seq(n_junk_preds), sep ='_')

for(i in junk_names)
pbc_with_junk[[i]] <- rnorm(nrow(pbc))

fit <- orsf(pbc_with_junk, time + status ~ ., n_tree = n_tree_test)

fit_var_select <- orsf_vs(fit, n_predictor_min = 10)

vars_picked <- fit_var_select$predictors_included[[1]]

expect_false(
any(junk_names %in% vars_picked)
)

}
)




0 comments on commit 68dbd4b

Please sign in to comment.