diff --git a/R/RcppExports.R b/R/RcppExports.R
index f9b0c296..14b8bcc9 100644
--- a/R/RcppExports.R
+++ b/R/RcppExports.R
@@ -21,6 +21,14 @@ is_col_splittable_exported <- function(x, y, r, j) {
.Call(`_aorsf_is_col_splittable_exported`, x, y, r, j)
}
+find_cutpoints_survival_exported <- function(y, w, lincomb, leaf_min_events, leaf_min_obs) {
+ .Call(`_aorsf_find_cutpoints_survival_exported`, y, w, lincomb, leaf_min_events, leaf_min_obs)
+}
+
+sprout_node_survival_exported <- function(y, w) {
+ .Call(`_aorsf_sprout_node_survival_exported`, y, w)
+}
+
cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}
diff --git a/R/misc.R b/R/misc.R
index d735910b..149a0e34 100644
--- a/R/misc.R
+++ b/R/misc.R
@@ -104,55 +104,6 @@ paste_collapse <- function(x, sep=', ', last = ' or '){
}
-#' Find cut-point boundaries (R version)
-#'
-#' Used to test the cpp version for finding cutpoints
-#'
-#' @param y_node outcome matrix
-#' @param w_node weight vector
-#' @param XB linear combination of predictors
-#' @param xb_uni unique values in XB
-#' @param leaf_min_events min no. of events in a leaf
-#' @param leaf_min_obs min no. of observations in a leaf
-#'
-#' @noRd
-#'
-#' @return data.frame with description of valid cutpoints
-cp_find_bounds_R <- function(y_node,
- w_node,
- XB,
- xb_uni,
- leaf_min_events,
- leaf_min_obs){
-
- status = y_node[, 'status']
-
- cp_stats <-
- sapply(
- X = xb_uni,
- FUN = function(x){
- c(
- cp = x,
- e_right = sum(status[XB > x]),
- e_left = sum(status[XB <= x]),
- n_right = sum(XB > x),
- n_left = sum(XB <= x)
- )
- }
- )
-
- cp_stats <- as.data.frame(t(cp_stats))
-
- cp_stats$valid_cp = with(
- cp_stats,
- e_right >= leaf_min_events & e_left >= leaf_min_events &
- n_right >= leaf_min_obs & n_left >= leaf_min_obs
- )
-
- cp_stats
-
-}
-
has_units <- function(x){
inherits(x, 'units')
}
diff --git a/man/orsf_vi.Rd b/man/orsf_vi.Rd
index f130cd94..3e82c020 100644
--- a/man/orsf_vi.Rd
+++ b/man/orsf_vi.Rd
@@ -221,26 +221,26 @@ orsf_vi_negate(fit_no_vi)
}\if{html}{\out{}}
\if{html}{\out{
}}\preformatted{## bili copper sex protime stage
-## 0.1139657923 0.0498712200 0.0355366377 0.0283554322 0.0263792287
+## 0.1140229590 0.0498125548 0.0355460350 0.0281214750 0.0261880947
## albumin age ascites chol ast
-## 0.0231636378 0.0195791833 0.0175120075 0.0148252414 0.0104918262
+## 0.0231317226 0.0195454979 0.0174983334 0.0147644716 0.0106376518
## edema spiders hepato trt trig
-## 0.0084871358 0.0070608860 0.0067054788 0.0052040792 0.0030363455
+## 0.0084608002 0.0069190613 0.0065709504 0.0052988663 0.0030146941
## alk.phos platelet
-## 0.0029918139 -0.0003309069
+## 0.0029465729 -0.0003607455
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{orsf_vi_permute(fit_no_vi)
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## bili copper protime albumin stage
-## 0.0511641625 0.0244676999 0.0160869571 0.0133334120 0.0130092352
-## ascites age hepato edema chol
-## 0.0127421184 0.0113532728 0.0050851329 0.0050477457 0.0049382275
-## ast spiders sex alk.phos trig
-## 0.0047345189 0.0038719163 0.0025231267 0.0018408350 0.0011528848
+## 0.0512565237 0.0244154211 0.0159740365 0.0133163811 0.0130200869
+## ascites age edema hepato ast
+## 0.0127617348 0.0113466862 0.0050618672 0.0050513510 0.0048291999
+## chol spiders sex alk.phos trig
+## 0.0048061978 0.0037750082 0.0024935019 0.0018249747 0.0011756814
## platelet trt
-## -0.0002875319 -0.0024330707
+## -0.0003313926 -0.0024983274
}\if{html}{\out{
}}
}
@@ -257,13 +257,13 @@ orsf_vi_permute(fit_permute_vi)
}\if{html}{\out{}}
\if{html}{\out{}}\preformatted{## bili copper age albumin protime
-## 0.0502725526 0.0201473283 0.0135888938 0.0127241082 0.0126629150
+## 0.0503462348 0.0202245782 0.0136242456 0.0127255613 0.0127084993
## stage ascites ast edema chol
-## 0.0124866976 0.0123508555 0.0060741690 0.0059166139 0.0053767371
+## 0.0124821760 0.0123368881 0.0061390036 0.0059648179 0.0053888106
## spiders sex hepato trig alk.phos
-## 0.0042600602 0.0028177750 0.0023470782 0.0021331719 0.0016874102
+## 0.0042214032 0.0027713550 0.0022973825 0.0019936938 0.0017116920
## platelet trt
-## 0.0002117061 -0.0005790547
+## 0.0002320012 -0.0006040636
}\if{html}{\out{
}}
You can still get negation VI from this fit, but it needs to be computed
@@ -272,11 +272,11 @@ You can still get negation VI from this fit, but it needs to be computed
}\if{html}{\out{}}
\if{html}{\out{}}\preformatted{## bili copper sex stage age protime
-## 0.1106715167 0.0456031656 0.0306666098 0.0304383573 0.0252136203 0.0224838590
+## 0.1106480407 0.0455070935 0.0307476652 0.0304715880 0.0252079956 0.0224192033
## albumin ascites chol ast edema trt
-## 0.0212630703 0.0168893963 0.0134174671 0.0132075752 0.0099681058 0.0088378768
+## 0.0214131931 0.0168756851 0.0133727008 0.0132732736 0.0100002899 0.0087589773
## spiders hepato trig alk.phos platelet
-## 0.0078776082 0.0062877323 0.0043076141 0.0030432581 0.0005571111
+## 0.0079457577 0.0062060793 0.0041321875 0.0030525590 0.0005921052
}\if{html}{\out{
}}
}
}
diff --git a/man/predict.orsf_fit.Rd b/man/predict.orsf_fit.Rd
index bafe78ca..f766dd48 100644
--- a/man/predict.orsf_fit.Rd
+++ b/man/predict.orsf_fit.Rd
@@ -169,11 +169,11 @@ prediction horizon
}\if{html}{\out{}}
\if{html}{\out{}}\preformatted{## [,1]
-## [1,] 83.08611
-## [2,] 27.48146
-## [3,] 43.52432
-## [4,] 15.20281
-## [5,] 10.56334
+## [1,] 81.42491
+## [2,] 26.90304
+## [3,] 42.62956
+## [4,] 14.88943
+## [5,] 10.35110
}\if{html}{\out{
}}
}
diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp
index 7f74671b..5f786d22 100644
--- a/src/RcppExports.cpp
+++ b/src/RcppExports.cpp
@@ -82,6 +82,33 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
+// find_cutpoints_survival_exported
+arma::uvec find_cutpoints_survival_exported(arma::mat& y, arma::vec& w, arma::vec& lincomb, double leaf_min_events, double leaf_min_obs);
+RcppExport SEXP _aorsf_find_cutpoints_survival_exported(SEXP ySEXP, SEXP wSEXP, SEXP lincombSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP) {
+BEGIN_RCPP
+ Rcpp::RObject rcpp_result_gen;
+ Rcpp::RNGScope rcpp_rngScope_gen;
+ Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type lincomb(lincombSEXP);
+ Rcpp::traits::input_parameter< double >::type leaf_min_events(leaf_min_eventsSEXP);
+ Rcpp::traits::input_parameter< double >::type leaf_min_obs(leaf_min_obsSEXP);
+ rcpp_result_gen = Rcpp::wrap(find_cutpoints_survival_exported(y, w, lincomb, leaf_min_events, leaf_min_obs));
+ return rcpp_result_gen;
+END_RCPP
+}
+// sprout_node_survival_exported
+List sprout_node_survival_exported(arma::mat& y, arma::vec& w);
+RcppExport SEXP _aorsf_sprout_node_survival_exported(SEXP ySEXP, SEXP wSEXP) {
+BEGIN_RCPP
+ Rcpp::RObject rcpp_result_gen;
+ Rcpp::RNGScope rcpp_rngScope_gen;
+ Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ rcpp_result_gen = Rcpp::wrap(sprout_node_survival_exported(y, w));
+ return rcpp_result_gen;
+END_RCPP
+}
// cph_scale
List cph_scale(arma::mat& x, arma::vec& w);
RcppExport SEXP _aorsf_cph_scale(SEXP xSEXP, SEXP wSEXP) {
@@ -155,6 +182,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_aorsf_compute_cstat_exported_uvec", (DL_FUNC) &_aorsf_compute_cstat_exported_uvec, 4},
{"_aorsf_compute_logrank_exported", (DL_FUNC) &_aorsf_compute_logrank_exported, 3},
{"_aorsf_is_col_splittable_exported", (DL_FUNC) &_aorsf_is_col_splittable_exported, 4},
+ {"_aorsf_find_cutpoints_survival_exported", (DL_FUNC) &_aorsf_find_cutpoints_survival_exported, 5},
+ {"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2},
{"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2},
{"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44},
{NULL, NULL, 0}
diff --git a/src/Tree.cpp b/src/Tree.cpp
index f7927e03..a637f867 100644
--- a/src/Tree.cpp
+++ b/src/Tree.cpp
@@ -562,20 +562,14 @@
Rcout << std::endl;
}
-
// do not split if best stat < minimum stat
- if(stat_best < split_min_stat){
-
- return(R_PosInf);
-
- }
+ if(stat_best < split_min_stat){ return(R_PosInf); }
// backtrack g_node to be what it was when best it was found
if(it_best < it_start){
g_node.elem(lincomb_sort.subvec(it_best+1, it_start)).fill(1);
}
-
// return the cut-point from best split
return(lincomb[lincomb_sort[it_best]]);
diff --git a/src/Tree.h b/src/Tree.h
index 79011abe..4602f53b 100644
--- a/src/Tree.h
+++ b/src/Tree.h
@@ -139,10 +139,42 @@
this->y_inbag = y;
}
+ void set_w_inbag(arma::vec w){
+ this->w_inbag = w;
+ }
+
+ void set_x_node(arma::mat x){
+ this->x_node = x;
+ }
+
+ void set_y_node(arma::mat y){
+ this->y_node = y;
+ }
+
+ void set_w_node(arma::vec w){
+ this->w_node = w;
+ }
+
void set_rows_node(arma::uvec rows){
this->rows_node = rows;
}
+ void set_lincomb(arma::vec lc){
+ this->lincomb = lc;
+ }
+
+ void set_lincomb_sort(arma::uvec lc_sort){
+ this->lincomb_sort = lc_sort;
+ }
+
+ void set_leaf_min_obs(double value){
+ this->leaf_min_obs = value;
+ }
+
+ void set_verbosity(int value){
+ this->verbosity = value;
+ }
+
protected:
diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp
index b38ebd02..8c3075a0 100644
--- a/src/TreeSurvival.cpp
+++ b/src/TreeSurvival.cpp
@@ -157,24 +157,12 @@
double n_events = 0, n_risk = 0;
- if(VERBOSITY > 1){
- Rcout << "----- finding lower bound for cut-points -----" << std::endl;
- }
-
// stop at end-1 b/c we access it+1 in lincomb_sort
for(it = lincomb_sort.begin(); it < lincomb_sort.end()-1; ++it){
n_events += y_status[*it] * w_node[*it];
n_risk += w_node[*it];
-
- if(VERBOSITY > 2){
- Rcout << "current value: "<< lincomb(*it) << " -- ";
- Rcout << "next: "<< lincomb(*(it+1)) << " -- ";
- Rcout << "N events: " << n_events << " -- ";
- Rcout << "N risk: " << n_risk << std::endl;
- }
-
// If we want to make the current value of lincomb a cut-point, we need
// to make sure the next value of lincomb isn't equal to this current value.
// Otherwise, we will have the same value of lincomb in both groups!
@@ -184,11 +172,11 @@
if( n_events >= leaf_min_events &&
n_risk >= leaf_min_obs ) {
- if(VERBOSITY > 0){
+ if(verbosity > 2){
Rcout << std::endl;
- Rcout << "lower cutpoint: " << lincomb(*it) << std::endl;
- Rcout << " - n_events, left node: " << n_events << std::endl;
- Rcout << " - n_risk, left node: " << n_risk << std::endl;
+ Rcout << " -- lower cutpoint: " << lincomb(*it) << std::endl;
+ Rcout << " - n_events, left node: " << n_events << std::endl;
+ Rcout << " - n_risk, left node: " << n_risk << std::endl;
Rcout << std::endl;
}
@@ -204,8 +192,8 @@
if(it == lincomb_sort.end()-1) {
- if(VERBOSITY > 1){
- Rcout << "Could not find a valid cut-point" << std::endl;
+ if(verbosity > 2){
+ Rcout << " -- Could not find a valid cut-point" << std::endl;
}
return(output);
@@ -218,23 +206,12 @@
// reset before finding the upper limit
n_events=0, n_risk=0;
- if(VERBOSITY > 1){
- Rcout << "----- finding upper bound for cut-points -----" << std::endl;
- }
-
// stop at beginning+1 b/c we access it-1 in lincomb_sort
for(it = lincomb_sort.end()-1; it >= lincomb_sort.begin()+1; --it){
n_events += y_status[*it] * w_node[*it];
n_risk += w_node[*it];
- if(VERBOSITY > 2){
- Rcout << "current value: "<< lincomb(*it) << " ---- ";
- Rcout << "next value: "<< lincomb(*(it-1)) << " ---- ";
- Rcout << "N events: " << n_events << " ---- ";
- Rcout << "N risk: " << n_risk << std::endl;
- }
-
if(lincomb[*it] != lincomb[*(it-1)]){
if( n_events >= leaf_min_events &&
@@ -250,11 +227,11 @@
--it;
- if(VERBOSITY > 0){
+ if(verbosity > 2){
Rcout << std::endl;
- Rcout << "upper cutpoint: " << lincomb(*it) << std::endl;
- Rcout << " - n_events, right node: " << n_events << std::endl;
- Rcout << " - n_risk, right node: " << n_risk << std::endl;
+ Rcout << " -- upper cutpoint: " << lincomb(*it) << std::endl;
+ Rcout << " - n_events, right node: " << n_events << std::endl;
+ Rcout << " - n_risk, right node: " << n_risk << std::endl;
Rcout << std::endl;
}
@@ -273,7 +250,7 @@
if(j > k){
- if(VERBOSITY > 0) {
+ if(verbosity > 2) {
Rcout << "Could not find valid cut-points" << std::endl;
}
@@ -333,68 +310,6 @@
}
- double TreeSurvival::score_logrank(){
-
- double
- n_risk=0,
- g_risk=0,
- observed=0,
- expected=0,
- V=0,
- temp1,
- temp2,
- n_events;
-
- vec y_time = y_node.unsafe_col(0);
- vec y_status = y_node.unsafe_col(1);
-
- bool break_loop = false;
-
- uword i = y_node.n_rows-1;
-
- // breaking condition of outer loop governed by inner loop
- for (; ;){
-
- temp1 = y_time[i];
-
- n_events = 0;
-
- for ( ; y_time[i] == temp1; i--) {
-
- n_risk += w_node[i];
- n_events += y_status[i] * w_node[i];
- g_risk += g_node[i] * w_node[i];
- observed += y_status[i] * g_node[i] * w_node[i];
-
- if(i == 0){
- break_loop = true;
- break;
- }
-
- }
-
- // should only do these calculations if n_events > 0,
- // but in practice its often faster to multiply by 0
- // versus check if n_events is > 0.
-
- temp2 = g_risk / n_risk;
- expected += n_events * temp2;
-
- // update variance if n_risk > 1 (if n_risk == 1, variance is 0)
- // definitely check if n_risk is > 1 b/c otherwise divide by 0
- if (n_risk > 1){
- temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
- V += temp1 * (1 - temp2);
- }
-
- if(break_loop) break;
-
- }
-
- return(pow(expected-observed, 2) / V);
-
- }
-
void TreeSurvival::sprout_leaf(uword node_id){
if(verbosity > 2){
@@ -404,6 +319,7 @@
Rcout << std::endl;
}
+
// reserve as much size as could be needed (probably more)
mat leaf_data(y_node.n_rows, 3);
@@ -414,6 +330,7 @@
person++;
}
+
// person corresponds to first event or last censor time
leaf_data.at(0, 0) = y_node.at(person, 0);
@@ -421,8 +338,8 @@
// (TODO: should this case even occur? consider removing)
if(person == y_node.n_rows){
- vec temp_surv(1, arma::fill::ones);
- vec temp_chf(1, arma::fill::zeros);
+ vec temp_surv(1, fill::ones);
+ vec temp_chf(1, fill::zeros);
leaf_pred_indx[node_id] = leaf_data.col(0);
leaf_pred_prob[node_id] = temp_surv;
@@ -514,8 +431,11 @@
for( ; i < (*unique_event_times).size(); i++){
- if((*unique_event_times)[i] >= leaf_data.at(j, 0) &&
- j < (leaf_data.n_rows-1)) {j++;}
+
+ while((*unique_event_times)[i] > leaf_data.at(j, 0) &&
+ j < (leaf_data.n_rows-1)) {
+ j++;
+ }
result += leaf_data.at(j, 2);
@@ -525,6 +445,24 @@
}
+ // double TreeSurvival::compute_mortality(arma::mat& leaf_data){
+ //
+ // double result = 0;
+ // uword i=0, j=0;
+ //
+ // for( ; i < (*unique_event_times).size(); i++){
+ //
+ // if((*unique_event_times)[i] >= leaf_data.at(j, 0) &&
+ // j < (leaf_data.n_rows-1)) {j++;}
+ //
+ // result += leaf_data.at(j, 2);
+ //
+ // }
+ //
+ // return(result);
+ //
+ // }
+
void TreeSurvival::predict_value(arma::mat* pred_output,
arma::vec* pred_denom,
PredType pred_type,
diff --git a/src/TreeSurvival.h b/src/TreeSurvival.h
index 0306616c..941aee24 100644
--- a/src/TreeSurvival.h
+++ b/src/TreeSurvival.h
@@ -51,8 +51,6 @@
double compute_split_score() override;
- double score_logrank();
-
double compute_mortality(arma::mat& leaf_data);
void sprout_leaf(uword node_id) override;
@@ -74,6 +72,14 @@
return(leaf_pred_chaz);
}
+ void set_unique_event_times(arma::vec event_times){
+ this->unique_event_times = &event_times;
+ }
+
+ void set_leaf_min_events(double value){
+ this->leaf_min_events = value;
+ }
+
double compute_prediction_accuracy_internal(arma::vec& preds) override;
std::vector leaf_pred_indx;
diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp
index c83a73d9..5d8672e0 100644
--- a/src/orsf_oop.cpp
+++ b/src/orsf_oop.cpp
@@ -93,6 +93,54 @@
}
+ // [[Rcpp::export]]
+ arma::uvec find_cutpoints_survival_exported(arma::mat& y,
+ arma::vec& w,
+ arma::vec& lincomb,
+ double leaf_min_events,
+ double leaf_min_obs){
+
+ TreeSurvival tree;
+
+ arma::uvec lincomb_sort = sort_index(lincomb);
+
+ tree.set_y_node(y);
+ tree.set_w_node(w);
+ tree.set_lincomb(lincomb);
+ tree.set_lincomb_sort(lincomb_sort);
+ tree.set_leaf_min_obs(leaf_min_obs);
+ tree.set_leaf_min_events(leaf_min_events);
+ return(tree.find_cutpoints());
+
+ }
+
+ // [[Rcpp::export]]
+ List sprout_node_survival_exported(arma::mat& y,
+ arma::vec& w){
+
+ TreeSurvival tree;
+
+ arma::uword node_id = 0;
+ arma::uword leaf_size = 1;
+
+ tree.unique_event_times = new vec(find_unique_event_times(y));
+ tree.set_y_node(y);
+ tree.set_w_node(w);
+
+ tree.resize_leaves(leaf_size);
+ tree.sprout_leaf(node_id);
+
+ List result;
+
+ result.push_back(tree.get_leaf_pred_indx(), "indx");
+ result.push_back(tree.get_leaf_pred_prob(), "prob");
+ result.push_back(tree.get_leaf_pred_chaz(), "chaz");
+ result.push_back(tree.get_leaf_summary(), "mort");
+
+ return(result);
+
+ }
+
// [[Rcpp::export]]
List cph_scale(arma::mat& x, arma::vec& w){
diff --git a/tests/testthat/helper-orsf.R b/tests/testthat/helper-orsf.R
index ae247233..fea821a2 100644
--- a/tests/testthat/helper-orsf.R
+++ b/tests/testthat/helper-orsf.R
@@ -27,6 +27,57 @@ change_scale <- function(x, mult_by = 1/2){
x * mult_by
}
+#' Find cut-point boundaries (R version)
+#'
+#' Used to test the cpp version for finding cutpoints
+#'
+#' @param y_node outcome matrix
+#' @param w_node weight vector
+#' @param XB linear combination of predictors
+#' @param xb_uni unique values in XB
+#' @param leaf_min_events min no. of events in a leaf
+#' @param leaf_min_obs min no. of observations in a leaf
+#'
+#' @noRd
+#'
+#' @return data.frame with description of valid cutpoints
+cp_find_bounds_R <- function(y_node,
+ w_node,
+ XB,
+ xb_uni,
+ leaf_min_events,
+ leaf_min_obs){
+
+ status = y_node[, 'status']
+
+ cp_stats <-
+ sapply(
+ X = xb_uni,
+ FUN = function(x){
+ c(
+ cp = x,
+ e_right = sum(status[XB > x] * w_node[XB > x]),
+ e_left = sum(status[XB <= x] * w_node[XB <= x]),
+ n_right = sum(as.numeric(XB > x) * w_node),
+ n_left = sum(as.numeric(XB <= x) * w_node)
+ )
+ }
+ )
+
+ cp_stats <- as.data.frame(t(cp_stats))
+
+ cp_stats$valid_cp = with(
+ cp_stats,
+ e_right >= leaf_min_events &
+ e_left >= leaf_min_events &
+ n_right >= leaf_min_obs &
+ n_left >= leaf_min_obs
+ )
+
+ cp_stats
+
+}
+
# oobag functions ----
@@ -138,3 +189,86 @@ expect_equal_oobag_eval <- function(x, y){
y$eval_oobag$stat_values,
tolerance = 1e-9)
}
+
+# data processing ----
+
+prep_test_matrices <- function(data, outcomes = c("time", "status")){
+
+ names_y_data <- outcomes
+ names_x_data <- setdiff(names(data), outcomes)
+
+ fi <- fctr_info(data, names_x_data)
+
+ types_x_data <- check_var_types(data,
+ names_x_data,
+ valid_types = c('numeric',
+ 'integer',
+ 'units',
+ 'factor',
+ 'ordered'))
+
+ names_x_numeric <- grep(pattern = "^integer$|^numeric$|^units$",
+ x = types_x_data)
+
+ means <- standard_deviations<- modes <- numeric_bounds <- NULL
+
+ numeric_cols <- names_x_data[names_x_numeric]
+ nominal_cols <- fi$cols
+
+ if(!is_empty(nominal_cols)){
+
+ modes <- vapply(
+ select_cols(data, nominal_cols),
+ collapse::fmode,
+ FUN.VALUE = integer(1)
+ )
+
+ }
+
+ if(!is_empty(numeric_cols)){
+
+ numeric_data <- select_cols(data, numeric_cols)
+
+ numeric_bounds <- matrix(
+ data = c(
+ collapse::fnth(numeric_data, 0.1),
+ collapse::fnth(numeric_data, 0.25),
+ collapse::fnth(numeric_data, 0.5),
+ collapse::fnth(numeric_data, 0.75),
+ collapse::fnth(numeric_data, 0.9)
+ ),
+ nrow =5,
+ byrow = TRUE,
+ dimnames = list(c('10%', '25%', '50%', '75%', '90%'),
+ names(numeric_data))
+ )
+
+ means <- collapse::fmean(numeric_data)
+
+ standard_deviations <- collapse::fsd(numeric_data)
+
+ }
+
+ if(any(is.na(select_cols(data, names_y_data))))
+ stop("Please remove missing values from the outcome variable(s)",
+ call. = FALSE)
+
+ cc <- stats::complete.cases(data[, names_x_data])
+ data <- data[cc, ]
+
+ y <- prep_y(data, names_y_data)
+ x <- prep_x(data, fi, names_x_data, means, standard_deviations)
+ w <- sample(1:3, nrow(y), replace = TRUE)
+
+ sorted <- collapse::radixorder(y[, 1], -y[, 2])
+
+ return(
+ list(
+ x = x[sorted, ],
+ y = y[sorted, ],
+ w = w[sorted]
+ )
+ )
+
+
+}
diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R
index 4026a764..fab5f4fb 100644
--- a/tests/testthat/setup.R
+++ b/tests/testthat/setup.R
@@ -1,4 +1,6 @@
+set.seed(329)
+
#' @srrstats {G5.0} *use pbc/flchain data, standards for survival analysis*
library(survival)
@@ -8,17 +10,25 @@ library(survival)
data("flchain", package = 'survival')
flc <- flchain
-
flc$chapter <- NULL
-
flc <- na.omit(flc)
-
-
flc <- flc[flc$futime > 0, ]
names(flc)[names(flc) == 'futime'] <- 'time'
names(flc)[names(flc) == 'death'] <- 'status'
+# make sorted x and y matrices for testing internal cpp functions
+flc_mats <- prep_test_matrices(flc, outcomes = c("time", "status"))
+
+# lung ----
+
+lcd <- survival::lung
+lcd$inst <- NULL
+lcd <- na.omit(lcd)
+
+# make sorted x and y matrices for testing internal cpp functions
+lcd_mats <- prep_test_matrices(lcd, outcomes = c("time", "status"))
+
# pbc ----
pbc <- pbc_orsf
@@ -37,6 +47,9 @@ for(i in vars){
pbc_scale[[i]] <- change_scale(pbc_scale[[i]])
}
+# make sorted x and y matrices for testing internal cpp functions
+pbc_mats <- prep_test_matrices(pbc, outcomes = c("time", "status"))
+
# data lists ----
data_list_pbc <- list(pbc_standard = pbc,
@@ -44,10 +57,16 @@ data_list_pbc <- list(pbc_standard = pbc,
pbc_scaled = pbc_scale,
pbc_noiced = pbc_noise)
-# standards used to check validity of other fits
+# matric lists ----
+
+mat_list_surv <- list(pbc = pbc_mats,
+ flc = flc_mats,
+ lcd = lcd_mats)
# standard fits ----
+# standards used to check validity of other fits
+
seeds_standard <- c(5, 20, 1000, 30, 50, 98, 22, 100, 329, 10)
fit_standard_pbc <- orsf(pbc,
diff --git a/tests/testthat/test-cp_find_bounds.R b/tests/testthat/test-cp_find_bounds.R
index ec81c687..e69de29b 100644
--- a/tests/testthat/test-cp_find_bounds.R
+++ b/tests/testthat/test-cp_find_bounds.R
@@ -1,138 +0,0 @@
-
-# tests are deprecated since node functions are now internal to Tree class.
-
-# run_cp_bounds_test <- function(test_values, XB){
-#
-# xb_uni <- unique(XB)
-#
-# cp_stats <- cp_find_bounds_R(y, w, XB, xb_uni, leaf_min_events, leaf_min_obs)
-#
-# if(!any(cp_stats$valid_cp)){
-# return(NULL)
-# }
-#
-# cps_true_values <- sort(xb_uni[cp_stats$valid_cp])
-# cps_test_values <- XB[test_values$XB_sorted+1][test_values$cp_index+1]
-#
-# test_that(
-# desc = 'cutpoints identified are unique and valid',
-# code = {
-#
-# expect_equal(
-# length(cps_test_values), length(unique(cps_test_values))
-# )
-#
-# expect_equal(cps_true_values, cps_test_values)
-#
-# }
-#
-# )
-#
-# test_that(
-# desc = "group values are filled corresponding to the given cut-point",
-# code = {
-#
-# group_cpp <- rep(0, length(XB))
-# XB_sorted <- order(XB)-1
-#
-# for(i in seq_along(cps_true_values)){
-#
-# group_R = XB <= cps_true_values[i]
-#
-# if(i == 1) start <- 0 else start <- test_values$cp_index[i-1]+1
-#
-# node_fill_group_exported(
-# group = group_cpp,
-# XB_sorted = XB_sorted,
-# start = start,
-# stop = test_values$cp_index[i],
-# value = 1
-# )
-#
-# expect_equal(as.numeric(group_R),
-# as.numeric(group_cpp))
-#
-# }
-# }
-# )
-#
-# }
-#
-# .leaf_min_events <- c(1, 5, 50, nrow(pbc_orsf))
-#
-# # leaf_min_events = 1
-#
-# for(leaf_min_events in .leaf_min_events){
-#
-# leaf_min_obs <- leaf_min_events + 10
-#
-# XB_ctns <- pbc_orsf$age
-# XB_catg <- round(pbc_orsf$bili)
-# XB_bnry <- as.numeric(pbc_orsf$sex)
-#
-# status <- pbc_orsf$status
-# time <- pbc_orsf$time
-#
-# t_sort <- order(time)
-# status <- status[t_sort]
-# XB_ctns <- XB_ctns[t_sort]
-# XB_catg <- XB_catg[t_sort]
-# XB_bnry <- XB_bnry[t_sort]
-# time <- time[t_sort]
-#
-# y <- cbind(time=time, status=status)
-# w <- rep(1, nrow(pbc_orsf))
-#
-# cp_bounds <- lapply(
-# X = list(ctns = XB_ctns,
-# catg = XB_catg,
-# bnry = XB_bnry),
-# FUN = function(XB){
-# node_find_cps_exported(y_node = y,
-# w_node = w,
-# XB = XB,
-# leaf_min_events = leaf_min_events,
-# leaf_min_obs = leaf_min_obs)
-# }
-# )
-#
-# run_cp_bounds_test(test_values = cp_bounds$ctns, XB = XB_ctns)
-# run_cp_bounds_test(cp_bounds$catg, XB = XB_catg)
-# run_cp_bounds_test(cp_bounds$bnry, XB = XB_bnry)
-#
-#
-#
-# }
-
-
-# benchmark does not need to be tested every time
-
-# bm <- microbenchmark::microbenchmark(
-#
-# R = {
-# xb_uni = unique(XB_ctns)
-# cp_find_bounds_R(y_node = y,
-# w_node = w,
-# XB = XB_ctns,
-# xb_uni = xb_uni,
-# leaf_min_events = 5,
-# leaf_min_obs = 10)
-# },
-#
-# cpp = cp_find_bounds_exported(y_node = y,
-# w_node = w,
-# XB = XB_ctns,
-# leaf_min_events = 5,
-# leaf_min_obs = 10),
-#
-# times = 50
-#
-# )
-#
-# expect_lt(
-# median(bm$time[bm$expr == 'cpp']),
-# median(bm$time[bm$expr == 'R'])
-# )
-
-
-
diff --git a/tests/testthat/test-find_cutpoints.R b/tests/testthat/test-find_cutpoints.R
new file mode 100644
index 00000000..27e1fcef
--- /dev/null
+++ b/tests/testthat/test-find_cutpoints.R
@@ -0,0 +1,108 @@
+
+test_that(
+ desc = 'cutpoints are unique and correct',
+ code = {
+
+ for(i in seq_along(mat_list_surv)){
+
+ y <- mat_list_surv[[i]]$y
+ w <- mat_list_surv[[i]]$w
+
+ for(cp_type in c("ctns", "bnry", "catg")){
+
+ xb <- switch(
+ cp_type,
+ 'ctns' = rnorm(nrow(y)),
+ 'bnry' = rbinom(nrow(y), size = 1, prob = 1/2),
+ 'catg' = rbinom(nrow(y), size = 5, prob = 1/2)
+ )
+
+ xb_uni <- unique(xb)
+
+ for(leaf_min_events in c(1, 5, 10)){
+
+ for(leaf_min_obs in c(leaf_min_events + c(0, 5, 10))){
+
+ cp_stats <- cp_find_bounds_R(y, w, xb, xb_uni, leaf_min_events, leaf_min_obs)
+
+ cp_index <- find_cutpoints_survival_exported(y, w, xb,
+ leaf_min_events,
+ leaf_min_obs)
+
+
+ cps_r <- cp_stats$cp[cp_stats$valid_cp]
+
+ cps_cpp <- sort(xb)[cp_index+1]
+
+ expect_equal(length(cps_cpp), length(unique(cps_cpp)))
+ expect_true(is_equivalent(cps_r, cps_cpp))
+
+ }
+
+ }
+
+ }
+
+ }
+
+ }
+)
+
+
+
+#
+# test_that(
+# desc = "group values are filled corresponding to the given cut-point",
+# code = {
+#
+# group_cpp <- rep(0, length(XB))
+# XB_sorted <- order(XB)-1
+#
+# for(i in seq_along(cps_true_values)){
+#
+# group_R = XB <= cps_true_values[i]
+#
+# if(i == 1) start <- 0 else start <- test_values$cp_index[i-1]+1
+#
+# node_fill_group_exported(
+# group = group_cpp,
+# XB_sorted = XB_sorted,
+# start = start,
+# stop = test_values$cp_index[i],
+# value = 1
+# )
+#
+# expect_equal(as.numeric(group_R),
+# as.numeric(group_cpp))
+#
+# }
+# }
+# )
+#
+# }
+#
+
+
+
+# benchmark does not need to be tested every time
+
+# bm <- microbenchmark::microbenchmark(
+#
+# cp_stats = cp_find_bounds_R(y, w, xb, xb_uni, leaf_min_events, leaf_min_obs),
+#
+# cp_index = find_cutpoints_survival_exported(y, w, xb,
+# leaf_min_events,
+# leaf_min_obs),
+#
+# times = 50
+#
+# )
+#
+# expect_lt(
+# median(bm$time[bm$expr == 'cpp']),
+# median(bm$time[bm$expr == 'R'])
+# )
+
+
+
+
diff --git a/tests/testthat/test-leaf_kaplan.R b/tests/testthat/test-leaf_kaplan.R
deleted file mode 100644
index f9156c47..00000000
--- a/tests/testthat/test-leaf_kaplan.R
+++ /dev/null
@@ -1,42 +0,0 @@
-
-#' @srrstats {G5.4} **Correctness tests** *test that statistical algorithms produce expected results to some fixed test data sets. I use the flchain data and compare the aorsf kaplan meier routine to that of the survival package.*
-
-#' @srrstats {G5.4b} *Correctness tests include tests against previous implementations, explicitly calling those implementations in testing.*
-
-#' @srrstats {G5.5} *Correctness tests are run with a fixed random seed*
-set.seed(329)
-
-flc <- flc[flc$status==1, ]
-
-weights <- sample(1:5, nrow(flc), replace = TRUE)
-
-# fit a normal tree with no bootstrap weights
-fit <- orsf(flc,
- time + status ~ .,
- n_tree = 1,
- weights = weights,
- tree_seeds = 1,
- oobag_pred_type = 'none',
- # this makes every observation part of the in-bag data
- sample_fraction = 1,
- sample_with_replacement = FALSE,
- # this forces the tree to make a leaf at the root
- split_rule = 'cstat',
- split_min_stat = 0.999)
-
-# so the result should be equivalent to fitting a kaplan-meier curve
-# to the original training data, using replicate weights
-aorsf_surv <- fit$forest$leaf_pred_prob[[1]][[1]]
-aorsf_time <- fit$forest$leaf_pred_indx[[1]][[1]]
-
-kap <- survfit(Surv(time, status) ~ 1, data = flc, weights = weights)
-
-test_that(
- desc = 'aorsf kaplan has same time values as survfit',
- code = {expect_equal(kap$time, aorsf_time, tolerance = 1e-9)}
-)
-
-test_that(
- desc = 'aorsf kaplan has same surv values as survfit',
- code = {expect_equal(kap$surv, aorsf_surv, tolerance = 1e-9)}
-)
diff --git a/tests/testthat/test-orsf_vi.R b/tests/testthat/test-orsf_vi.R
index e66026b6..89f772e0 100644
--- a/tests/testthat/test-orsf_vi.R
+++ b/tests/testthat/test-orsf_vi.R
@@ -32,7 +32,7 @@ test_that(
fit_with_vi <- orsf(pbc_vi,
formula = formula,
importance = importance,
- n_tree = 50,
+ n_tree = 75,
group_factors = group_factors,
tree_seeds = tree_seeds)
@@ -54,7 +54,7 @@ test_that(
fit_no_vi <- orsf(pbc_vi,
formula = formula,
importance = 'none',
- n_tree = 50,
+ n_tree = 75,
group_factors = group_factors,
tree_seeds = tree_seeds)
@@ -69,7 +69,7 @@ test_that(
fit_custom_oobag <- orsf(pbc_vi,
formula = formula,
importance = importance,
- n_tree = 50,
+ n_tree = 75,
oobag_fun = oobag_c_risk,
group_factors = group_factors,
tree_seeds = tree_seeds)
@@ -86,7 +86,7 @@ test_that(
fit_threads <- orsf(pbc_vi,
formula = formula,
importance = importance,
- n_tree = 50,
+ n_tree = 75,
n_thread = 0,
group_factors = group_factors,
tree_seeds = tree_seeds)
@@ -98,7 +98,7 @@ test_that(
good_vars <- c('bili',
'protime',
- if(group_factors) 'edema' else "edema_1")
+ if(group_factors) 'edema' else c("edema_1", "edema_0.5"))
bad_vars <- setdiff(names(vi_during_fit), good_vars)
diff --git a/tests/testthat/test-sprout_node.R b/tests/testthat/test-sprout_node.R
new file mode 100644
index 00000000..420da30f
--- /dev/null
+++ b/tests/testthat/test-sprout_node.R
@@ -0,0 +1,45 @@
+
+test_that(
+ desc = 'leaf node stats have same time/surv/chaz as survfit',
+ code = {
+
+ for(i in seq_along(mat_list_surv)){
+
+ y <- mat_list_surv[[i]]$y
+ w <- mat_list_surv[[i]]$w
+ r <- sprout_node_survival_exported(y, w)
+
+ aorsf_surv <- r$prob[[1]]
+ aorsf_chaz <- r$chaz[[1]]
+ aorsf_time <- r$indx[[1]]
+ aorsf_mort <- r$mort[[1]]
+
+ kap_fit <- survfit(Surv(time, status) ~ 1,
+ data = as.data.frame(y),
+ weights = w)
+
+ kap_data <- data.frame(time = kap_fit$time,
+ surv = kap_fit$surv,
+ cumhaz = kap_fit$cumhaz,
+ n_event = kap_fit$n.event)
+
+ kap_data <- subset(kap_data, n_event > 0)
+
+ expect_equal(kap_data$time, aorsf_time, tolerance = 1e-9)
+ expect_equal(kap_data$surv, aorsf_surv, tolerance = 1e-9)
+ expect_equal(kap_data$cumhaz, aorsf_chaz, tolerance = 1e-9)
+ expect_equal(sum(kap_data$cumhaz), aorsf_mort, tolerance = 1e-9)
+
+ }
+
+ }
+)
+
+
+
+
+
+
+
+
+