diff --git a/src/Tree.cpp b/src/Tree.cpp index 36969c39..8e0119e3 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -226,6 +226,13 @@ this->cols_node.set_size(mtry); uint cols_accepted = 0; + uword mtry_safe = find_safe_mtry(); + + if(mtry_safe == 0){ + cols_node.resize(0); + return; + } + // Set all to not selected std::vector temp; temp.resize(n_cols_total, false); @@ -245,7 +252,7 @@ cols_accepted++; } - if(cols_accepted == mtry) break; + if(cols_accepted == mtry_safe) break; } @@ -573,6 +580,11 @@ } + uword Tree::find_safe_mtry(){ + // only relevant for survival trees at the moment + return(this->mtry); + } + void Tree::find_rows_inbag(arma::uword n_obs) { // it is assumed that: @@ -821,6 +833,8 @@ // determine rows in the current node and if it can be split if(!is_node_splittable(*node)){ + // this step creates y_node and w_node for the current node + // x_node is created once a set of columns are sampled. sprout_leaf(*node); continue; } diff --git a/src/Tree.h b/src/Tree.h index 3ffe7d95..021f673d 100644 --- a/src/Tree.h +++ b/src/Tree.h @@ -77,6 +77,8 @@ virtual void find_all_cuts(); + virtual uword find_safe_mtry(); + virtual double compute_split_score(); void sample_cuts(); diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index 4c0ae8fc..6e217780 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -299,6 +299,26 @@ } + arma::uword TreeSurvival::find_safe_mtry(){ + + uword safer_mtry = mtry; + + if(lincomb_type == LC_NEWTON_RAPHSON){ + + // Need 3:1 ratio of unweighted events:predictors + uword n_events_total = sum(y_node.col(1)); + + while(n_events_total / safer_mtry < 3){ + --safer_mtry; + if(safer_mtry == 0) break; + } + + } + + return(safer_mtry); + + } + double TreeSurvival::compute_split_score(){ double result=0; diff --git a/src/TreeSurvival.h b/src/TreeSurvival.h index d3a92d41..8b035461 100644 --- a/src/TreeSurvival.h +++ b/src/TreeSurvival.h @@ -81,6 +81,8 @@ this->leaf_min_events = value; } + arma::uword find_safe_mtry() override; + double compute_prediction_accuracy_internal(arma::vec& preds) override; std::vector leaf_pred_indx; diff --git a/tests/testthat/test-orsf_vs.R b/tests/testthat/test-orsf_vs.R index 0c4e52dd..5f736fb3 100644 --- a/tests/testthat/test-orsf_vs.R +++ b/tests/testthat/test-orsf_vs.R @@ -16,11 +16,11 @@ test_that( fit <- orsf(pbc_with_junk, time + status ~ ., - n_tree = 75, - importance = 'permute', + n_tree = 25, + importance = 'anova', tree_seeds = seeds_standard) - fit_var_select <- orsf_vs(fit, n_predictor_min = 10) + fit_var_select <- orsf_vs(fit, n_predictor_min = 5) vars_picked <- fit_var_select$predictors_included[[1]]