Skip to content

Commit

Permalink
test updates
Browse files Browse the repository at this point in the history
moving old helper functions from misc. using setup to create list of
items for tests. including lung as third survival dataset. using a
preproc function to make x/y/w mats for testing cpp code
  • Loading branch information
bcjaeger committed Oct 6, 2023
1 parent a2d4e48 commit 0823fbf
Show file tree
Hide file tree
Showing 17 changed files with 502 additions and 370 deletions.
8 changes: 8 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ is_col_splittable_exported <- function(x, y, r, j) {
.Call(`_aorsf_is_col_splittable_exported`, x, y, r, j)
}

find_cutpoints_survival_exported <- function(y, w, lincomb, leaf_min_events, leaf_min_obs) {
.Call(`_aorsf_find_cutpoints_survival_exported`, y, w, lincomb, leaf_min_events, leaf_min_obs)
}

sprout_node_survival_exported <- function(y, w) {
.Call(`_aorsf_sprout_node_survival_exported`, y, w)
}

cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}
Expand Down
49 changes: 0 additions & 49 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,55 +104,6 @@ paste_collapse <- function(x, sep=', ', last = ' or '){

}

#' Find cut-point boundaries (R version)
#'
#' Used to test the cpp version for finding cutpoints
#'
#' @param y_node outcome matrix
#' @param w_node weight vector
#' @param XB linear combination of predictors
#' @param xb_uni unique values in XB
#' @param leaf_min_events min no. of events in a leaf
#' @param leaf_min_obs min no. of observations in a leaf
#'
#' @noRd
#'
#' @return data.frame with description of valid cutpoints
cp_find_bounds_R <- function(y_node,
w_node,
XB,
xb_uni,
leaf_min_events,
leaf_min_obs){

status = y_node[, 'status']

cp_stats <-
sapply(
X = xb_uni,
FUN = function(x){
c(
cp = x,
e_right = sum(status[XB > x]),
e_left = sum(status[XB <= x]),
n_right = sum(XB > x),
n_left = sum(XB <= x)
)
}
)

cp_stats <- as.data.frame(t(cp_stats))

cp_stats$valid_cp = with(
cp_stats,
e_right >= leaf_min_events & e_left >= leaf_min_events &
n_right >= leaf_min_obs & n_left >= leaf_min_obs
)

cp_stats

}

has_units <- function(x){
inherits(x, 'units')
}
Expand Down
34 changes: 17 additions & 17 deletions man/orsf_vi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/predict.orsf_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,33 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// find_cutpoints_survival_exported
arma::uvec find_cutpoints_survival_exported(arma::mat& y, arma::vec& w, arma::vec& lincomb, double leaf_min_events, double leaf_min_obs);
RcppExport SEXP _aorsf_find_cutpoints_survival_exported(SEXP ySEXP, SEXP wSEXP, SEXP lincombSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type lincomb(lincombSEXP);
Rcpp::traits::input_parameter< double >::type leaf_min_events(leaf_min_eventsSEXP);
Rcpp::traits::input_parameter< double >::type leaf_min_obs(leaf_min_obsSEXP);
rcpp_result_gen = Rcpp::wrap(find_cutpoints_survival_exported(y, w, lincomb, leaf_min_events, leaf_min_obs));
return rcpp_result_gen;
END_RCPP
}
// sprout_node_survival_exported
List sprout_node_survival_exported(arma::mat& y, arma::vec& w);
RcppExport SEXP _aorsf_sprout_node_survival_exported(SEXP ySEXP, SEXP wSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
rcpp_result_gen = Rcpp::wrap(sprout_node_survival_exported(y, w));
return rcpp_result_gen;
END_RCPP
}
// cph_scale
List cph_scale(arma::mat& x, arma::vec& w);
RcppExport SEXP _aorsf_cph_scale(SEXP xSEXP, SEXP wSEXP) {
Expand Down Expand Up @@ -155,6 +182,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_aorsf_compute_cstat_exported_uvec", (DL_FUNC) &_aorsf_compute_cstat_exported_uvec, 4},
{"_aorsf_compute_logrank_exported", (DL_FUNC) &_aorsf_compute_logrank_exported, 3},
{"_aorsf_is_col_splittable_exported", (DL_FUNC) &_aorsf_is_col_splittable_exported, 4},
{"_aorsf_find_cutpoints_survival_exported", (DL_FUNC) &_aorsf_find_cutpoints_survival_exported, 5},
{"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2},
{"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2},
{"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44},
{NULL, NULL, 0}
Expand Down
8 changes: 1 addition & 7 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,20 +562,14 @@
Rcout << std::endl;
}


// do not split if best stat < minimum stat
if(stat_best < split_min_stat){

return(R_PosInf);

}
if(stat_best < split_min_stat){ return(R_PosInf); }

// backtrack g_node to be what it was when best it was found
if(it_best < it_start){
g_node.elem(lincomb_sort.subvec(it_best+1, it_start)).fill(1);
}


// return the cut-point from best split
return(lincomb[lincomb_sort[it_best]]);

Expand Down
32 changes: 32 additions & 0 deletions src/Tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,42 @@
this->y_inbag = y;
}

void set_w_inbag(arma::vec w){
this->w_inbag = w;
}

void set_x_node(arma::mat x){
this->x_node = x;
}

void set_y_node(arma::mat y){
this->y_node = y;
}

void set_w_node(arma::vec w){
this->w_node = w;
}

void set_rows_node(arma::uvec rows){
this->rows_node = rows;
}

void set_lincomb(arma::vec lc){
this->lincomb = lc;
}

void set_lincomb_sort(arma::uvec lc_sort){
this->lincomb_sort = lc_sort;
}

void set_leaf_min_obs(double value){
this->leaf_min_obs = value;
}

void set_verbosity(int value){
this->verbosity = value;
}


protected:

Expand Down
Loading

0 comments on commit 0823fbf

Please sign in to comment.