Skip to content

Commit

Permalink
use safer mtry for high dimension preds
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 13, 2023
1 parent 25b1e5f commit 62f71c6
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 4 deletions.
16 changes: 15 additions & 1 deletion src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@
this->cols_node.set_size(mtry);
uint cols_accepted = 0;

uword mtry_safe = find_safe_mtry();

if(mtry_safe == 0){
cols_node.resize(0);
return;
}

// Set all to not selected
std::vector<bool> temp;
temp.resize(n_cols_total, false);
Expand All @@ -245,7 +252,7 @@
cols_accepted++;
}

if(cols_accepted == mtry) break;
if(cols_accepted == mtry_safe) break;

}

Expand Down Expand Up @@ -573,6 +580,11 @@

}

uword Tree::find_safe_mtry(){
// only relevant for survival trees at the moment
return(this->mtry);
}

void Tree::find_rows_inbag(arma::uword n_obs) {

// it is assumed that:
Expand Down Expand Up @@ -821,6 +833,8 @@

// determine rows in the current node and if it can be split
if(!is_node_splittable(*node)){
// this step creates y_node and w_node for the current node
// x_node is created once a set of columns are sampled.
sprout_leaf(*node);
continue;
}
Expand Down
2 changes: 2 additions & 0 deletions src/Tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@

virtual void find_all_cuts();

virtual uword find_safe_mtry();

virtual double compute_split_score();

void sample_cuts();
Expand Down
20 changes: 20 additions & 0 deletions src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,26 @@

}

arma::uword TreeSurvival::find_safe_mtry(){

uword safer_mtry = mtry;

if(lincomb_type == LC_NEWTON_RAPHSON){

// Need 3:1 ratio of unweighted events:predictors
uword n_events_total = sum(y_node.col(1));

while(n_events_total / safer_mtry < 3){
--safer_mtry;
if(safer_mtry == 0) break;
}

}

return(safer_mtry);

}

double TreeSurvival::compute_split_score(){

double result=0;
Expand Down
2 changes: 2 additions & 0 deletions src/TreeSurvival.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
this->leaf_min_events = value;
}

arma::uword find_safe_mtry() override;

double compute_prediction_accuracy_internal(arma::vec& preds) override;

std::vector<arma::vec> leaf_pred_indx;
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-orsf_vs.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ test_that(

fit <- orsf(pbc_with_junk,
time + status ~ .,
n_tree = 75,
importance = 'permute',
n_tree = 25,
importance = 'anova',
tree_seeds = seeds_standard)

fit_var_select <- orsf_vs(fit, n_predictor_min = 10)
fit_var_select <- orsf_vs(fit, n_predictor_min = 5)

vars_picked <- fit_var_select$predictors_included[[1]]

Expand Down

0 comments on commit 62f71c6

Please sign in to comment.