diff --git a/R/RcppExports.R b/R/RcppExports.R index f7afbef5..f99dbe0e 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -33,6 +33,10 @@ find_rows_inbag_exported <- function(rows_oobag, n_obs) { .Call(`_aorsf_find_rows_inbag_exported`, rows_oobag, n_obs) } +x_submat_mult_beta_exported <- function(x, y, w, x_rows, x_cols, beta) { + .Call(`_aorsf_x_submat_mult_beta_exported`, x, y, w, x_rows, x_cols, beta) +} + cph_scale <- function(x, w) { .Call(`_aorsf_cph_scale`, x, w) } diff --git a/src/Data.h b/src/Data.h index c561be3c..e3804c60 100644 --- a/src/Data.h +++ b/src/Data.h @@ -89,6 +89,28 @@ return(w(indices)); } + // multiply X matrix by lincomb coefficients + // without taking a sub-matrix of X + arma::vec x_submat_mult_beta(arma::uvec& x_rows, + arma::uvec& x_cols, + arma::vec& beta){ + + arma::vec out(x_rows.size()); + arma::uword i = 0; + + for(auto row : x_rows){ + arma::uword j = 0; + for(auto col : x_cols){ + out[i] += x.at(row, col) * beta[j]; + j++; + } + i++; + } + + return(out); + + } + void permute_col(arma::uword j, std::mt19937_64& rng){ arma::vec x_j = x.unsafe_col(j); diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 663f2be7..cf64be34 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -122,6 +122,22 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// x_submat_mult_beta_exported +arma::vec x_submat_mult_beta_exported(arma::mat& x, arma::mat& y, arma::vec& w, arma::uvec& x_rows, arma::uvec& x_cols, arma::vec& beta); +RcppExport SEXP _aorsf_x_submat_mult_beta_exported(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP x_rowsSEXP, SEXP x_colsSEXP, SEXP betaSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type x(xSEXP); + Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP); + Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP); + Rcpp::traits::input_parameter< arma::uvec& >::type x_rows(x_rowsSEXP); + Rcpp::traits::input_parameter< arma::uvec& >::type x_cols(x_colsSEXP); + Rcpp::traits::input_parameter< arma::vec& >::type beta(betaSEXP); + rcpp_result_gen = Rcpp::wrap(x_submat_mult_beta_exported(x, y, w, x_rows, x_cols, beta)); + return rcpp_result_gen; +END_RCPP +} // cph_scale List cph_scale(arma::mat& x, arma::vec& w); RcppExport SEXP _aorsf_cph_scale(SEXP xSEXP, SEXP wSEXP) { @@ -198,6 +214,7 @@ static const R_CallMethodDef CallEntries[] = { {"_aorsf_find_cuts_survival_exported", (DL_FUNC) &_aorsf_find_cuts_survival_exported, 6}, {"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2}, {"_aorsf_find_rows_inbag_exported", (DL_FUNC) &_aorsf_find_rows_inbag_exported, 2}, + {"_aorsf_x_submat_mult_beta_exported", (DL_FUNC) &_aorsf_x_submat_mult_beta_exported, 6}, {"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2}, {"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44}, {NULL, NULL, 0} diff --git a/src/Tree.cpp b/src/Tree.cpp index 8e0119e3..a111db3b 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -1128,13 +1128,15 @@ if(obs_in_node.size() > 0){ - x_node = prediction_data->x_submat(obs_in_node, coef_indices[i]); + lincomb = prediction_data->x_submat_mult_beta(obs_in_node, + coef_indices[i], + coef_values[i]); it = obs_in_node.begin(); for(uword j = 0; j < obs_in_node.size(); ++j, ++it){ - if(dot(x_node.row(j), coef_values[i]) <= cutpoint[i]) { + if(lincomb[j] <= cutpoint[i]) { pred_leaf[*it] = child_left[i]; diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index f22b401a..ec9d836f 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -6,7 +6,6 @@ Authors: - Byron C. Jaeger (http://byronjaeger.com) - - test #----------------------------------------------------------------------------*/ #include @@ -171,6 +170,22 @@ } + // [[Rcpp::export]] + arma::vec x_submat_mult_beta_exported(arma::mat& x, + arma::mat& y, + arma::vec& w, + arma::uvec& x_rows, + arma::uvec& x_cols, + arma::vec& beta){ + + std::unique_ptr data = std::make_unique(x, y, w); + + vec out = data->x_submat_mult_beta(x_rows, x_cols, beta); + + return(out); + + } + // [[Rcpp::export]] List cph_scale(arma::mat& x, arma::vec& w){ diff --git a/tests/testthat/test-DataCpp.R b/tests/testthat/test-DataCpp.R new file mode 100644 index 00000000..f7748ef9 --- /dev/null +++ b/tests/testthat/test-DataCpp.R @@ -0,0 +1,24 @@ + + +x_rows <- sample(nrow(pbc_mats$x), size = 100) +x_cols <- sample(ncol(pbc_mats$x), size = 10) + +beta <- runif(n = length(x_cols)) + +test_that( + desc = "submatrix multiplication is correct", + code = { + + data_cpp_answer <- x_submat_mult_beta_exported(x = pbc_mats$x, + y = pbc_mats$y, + w = pbc_mats$w, + x_rows = x_rows - 1, + x_cols = x_cols - 1, + beta = beta) + + target <- pbc_mats$x[x_rows, x_cols] %*% beta + + expect_equal(data_cpp_answer, target) + + }) + diff --git a/tests/testthat/test-performance.R b/tests/testthat/test-performance.R index 47d0be28..923715fc 100644 --- a/tests/testthat/test-performance.R +++ b/tests/testthat/test-performance.R @@ -48,7 +48,7 @@ # rfsrc = predict(fit_rfsrc, newdata = pbc), # times = 50 # ) -# +# # # fit_orsf <- orsf(flc, time + status ~ ., n_thread = 0, leaf_min_obs = 10) # # fit_rfsrc <- randomForestSRC::rfsrc(Surv(time, status) ~ .,