Skip to content

Commit

Permalink
tests for classification vi
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Nov 21, 2023
1 parent 33fce56 commit ad15c8a
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 39 deletions.
4 changes: 2 additions & 2 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Comparisons between `aorsf` and existing software are presented in our [JCGS pap

- reports prediction accuracy and computational efficiency of all learners.

- runs a simulation study comparing variable importance techniques with ORSFs, axis based RSFs, and boosted trees.
- runs a simulation study comparing variable importance techniques with oblique survival RFs, axis based survival RFs, and boosted trees.

- reports the probability that each variable importance technique will rank a relevant variable with higher importance than an irrelevant variable.

Expand All @@ -181,7 +181,7 @@ cat("3. ", aorsf:::roxy_cite_menze_2011())

## Funding

The developers of `aorsf` receive financial support from the Center for Biomedical Informatics, Wake Forest University School of Medicine. We also receive support from the National Center for Advancing Translational Sciences of the National Institutes of Health under Award Number UL1TR001420.
The developers of `aorsf` received financial support from the Center for Biomedical Informatics, Wake Forest University School of Medicine. We also received support from the National Center for Advancing Translational Sciences of the National Institutes of Health under Award Number UL1TR001420.

The content is solely the responsibility of the authors and does not necessarily represent the official views of the National Institutes of Health.

31 changes: 24 additions & 7 deletions src/ForestClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,14 @@ void ForestClassification::compute_prediction_accuracy_internal(
arma::uword row_fill
) {

double result = 0;
double result = 0, denom = 0;

if(oobag_eval_type == EVAL_R_FUNCTION){

// initialize function from tree object
// (Functions can't be stored in C++ classes, but Robjects can)
Rcpp::Function f_oobag_eval = Rcpp::as<Rcpp::Function>(oobag_R_function);


// go through all columns if multi-class y,
// but only go through one column if y is binary
// uword start = 0;
Expand Down Expand Up @@ -129,13 +128,31 @@ void ForestClassification::compute_prediction_accuracy_internal(

}

for(uword i = 0; i < predictions.n_cols; i++){
vec y_i = y.unsafe_col(i);
vec p_i = predictions.unsafe_col(i);
result += compute_cstat_clsf(y_i, w, p_i);
if(pred_type == PRED_PROBABILITY){

denom = predictions.n_cols;

for(uword i = 0; i < predictions.n_cols; i++){
vec y_i = y.unsafe_col(i);
vec p_i = predictions.unsafe_col(i);
result += compute_cstat_clsf(y_i, w, p_i);
}

} else if (pred_type == PRED_CLASS){

for(uword i = 0; i < y.n_rows; i++){

if(predictions.at(i, 0) == y.at(i, 0)){
result += w[i];
}

denom += w[i];

}

}

oobag_eval(row_fill, 0) = result / predictions.n_cols;
oobag_eval(row_fill, 0) = result / denom;

}

Expand Down
26 changes: 3 additions & 23 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,7 @@
// using oobag = false for predict b/c data_oobag is already subsetted
predict_leaf(data_oobag.get(), false);

uword n_col_vi = get_n_col_vi();

mat pred_values(data_oobag->n_rows, n_col_vi);
mat pred_values(data_oobag->n_rows, get_n_col_vi());

fill_pred_values_vi(pred_values);

Expand All @@ -683,9 +681,9 @@

if(verbosity > 1){
// # nocov start
Rcout << " -- prediction accuracy before noising: ";
Rcout << " -- prediction accuracy before noising: ";
Rcout << accuracy_normal << std::endl;
Rcout << " -- mean leaf pred: ";
Rcout << " -- mean leaf pred: ";
Rcout << mean(conv_to<vec>::from(pred_leaf));
Rcout << std::endl << std::endl;
// # nocov end
Expand Down Expand Up @@ -1185,24 +1183,6 @@

double Tree::compute_prediction_accuracy(arma::mat& preds){

if (oobag_eval_type == EVAL_R_FUNCTION){

vec preds_vec = preds.unsafe_col(0);

NumericMatrix y_wrap = wrap(y_oobag);
NumericVector w_wrap = wrap(w_oobag);
NumericVector p_wrap = wrap(preds_vec);

// initialize function from tree object
// (Functions can't be stored in C++ classes, but RObjects can)
Function f_oobag = as<Function>(oobag_R_function);

NumericVector result_R = f_oobag(y_wrap, w_wrap, p_wrap);

return(result_R[0]);

}

return(compute_prediction_accuracy_internal(preds));

}
Expand Down
1 change: 1 addition & 0 deletions src/Tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@
virtual arma::mat user_fit();

virtual uword get_n_col_vi()=0;
virtual PredType get_pred_type_vi()=0;

virtual void fill_pred_values_vi(arma::mat& pred_values)=0;

Expand Down
62 changes: 56 additions & 6 deletions src/TreeClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@
vec y_i = y_node.unsafe_col(y_col_split);
result = compute_cstat_clsf(y_i, w_node, g_node);

// if the split has good 'anti-prediction' properties:
// if the split has good 'anti-prediction':
if(result < 0.50){ result = 1 - result; }

break;

}

default:
Rcpp::stop("invalid split rule");
stop("invalid split rule");
break;

}
Expand Down Expand Up @@ -202,17 +202,51 @@
arma::mat& preds
){

double cstat_sum = 0;
double result = 0, denom = preds.n_cols;

uword start = 0; if(binary) start = 1;
uword start = 0;

if(binary){

start = 1;
denom = 1;

}

if (oobag_eval_type == EVAL_R_FUNCTION){

// initialize function from tree object
// (Functions can't be stored in C++ classes, but RObjects can)
Function f_oobag_eval = as<Function>(oobag_R_function);

NumericVector w_ = wrap(w_oobag);

for(uword i = start; i < preds.n_cols; ++i){

vec y_i = y_oobag.unsafe_col(i);
vec p_i = preds.unsafe_col(i);

NumericVector y_ = wrap(y_i);
NumericVector p_ = wrap(p_i);
NumericVector R_result = f_oobag_eval(y_, w_, p_);

double result_addon = R_result[0];

result += result_addon;

}

return(result / denom);

}

for(uword i = start; i < y_oobag.n_cols; i++){
vec y_i = y_oobag.unsafe_col(i);
vec p_i = preds.unsafe_col(i);
cstat_sum += compute_cstat_clsf(y_i, w_oobag, p_i);
result += compute_cstat_clsf(y_i, w_oobag, p_i);
}

return cstat_sum / preds.n_cols;
return result / denom;

}

Expand Down Expand Up @@ -365,7 +399,23 @@
}

uword TreeClassification::get_n_col_vi(){

return(n_class);

}

PredType TreeClassification::get_pred_type_vi(){

PredType out;

if(pred_type == PRED_CLASS){
out = PRED_CLASS;
} else {
out = PRED_PROBABILITY;
}

return(out);

}

void TreeClassification::fill_pred_values_vi(mat& pred_values){
Expand Down
1 change: 1 addition & 0 deletions src/TreeClassification.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
arma::mat user_fit() override;

uword get_n_col_vi() override;
PredType get_pred_type_vi() override;

void fill_pred_values_vi(arma::mat& pred_values) override;

Expand Down
26 changes: 26 additions & 0 deletions src/TreeRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,24 @@
arma::mat& preds
){

if (oobag_eval_type == EVAL_R_FUNCTION){

vec preds_vec = preds.unsafe_col(0);

NumericMatrix y_wrap = wrap(y_oobag);
NumericVector w_wrap = wrap(w_oobag);
NumericVector p_wrap = wrap(preds_vec);

// initialize function from tree object
// (Functions can't be stored in C++ classes, but RObjects can)
Function f_oobag = as<Function>(oobag_R_function);

NumericVector result_R = f_oobag(y_wrap, w_wrap, p_wrap);

return(result_R[0]);

}

double mse_sum = 0;

for(uword i = 0; i < y_oobag.n_cols; i++){
Expand Down Expand Up @@ -255,6 +273,14 @@
return(1);
}

PredType TreeRegression::get_pred_type_vi(){

PredType out = PRED_MEAN;

return(out);

}

void TreeRegression::fill_pred_values_vi(mat& pred_values){

for(uword i = 0; i < pred_values.n_rows; ++i){
Expand Down
1 change: 1 addition & 0 deletions src/TreeRegression.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
arma::mat user_fit() override;

uword get_n_col_vi() override;
PredType get_pred_type_vi() override;

bool is_node_splittable_internal() override;

Expand Down
26 changes: 26 additions & 0 deletions src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,24 @@

double TreeSurvival::compute_prediction_accuracy_internal(arma::mat& preds){

if (oobag_eval_type == EVAL_R_FUNCTION){

vec preds_vec = preds.unsafe_col(0);

NumericMatrix y_wrap = wrap(y_oobag);
NumericVector w_wrap = wrap(w_oobag);
NumericVector p_wrap = wrap(preds_vec);

// initialize function from tree object
// (Functions can't be stored in C++ classes, but RObjects can)
Function f_oobag = as<Function>(oobag_R_function);

NumericVector result_R = f_oobag(y_wrap, w_wrap, p_wrap);

return(result_R[0]);

}

vec preds_vec = preds.unsafe_col(0);

return compute_cstat_surv(y_oobag, w_oobag, preds_vec, true);
Expand Down Expand Up @@ -735,6 +753,14 @@
return(1);
}

PredType TreeSurvival::get_pred_type_vi(){

PredType out = PRED_MORTALITY;

return(out);

}

void TreeSurvival::fill_pred_values_vi(mat& pred_values){

for(uword i = 0; i < pred_values.n_rows; ++i){
Expand Down
1 change: 1 addition & 0 deletions src/TreeSurvival.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
void sprout_leaf_internal(uword node_id) override;

uword get_n_col_vi() override;
PredType get_pred_type_vi() override;

void fill_pred_values_vi(arma::mat& pred_values) override;

Expand Down
Loading

0 comments on commit ad15c8a

Please sign in to comment.