From 0823fbfd1a135491581969364266d89bf80546f9 Mon Sep 17 00:00:00 2001 From: bjaeger Date: Fri, 6 Oct 2023 07:54:49 -0400 Subject: [PATCH] test updates moving old helper functions from misc. using setup to create list of items for tests. including lung as third survival dataset. using a preproc function to make x/y/w mats for testing cpp code --- R/RcppExports.R | 8 ++ R/misc.R | 49 ---------- man/orsf_vi.Rd | 34 +++---- man/predict.orsf_fit.Rd | 10 +- src/RcppExports.cpp | 29 ++++++ src/Tree.cpp | 8 +- src/Tree.h | 32 +++++++ src/TreeSurvival.cpp | 138 ++++++++------------------- src/TreeSurvival.h | 10 +- src/orsf_oop.cpp | 48 ++++++++++ tests/testthat/helper-orsf.R | 134 ++++++++++++++++++++++++++ tests/testthat/setup.R | 29 +++++- tests/testthat/test-cp_find_bounds.R | 138 --------------------------- tests/testthat/test-find_cutpoints.R | 108 +++++++++++++++++++++ tests/testthat/test-leaf_kaplan.R | 42 -------- tests/testthat/test-orsf_vi.R | 10 +- tests/testthat/test-sprout_node.R | 45 +++++++++ 17 files changed, 502 insertions(+), 370 deletions(-) create mode 100644 tests/testthat/test-find_cutpoints.R delete mode 100644 tests/testthat/test-leaf_kaplan.R create mode 100644 tests/testthat/test-sprout_node.R diff --git a/R/RcppExports.R b/R/RcppExports.R index f9b0c296..14b8bcc9 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -21,6 +21,14 @@ is_col_splittable_exported <- function(x, y, r, j) { .Call(`_aorsf_is_col_splittable_exported`, x, y, r, j) } +find_cutpoints_survival_exported <- function(y, w, lincomb, leaf_min_events, leaf_min_obs) { + .Call(`_aorsf_find_cutpoints_survival_exported`, y, w, lincomb, leaf_min_events, leaf_min_obs) +} + +sprout_node_survival_exported <- function(y, w) { + .Call(`_aorsf_sprout_node_survival_exported`, y, w) +} + cph_scale <- function(x, w) { .Call(`_aorsf_cph_scale`, x, w) } diff --git a/R/misc.R b/R/misc.R index d735910b..149a0e34 100644 --- a/R/misc.R +++ b/R/misc.R @@ -104,55 +104,6 @@ paste_collapse <- function(x, sep=', ', last = ' or '){ } -#' Find cut-point boundaries (R version) -#' -#' Used to test the cpp version for finding cutpoints -#' -#' @param y_node outcome matrix -#' @param w_node weight vector -#' @param XB linear combination of predictors -#' @param xb_uni unique values in XB -#' @param leaf_min_events min no. of events in a leaf -#' @param leaf_min_obs min no. of observations in a leaf -#' -#' @noRd -#' -#' @return data.frame with description of valid cutpoints -cp_find_bounds_R <- function(y_node, - w_node, - XB, - xb_uni, - leaf_min_events, - leaf_min_obs){ - - status = y_node[, 'status'] - - cp_stats <- - sapply( - X = xb_uni, - FUN = function(x){ - c( - cp = x, - e_right = sum(status[XB > x]), - e_left = sum(status[XB <= x]), - n_right = sum(XB > x), - n_left = sum(XB <= x) - ) - } - ) - - cp_stats <- as.data.frame(t(cp_stats)) - - cp_stats$valid_cp = with( - cp_stats, - e_right >= leaf_min_events & e_left >= leaf_min_events & - n_right >= leaf_min_obs & n_left >= leaf_min_obs - ) - - cp_stats - -} - has_units <- function(x){ inherits(x, 'units') } diff --git a/man/orsf_vi.Rd b/man/orsf_vi.Rd index f130cd94..3e82c020 100644 --- a/man/orsf_vi.Rd +++ b/man/orsf_vi.Rd @@ -221,26 +221,26 @@ orsf_vi_negate(fit_no_vi) }\if{html}{\out{}} \if{html}{\out{
}}\preformatted{## bili copper sex protime stage -## 0.1139657923 0.0498712200 0.0355366377 0.0283554322 0.0263792287 +## 0.1140229590 0.0498125548 0.0355460350 0.0281214750 0.0261880947 ## albumin age ascites chol ast -## 0.0231636378 0.0195791833 0.0175120075 0.0148252414 0.0104918262 +## 0.0231317226 0.0195454979 0.0174983334 0.0147644716 0.0106376518 ## edema spiders hepato trt trig -## 0.0084871358 0.0070608860 0.0067054788 0.0052040792 0.0030363455 +## 0.0084608002 0.0069190613 0.0065709504 0.0052988663 0.0030146941 ## alk.phos platelet -## 0.0029918139 -0.0003309069 +## 0.0029465729 -0.0003607455 }\if{html}{\out{
}} \if{html}{\out{
}}\preformatted{orsf_vi_permute(fit_no_vi) }\if{html}{\out{
}} \if{html}{\out{
}}\preformatted{## bili copper protime albumin stage -## 0.0511641625 0.0244676999 0.0160869571 0.0133334120 0.0130092352 -## ascites age hepato edema chol -## 0.0127421184 0.0113532728 0.0050851329 0.0050477457 0.0049382275 -## ast spiders sex alk.phos trig -## 0.0047345189 0.0038719163 0.0025231267 0.0018408350 0.0011528848 +## 0.0512565237 0.0244154211 0.0159740365 0.0133163811 0.0130200869 +## ascites age edema hepato ast +## 0.0127617348 0.0113466862 0.0050618672 0.0050513510 0.0048291999 +## chol spiders sex alk.phos trig +## 0.0048061978 0.0037750082 0.0024935019 0.0018249747 0.0011756814 ## platelet trt -## -0.0002875319 -0.0024330707 +## -0.0003313926 -0.0024983274 }\if{html}{\out{
}} } @@ -257,13 +257,13 @@ orsf_vi_permute(fit_permute_vi) }\if{html}{\out{}} \if{html}{\out{
}}\preformatted{## bili copper age albumin protime -## 0.0502725526 0.0201473283 0.0135888938 0.0127241082 0.0126629150 +## 0.0503462348 0.0202245782 0.0136242456 0.0127255613 0.0127084993 ## stage ascites ast edema chol -## 0.0124866976 0.0123508555 0.0060741690 0.0059166139 0.0053767371 +## 0.0124821760 0.0123368881 0.0061390036 0.0059648179 0.0053888106 ## spiders sex hepato trig alk.phos -## 0.0042600602 0.0028177750 0.0023470782 0.0021331719 0.0016874102 +## 0.0042214032 0.0027713550 0.0022973825 0.0019936938 0.0017116920 ## platelet trt -## 0.0002117061 -0.0005790547 +## 0.0002320012 -0.0006040636 }\if{html}{\out{
}} You can still get negation VI from this fit, but it needs to be computed @@ -272,11 +272,11 @@ You can still get negation VI from this fit, but it needs to be computed }\if{html}{\out{}} \if{html}{\out{
}}\preformatted{## bili copper sex stage age protime -## 0.1106715167 0.0456031656 0.0306666098 0.0304383573 0.0252136203 0.0224838590 +## 0.1106480407 0.0455070935 0.0307476652 0.0304715880 0.0252079956 0.0224192033 ## albumin ascites chol ast edema trt -## 0.0212630703 0.0168893963 0.0134174671 0.0132075752 0.0099681058 0.0088378768 +## 0.0214131931 0.0168756851 0.0133727008 0.0132732736 0.0100002899 0.0087589773 ## spiders hepato trig alk.phos platelet -## 0.0078776082 0.0062877323 0.0043076141 0.0030432581 0.0005571111 +## 0.0079457577 0.0062060793 0.0041321875 0.0030525590 0.0005921052 }\if{html}{\out{
}} } } diff --git a/man/predict.orsf_fit.Rd b/man/predict.orsf_fit.Rd index bafe78ca..f766dd48 100644 --- a/man/predict.orsf_fit.Rd +++ b/man/predict.orsf_fit.Rd @@ -169,11 +169,11 @@ prediction horizon }\if{html}{\out{}} \if{html}{\out{
}}\preformatted{## [,1] -## [1,] 83.08611 -## [2,] 27.48146 -## [3,] 43.52432 -## [4,] 15.20281 -## [5,] 10.56334 +## [1,] 81.42491 +## [2,] 26.90304 +## [3,] 42.62956 +## [4,] 14.88943 +## [5,] 10.35110 }\if{html}{\out{
}} } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 7f74671b..5f786d22 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -82,6 +82,33 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// find_cutpoints_survival_exported +arma::uvec find_cutpoints_survival_exported(arma::mat& y, arma::vec& w, arma::vec& lincomb, double leaf_min_events, double leaf_min_obs); +RcppExport SEXP _aorsf_find_cutpoints_survival_exported(SEXP ySEXP, SEXP wSEXP, SEXP lincombSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP); + Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP); + Rcpp::traits::input_parameter< arma::vec& >::type lincomb(lincombSEXP); + Rcpp::traits::input_parameter< double >::type leaf_min_events(leaf_min_eventsSEXP); + Rcpp::traits::input_parameter< double >::type leaf_min_obs(leaf_min_obsSEXP); + rcpp_result_gen = Rcpp::wrap(find_cutpoints_survival_exported(y, w, lincomb, leaf_min_events, leaf_min_obs)); + return rcpp_result_gen; +END_RCPP +} +// sprout_node_survival_exported +List sprout_node_survival_exported(arma::mat& y, arma::vec& w); +RcppExport SEXP _aorsf_sprout_node_survival_exported(SEXP ySEXP, SEXP wSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP); + Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP); + rcpp_result_gen = Rcpp::wrap(sprout_node_survival_exported(y, w)); + return rcpp_result_gen; +END_RCPP +} // cph_scale List cph_scale(arma::mat& x, arma::vec& w); RcppExport SEXP _aorsf_cph_scale(SEXP xSEXP, SEXP wSEXP) { @@ -155,6 +182,8 @@ static const R_CallMethodDef CallEntries[] = { {"_aorsf_compute_cstat_exported_uvec", (DL_FUNC) &_aorsf_compute_cstat_exported_uvec, 4}, {"_aorsf_compute_logrank_exported", (DL_FUNC) &_aorsf_compute_logrank_exported, 3}, {"_aorsf_is_col_splittable_exported", (DL_FUNC) &_aorsf_is_col_splittable_exported, 4}, + {"_aorsf_find_cutpoints_survival_exported", (DL_FUNC) &_aorsf_find_cutpoints_survival_exported, 5}, + {"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2}, {"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2}, {"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44}, {NULL, NULL, 0} diff --git a/src/Tree.cpp b/src/Tree.cpp index f7927e03..a637f867 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -562,20 +562,14 @@ Rcout << std::endl; } - // do not split if best stat < minimum stat - if(stat_best < split_min_stat){ - - return(R_PosInf); - - } + if(stat_best < split_min_stat){ return(R_PosInf); } // backtrack g_node to be what it was when best it was found if(it_best < it_start){ g_node.elem(lincomb_sort.subvec(it_best+1, it_start)).fill(1); } - // return the cut-point from best split return(lincomb[lincomb_sort[it_best]]); diff --git a/src/Tree.h b/src/Tree.h index 79011abe..4602f53b 100644 --- a/src/Tree.h +++ b/src/Tree.h @@ -139,10 +139,42 @@ this->y_inbag = y; } + void set_w_inbag(arma::vec w){ + this->w_inbag = w; + } + + void set_x_node(arma::mat x){ + this->x_node = x; + } + + void set_y_node(arma::mat y){ + this->y_node = y; + } + + void set_w_node(arma::vec w){ + this->w_node = w; + } + void set_rows_node(arma::uvec rows){ this->rows_node = rows; } + void set_lincomb(arma::vec lc){ + this->lincomb = lc; + } + + void set_lincomb_sort(arma::uvec lc_sort){ + this->lincomb_sort = lc_sort; + } + + void set_leaf_min_obs(double value){ + this->leaf_min_obs = value; + } + + void set_verbosity(int value){ + this->verbosity = value; + } + protected: diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index b38ebd02..8c3075a0 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -157,24 +157,12 @@ double n_events = 0, n_risk = 0; - if(VERBOSITY > 1){ - Rcout << "----- finding lower bound for cut-points -----" << std::endl; - } - // stop at end-1 b/c we access it+1 in lincomb_sort for(it = lincomb_sort.begin(); it < lincomb_sort.end()-1; ++it){ n_events += y_status[*it] * w_node[*it]; n_risk += w_node[*it]; - - if(VERBOSITY > 2){ - Rcout << "current value: "<< lincomb(*it) << " -- "; - Rcout << "next: "<< lincomb(*(it+1)) << " -- "; - Rcout << "N events: " << n_events << " -- "; - Rcout << "N risk: " << n_risk << std::endl; - } - // If we want to make the current value of lincomb a cut-point, we need // to make sure the next value of lincomb isn't equal to this current value. // Otherwise, we will have the same value of lincomb in both groups! @@ -184,11 +172,11 @@ if( n_events >= leaf_min_events && n_risk >= leaf_min_obs ) { - if(VERBOSITY > 0){ + if(verbosity > 2){ Rcout << std::endl; - Rcout << "lower cutpoint: " << lincomb(*it) << std::endl; - Rcout << " - n_events, left node: " << n_events << std::endl; - Rcout << " - n_risk, left node: " << n_risk << std::endl; + Rcout << " -- lower cutpoint: " << lincomb(*it) << std::endl; + Rcout << " - n_events, left node: " << n_events << std::endl; + Rcout << " - n_risk, left node: " << n_risk << std::endl; Rcout << std::endl; } @@ -204,8 +192,8 @@ if(it == lincomb_sort.end()-1) { - if(VERBOSITY > 1){ - Rcout << "Could not find a valid cut-point" << std::endl; + if(verbosity > 2){ + Rcout << " -- Could not find a valid cut-point" << std::endl; } return(output); @@ -218,23 +206,12 @@ // reset before finding the upper limit n_events=0, n_risk=0; - if(VERBOSITY > 1){ - Rcout << "----- finding upper bound for cut-points -----" << std::endl; - } - // stop at beginning+1 b/c we access it-1 in lincomb_sort for(it = lincomb_sort.end()-1; it >= lincomb_sort.begin()+1; --it){ n_events += y_status[*it] * w_node[*it]; n_risk += w_node[*it]; - if(VERBOSITY > 2){ - Rcout << "current value: "<< lincomb(*it) << " ---- "; - Rcout << "next value: "<< lincomb(*(it-1)) << " ---- "; - Rcout << "N events: " << n_events << " ---- "; - Rcout << "N risk: " << n_risk << std::endl; - } - if(lincomb[*it] != lincomb[*(it-1)]){ if( n_events >= leaf_min_events && @@ -250,11 +227,11 @@ --it; - if(VERBOSITY > 0){ + if(verbosity > 2){ Rcout << std::endl; - Rcout << "upper cutpoint: " << lincomb(*it) << std::endl; - Rcout << " - n_events, right node: " << n_events << std::endl; - Rcout << " - n_risk, right node: " << n_risk << std::endl; + Rcout << " -- upper cutpoint: " << lincomb(*it) << std::endl; + Rcout << " - n_events, right node: " << n_events << std::endl; + Rcout << " - n_risk, right node: " << n_risk << std::endl; Rcout << std::endl; } @@ -273,7 +250,7 @@ if(j > k){ - if(VERBOSITY > 0) { + if(verbosity > 2) { Rcout << "Could not find valid cut-points" << std::endl; } @@ -333,68 +310,6 @@ } - double TreeSurvival::score_logrank(){ - - double - n_risk=0, - g_risk=0, - observed=0, - expected=0, - V=0, - temp1, - temp2, - n_events; - - vec y_time = y_node.unsafe_col(0); - vec y_status = y_node.unsafe_col(1); - - bool break_loop = false; - - uword i = y_node.n_rows-1; - - // breaking condition of outer loop governed by inner loop - for (; ;){ - - temp1 = y_time[i]; - - n_events = 0; - - for ( ; y_time[i] == temp1; i--) { - - n_risk += w_node[i]; - n_events += y_status[i] * w_node[i]; - g_risk += g_node[i] * w_node[i]; - observed += y_status[i] * g_node[i] * w_node[i]; - - if(i == 0){ - break_loop = true; - break; - } - - } - - // should only do these calculations if n_events > 0, - // but in practice its often faster to multiply by 0 - // versus check if n_events is > 0. - - temp2 = g_risk / n_risk; - expected += n_events * temp2; - - // update variance if n_risk > 1 (if n_risk == 1, variance is 0) - // definitely check if n_risk is > 1 b/c otherwise divide by 0 - if (n_risk > 1){ - temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1); - V += temp1 * (1 - temp2); - } - - if(break_loop) break; - - } - - return(pow(expected-observed, 2) / V); - - } - void TreeSurvival::sprout_leaf(uword node_id){ if(verbosity > 2){ @@ -404,6 +319,7 @@ Rcout << std::endl; } + // reserve as much size as could be needed (probably more) mat leaf_data(y_node.n_rows, 3); @@ -414,6 +330,7 @@ person++; } + // person corresponds to first event or last censor time leaf_data.at(0, 0) = y_node.at(person, 0); @@ -421,8 +338,8 @@ // (TODO: should this case even occur? consider removing) if(person == y_node.n_rows){ - vec temp_surv(1, arma::fill::ones); - vec temp_chf(1, arma::fill::zeros); + vec temp_surv(1, fill::ones); + vec temp_chf(1, fill::zeros); leaf_pred_indx[node_id] = leaf_data.col(0); leaf_pred_prob[node_id] = temp_surv; @@ -514,8 +431,11 @@ for( ; i < (*unique_event_times).size(); i++){ - if((*unique_event_times)[i] >= leaf_data.at(j, 0) && - j < (leaf_data.n_rows-1)) {j++;} + + while((*unique_event_times)[i] > leaf_data.at(j, 0) && + j < (leaf_data.n_rows-1)) { + j++; + } result += leaf_data.at(j, 2); @@ -525,6 +445,24 @@ } + // double TreeSurvival::compute_mortality(arma::mat& leaf_data){ + // + // double result = 0; + // uword i=0, j=0; + // + // for( ; i < (*unique_event_times).size(); i++){ + // + // if((*unique_event_times)[i] >= leaf_data.at(j, 0) && + // j < (leaf_data.n_rows-1)) {j++;} + // + // result += leaf_data.at(j, 2); + // + // } + // + // return(result); + // + // } + void TreeSurvival::predict_value(arma::mat* pred_output, arma::vec* pred_denom, PredType pred_type, diff --git a/src/TreeSurvival.h b/src/TreeSurvival.h index 0306616c..941aee24 100644 --- a/src/TreeSurvival.h +++ b/src/TreeSurvival.h @@ -51,8 +51,6 @@ double compute_split_score() override; - double score_logrank(); - double compute_mortality(arma::mat& leaf_data); void sprout_leaf(uword node_id) override; @@ -74,6 +72,14 @@ return(leaf_pred_chaz); } + void set_unique_event_times(arma::vec event_times){ + this->unique_event_times = &event_times; + } + + void set_leaf_min_events(double value){ + this->leaf_min_events = value; + } + double compute_prediction_accuracy_internal(arma::vec& preds) override; std::vector leaf_pred_indx; diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index c83a73d9..5d8672e0 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -93,6 +93,54 @@ } + // [[Rcpp::export]] + arma::uvec find_cutpoints_survival_exported(arma::mat& y, + arma::vec& w, + arma::vec& lincomb, + double leaf_min_events, + double leaf_min_obs){ + + TreeSurvival tree; + + arma::uvec lincomb_sort = sort_index(lincomb); + + tree.set_y_node(y); + tree.set_w_node(w); + tree.set_lincomb(lincomb); + tree.set_lincomb_sort(lincomb_sort); + tree.set_leaf_min_obs(leaf_min_obs); + tree.set_leaf_min_events(leaf_min_events); + return(tree.find_cutpoints()); + + } + + // [[Rcpp::export]] + List sprout_node_survival_exported(arma::mat& y, + arma::vec& w){ + + TreeSurvival tree; + + arma::uword node_id = 0; + arma::uword leaf_size = 1; + + tree.unique_event_times = new vec(find_unique_event_times(y)); + tree.set_y_node(y); + tree.set_w_node(w); + + tree.resize_leaves(leaf_size); + tree.sprout_leaf(node_id); + + List result; + + result.push_back(tree.get_leaf_pred_indx(), "indx"); + result.push_back(tree.get_leaf_pred_prob(), "prob"); + result.push_back(tree.get_leaf_pred_chaz(), "chaz"); + result.push_back(tree.get_leaf_summary(), "mort"); + + return(result); + + } + // [[Rcpp::export]] List cph_scale(arma::mat& x, arma::vec& w){ diff --git a/tests/testthat/helper-orsf.R b/tests/testthat/helper-orsf.R index ae247233..fea821a2 100644 --- a/tests/testthat/helper-orsf.R +++ b/tests/testthat/helper-orsf.R @@ -27,6 +27,57 @@ change_scale <- function(x, mult_by = 1/2){ x * mult_by } +#' Find cut-point boundaries (R version) +#' +#' Used to test the cpp version for finding cutpoints +#' +#' @param y_node outcome matrix +#' @param w_node weight vector +#' @param XB linear combination of predictors +#' @param xb_uni unique values in XB +#' @param leaf_min_events min no. of events in a leaf +#' @param leaf_min_obs min no. of observations in a leaf +#' +#' @noRd +#' +#' @return data.frame with description of valid cutpoints +cp_find_bounds_R <- function(y_node, + w_node, + XB, + xb_uni, + leaf_min_events, + leaf_min_obs){ + + status = y_node[, 'status'] + + cp_stats <- + sapply( + X = xb_uni, + FUN = function(x){ + c( + cp = x, + e_right = sum(status[XB > x] * w_node[XB > x]), + e_left = sum(status[XB <= x] * w_node[XB <= x]), + n_right = sum(as.numeric(XB > x) * w_node), + n_left = sum(as.numeric(XB <= x) * w_node) + ) + } + ) + + cp_stats <- as.data.frame(t(cp_stats)) + + cp_stats$valid_cp = with( + cp_stats, + e_right >= leaf_min_events & + e_left >= leaf_min_events & + n_right >= leaf_min_obs & + n_left >= leaf_min_obs + ) + + cp_stats + +} + # oobag functions ---- @@ -138,3 +189,86 @@ expect_equal_oobag_eval <- function(x, y){ y$eval_oobag$stat_values, tolerance = 1e-9) } + +# data processing ---- + +prep_test_matrices <- function(data, outcomes = c("time", "status")){ + + names_y_data <- outcomes + names_x_data <- setdiff(names(data), outcomes) + + fi <- fctr_info(data, names_x_data) + + types_x_data <- check_var_types(data, + names_x_data, + valid_types = c('numeric', + 'integer', + 'units', + 'factor', + 'ordered')) + + names_x_numeric <- grep(pattern = "^integer$|^numeric$|^units$", + x = types_x_data) + + means <- standard_deviations<- modes <- numeric_bounds <- NULL + + numeric_cols <- names_x_data[names_x_numeric] + nominal_cols <- fi$cols + + if(!is_empty(nominal_cols)){ + + modes <- vapply( + select_cols(data, nominal_cols), + collapse::fmode, + FUN.VALUE = integer(1) + ) + + } + + if(!is_empty(numeric_cols)){ + + numeric_data <- select_cols(data, numeric_cols) + + numeric_bounds <- matrix( + data = c( + collapse::fnth(numeric_data, 0.1), + collapse::fnth(numeric_data, 0.25), + collapse::fnth(numeric_data, 0.5), + collapse::fnth(numeric_data, 0.75), + collapse::fnth(numeric_data, 0.9) + ), + nrow =5, + byrow = TRUE, + dimnames = list(c('10%', '25%', '50%', '75%', '90%'), + names(numeric_data)) + ) + + means <- collapse::fmean(numeric_data) + + standard_deviations <- collapse::fsd(numeric_data) + + } + + if(any(is.na(select_cols(data, names_y_data)))) + stop("Please remove missing values from the outcome variable(s)", + call. = FALSE) + + cc <- stats::complete.cases(data[, names_x_data]) + data <- data[cc, ] + + y <- prep_y(data, names_y_data) + x <- prep_x(data, fi, names_x_data, means, standard_deviations) + w <- sample(1:3, nrow(y), replace = TRUE) + + sorted <- collapse::radixorder(y[, 1], -y[, 2]) + + return( + list( + x = x[sorted, ], + y = y[sorted, ], + w = w[sorted] + ) + ) + + +} diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 4026a764..fab5f4fb 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -1,4 +1,6 @@ +set.seed(329) + #' @srrstats {G5.0} *use pbc/flchain data, standards for survival analysis* library(survival) @@ -8,17 +10,25 @@ library(survival) data("flchain", package = 'survival') flc <- flchain - flc$chapter <- NULL - flc <- na.omit(flc) - - flc <- flc[flc$futime > 0, ] names(flc)[names(flc) == 'futime'] <- 'time' names(flc)[names(flc) == 'death'] <- 'status' +# make sorted x and y matrices for testing internal cpp functions +flc_mats <- prep_test_matrices(flc, outcomes = c("time", "status")) + +# lung ---- + +lcd <- survival::lung +lcd$inst <- NULL +lcd <- na.omit(lcd) + +# make sorted x and y matrices for testing internal cpp functions +lcd_mats <- prep_test_matrices(lcd, outcomes = c("time", "status")) + # pbc ---- pbc <- pbc_orsf @@ -37,6 +47,9 @@ for(i in vars){ pbc_scale[[i]] <- change_scale(pbc_scale[[i]]) } +# make sorted x and y matrices for testing internal cpp functions +pbc_mats <- prep_test_matrices(pbc, outcomes = c("time", "status")) + # data lists ---- data_list_pbc <- list(pbc_standard = pbc, @@ -44,10 +57,16 @@ data_list_pbc <- list(pbc_standard = pbc, pbc_scaled = pbc_scale, pbc_noiced = pbc_noise) -# standards used to check validity of other fits +# matric lists ---- + +mat_list_surv <- list(pbc = pbc_mats, + flc = flc_mats, + lcd = lcd_mats) # standard fits ---- +# standards used to check validity of other fits + seeds_standard <- c(5, 20, 1000, 30, 50, 98, 22, 100, 329, 10) fit_standard_pbc <- orsf(pbc, diff --git a/tests/testthat/test-cp_find_bounds.R b/tests/testthat/test-cp_find_bounds.R index ec81c687..e69de29b 100644 --- a/tests/testthat/test-cp_find_bounds.R +++ b/tests/testthat/test-cp_find_bounds.R @@ -1,138 +0,0 @@ - -# tests are deprecated since node functions are now internal to Tree class. - -# run_cp_bounds_test <- function(test_values, XB){ -# -# xb_uni <- unique(XB) -# -# cp_stats <- cp_find_bounds_R(y, w, XB, xb_uni, leaf_min_events, leaf_min_obs) -# -# if(!any(cp_stats$valid_cp)){ -# return(NULL) -# } -# -# cps_true_values <- sort(xb_uni[cp_stats$valid_cp]) -# cps_test_values <- XB[test_values$XB_sorted+1][test_values$cp_index+1] -# -# test_that( -# desc = 'cutpoints identified are unique and valid', -# code = { -# -# expect_equal( -# length(cps_test_values), length(unique(cps_test_values)) -# ) -# -# expect_equal(cps_true_values, cps_test_values) -# -# } -# -# ) -# -# test_that( -# desc = "group values are filled corresponding to the given cut-point", -# code = { -# -# group_cpp <- rep(0, length(XB)) -# XB_sorted <- order(XB)-1 -# -# for(i in seq_along(cps_true_values)){ -# -# group_R = XB <= cps_true_values[i] -# -# if(i == 1) start <- 0 else start <- test_values$cp_index[i-1]+1 -# -# node_fill_group_exported( -# group = group_cpp, -# XB_sorted = XB_sorted, -# start = start, -# stop = test_values$cp_index[i], -# value = 1 -# ) -# -# expect_equal(as.numeric(group_R), -# as.numeric(group_cpp)) -# -# } -# } -# ) -# -# } -# -# .leaf_min_events <- c(1, 5, 50, nrow(pbc_orsf)) -# -# # leaf_min_events = 1 -# -# for(leaf_min_events in .leaf_min_events){ -# -# leaf_min_obs <- leaf_min_events + 10 -# -# XB_ctns <- pbc_orsf$age -# XB_catg <- round(pbc_orsf$bili) -# XB_bnry <- as.numeric(pbc_orsf$sex) -# -# status <- pbc_orsf$status -# time <- pbc_orsf$time -# -# t_sort <- order(time) -# status <- status[t_sort] -# XB_ctns <- XB_ctns[t_sort] -# XB_catg <- XB_catg[t_sort] -# XB_bnry <- XB_bnry[t_sort] -# time <- time[t_sort] -# -# y <- cbind(time=time, status=status) -# w <- rep(1, nrow(pbc_orsf)) -# -# cp_bounds <- lapply( -# X = list(ctns = XB_ctns, -# catg = XB_catg, -# bnry = XB_bnry), -# FUN = function(XB){ -# node_find_cps_exported(y_node = y, -# w_node = w, -# XB = XB, -# leaf_min_events = leaf_min_events, -# leaf_min_obs = leaf_min_obs) -# } -# ) -# -# run_cp_bounds_test(test_values = cp_bounds$ctns, XB = XB_ctns) -# run_cp_bounds_test(cp_bounds$catg, XB = XB_catg) -# run_cp_bounds_test(cp_bounds$bnry, XB = XB_bnry) -# -# -# -# } - - -# benchmark does not need to be tested every time - -# bm <- microbenchmark::microbenchmark( -# -# R = { -# xb_uni = unique(XB_ctns) -# cp_find_bounds_R(y_node = y, -# w_node = w, -# XB = XB_ctns, -# xb_uni = xb_uni, -# leaf_min_events = 5, -# leaf_min_obs = 10) -# }, -# -# cpp = cp_find_bounds_exported(y_node = y, -# w_node = w, -# XB = XB_ctns, -# leaf_min_events = 5, -# leaf_min_obs = 10), -# -# times = 50 -# -# ) -# -# expect_lt( -# median(bm$time[bm$expr == 'cpp']), -# median(bm$time[bm$expr == 'R']) -# ) - - - diff --git a/tests/testthat/test-find_cutpoints.R b/tests/testthat/test-find_cutpoints.R new file mode 100644 index 00000000..27e1fcef --- /dev/null +++ b/tests/testthat/test-find_cutpoints.R @@ -0,0 +1,108 @@ + +test_that( + desc = 'cutpoints are unique and correct', + code = { + + for(i in seq_along(mat_list_surv)){ + + y <- mat_list_surv[[i]]$y + w <- mat_list_surv[[i]]$w + + for(cp_type in c("ctns", "bnry", "catg")){ + + xb <- switch( + cp_type, + 'ctns' = rnorm(nrow(y)), + 'bnry' = rbinom(nrow(y), size = 1, prob = 1/2), + 'catg' = rbinom(nrow(y), size = 5, prob = 1/2) + ) + + xb_uni <- unique(xb) + + for(leaf_min_events in c(1, 5, 10)){ + + for(leaf_min_obs in c(leaf_min_events + c(0, 5, 10))){ + + cp_stats <- cp_find_bounds_R(y, w, xb, xb_uni, leaf_min_events, leaf_min_obs) + + cp_index <- find_cutpoints_survival_exported(y, w, xb, + leaf_min_events, + leaf_min_obs) + + + cps_r <- cp_stats$cp[cp_stats$valid_cp] + + cps_cpp <- sort(xb)[cp_index+1] + + expect_equal(length(cps_cpp), length(unique(cps_cpp))) + expect_true(is_equivalent(cps_r, cps_cpp)) + + } + + } + + } + + } + + } +) + + + +# +# test_that( +# desc = "group values are filled corresponding to the given cut-point", +# code = { +# +# group_cpp <- rep(0, length(XB)) +# XB_sorted <- order(XB)-1 +# +# for(i in seq_along(cps_true_values)){ +# +# group_R = XB <= cps_true_values[i] +# +# if(i == 1) start <- 0 else start <- test_values$cp_index[i-1]+1 +# +# node_fill_group_exported( +# group = group_cpp, +# XB_sorted = XB_sorted, +# start = start, +# stop = test_values$cp_index[i], +# value = 1 +# ) +# +# expect_equal(as.numeric(group_R), +# as.numeric(group_cpp)) +# +# } +# } +# ) +# +# } +# + + + +# benchmark does not need to be tested every time + +# bm <- microbenchmark::microbenchmark( +# +# cp_stats = cp_find_bounds_R(y, w, xb, xb_uni, leaf_min_events, leaf_min_obs), +# +# cp_index = find_cutpoints_survival_exported(y, w, xb, +# leaf_min_events, +# leaf_min_obs), +# +# times = 50 +# +# ) +# +# expect_lt( +# median(bm$time[bm$expr == 'cpp']), +# median(bm$time[bm$expr == 'R']) +# ) + + + + diff --git a/tests/testthat/test-leaf_kaplan.R b/tests/testthat/test-leaf_kaplan.R deleted file mode 100644 index f9156c47..00000000 --- a/tests/testthat/test-leaf_kaplan.R +++ /dev/null @@ -1,42 +0,0 @@ - -#' @srrstats {G5.4} **Correctness tests** *test that statistical algorithms produce expected results to some fixed test data sets. I use the flchain data and compare the aorsf kaplan meier routine to that of the survival package.* - -#' @srrstats {G5.4b} *Correctness tests include tests against previous implementations, explicitly calling those implementations in testing.* - -#' @srrstats {G5.5} *Correctness tests are run with a fixed random seed* -set.seed(329) - -flc <- flc[flc$status==1, ] - -weights <- sample(1:5, nrow(flc), replace = TRUE) - -# fit a normal tree with no bootstrap weights -fit <- orsf(flc, - time + status ~ ., - n_tree = 1, - weights = weights, - tree_seeds = 1, - oobag_pred_type = 'none', - # this makes every observation part of the in-bag data - sample_fraction = 1, - sample_with_replacement = FALSE, - # this forces the tree to make a leaf at the root - split_rule = 'cstat', - split_min_stat = 0.999) - -# so the result should be equivalent to fitting a kaplan-meier curve -# to the original training data, using replicate weights -aorsf_surv <- fit$forest$leaf_pred_prob[[1]][[1]] -aorsf_time <- fit$forest$leaf_pred_indx[[1]][[1]] - -kap <- survfit(Surv(time, status) ~ 1, data = flc, weights = weights) - -test_that( - desc = 'aorsf kaplan has same time values as survfit', - code = {expect_equal(kap$time, aorsf_time, tolerance = 1e-9)} -) - -test_that( - desc = 'aorsf kaplan has same surv values as survfit', - code = {expect_equal(kap$surv, aorsf_surv, tolerance = 1e-9)} -) diff --git a/tests/testthat/test-orsf_vi.R b/tests/testthat/test-orsf_vi.R index e66026b6..89f772e0 100644 --- a/tests/testthat/test-orsf_vi.R +++ b/tests/testthat/test-orsf_vi.R @@ -32,7 +32,7 @@ test_that( fit_with_vi <- orsf(pbc_vi, formula = formula, importance = importance, - n_tree = 50, + n_tree = 75, group_factors = group_factors, tree_seeds = tree_seeds) @@ -54,7 +54,7 @@ test_that( fit_no_vi <- orsf(pbc_vi, formula = formula, importance = 'none', - n_tree = 50, + n_tree = 75, group_factors = group_factors, tree_seeds = tree_seeds) @@ -69,7 +69,7 @@ test_that( fit_custom_oobag <- orsf(pbc_vi, formula = formula, importance = importance, - n_tree = 50, + n_tree = 75, oobag_fun = oobag_c_risk, group_factors = group_factors, tree_seeds = tree_seeds) @@ -86,7 +86,7 @@ test_that( fit_threads <- orsf(pbc_vi, formula = formula, importance = importance, - n_tree = 50, + n_tree = 75, n_thread = 0, group_factors = group_factors, tree_seeds = tree_seeds) @@ -98,7 +98,7 @@ test_that( good_vars <- c('bili', 'protime', - if(group_factors) 'edema' else "edema_1") + if(group_factors) 'edema' else c("edema_1", "edema_0.5")) bad_vars <- setdiff(names(vi_during_fit), good_vars) diff --git a/tests/testthat/test-sprout_node.R b/tests/testthat/test-sprout_node.R new file mode 100644 index 00000000..420da30f --- /dev/null +++ b/tests/testthat/test-sprout_node.R @@ -0,0 +1,45 @@ + +test_that( + desc = 'leaf node stats have same time/surv/chaz as survfit', + code = { + + for(i in seq_along(mat_list_surv)){ + + y <- mat_list_surv[[i]]$y + w <- mat_list_surv[[i]]$w + r <- sprout_node_survival_exported(y, w) + + aorsf_surv <- r$prob[[1]] + aorsf_chaz <- r$chaz[[1]] + aorsf_time <- r$indx[[1]] + aorsf_mort <- r$mort[[1]] + + kap_fit <- survfit(Surv(time, status) ~ 1, + data = as.data.frame(y), + weights = w) + + kap_data <- data.frame(time = kap_fit$time, + surv = kap_fit$surv, + cumhaz = kap_fit$cumhaz, + n_event = kap_fit$n.event) + + kap_data <- subset(kap_data, n_event > 0) + + expect_equal(kap_data$time, aorsf_time, tolerance = 1e-9) + expect_equal(kap_data$surv, aorsf_surv, tolerance = 1e-9) + expect_equal(kap_data$cumhaz, aorsf_chaz, tolerance = 1e-9) + expect_equal(sum(kap_data$cumhaz), aorsf_mort, tolerance = 1e-9) + + } + + } +) + + + + + + + + +