Skip to content

Commit

Permalink
submat mult in place and tested
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 14, 2023
1 parent 4941f7c commit db74186
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 4 deletions.
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ find_rows_inbag_exported <- function(rows_oobag, n_obs) {
.Call(`_aorsf_find_rows_inbag_exported`, rows_oobag, n_obs)
}

x_submat_mult_beta_exported <- function(x, y, w, x_rows, x_cols, beta) {
.Call(`_aorsf_x_submat_mult_beta_exported`, x, y, w, x_rows, x_cols, beta)
}

cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}
Expand Down
17 changes: 17 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// x_submat_mult_beta_exported
arma::vec x_submat_mult_beta_exported(arma::mat& x, arma::mat& y, arma::vec& w, arma::uvec& x_rows, arma::uvec& x_cols, arma::vec& beta);
RcppExport SEXP _aorsf_x_submat_mult_beta_exported(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP x_rowsSEXP, SEXP x_colsSEXP, SEXP betaSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type x(xSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
Rcpp::traits::input_parameter< arma::uvec& >::type x_rows(x_rowsSEXP);
Rcpp::traits::input_parameter< arma::uvec& >::type x_cols(x_colsSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type beta(betaSEXP);
rcpp_result_gen = Rcpp::wrap(x_submat_mult_beta_exported(x, y, w, x_rows, x_cols, beta));
return rcpp_result_gen;
END_RCPP
}
// cph_scale
List cph_scale(arma::mat& x, arma::vec& w);
RcppExport SEXP _aorsf_cph_scale(SEXP xSEXP, SEXP wSEXP) {
Expand Down Expand Up @@ -198,6 +214,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_aorsf_find_cuts_survival_exported", (DL_FUNC) &_aorsf_find_cuts_survival_exported, 6},
{"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2},
{"_aorsf_find_rows_inbag_exported", (DL_FUNC) &_aorsf_find_rows_inbag_exported, 2},
{"_aorsf_x_submat_mult_beta_exported", (DL_FUNC) &_aorsf_x_submat_mult_beta_exported, 6},
{"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2},
{"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44},
{NULL, NULL, 0}
Expand Down
6 changes: 4 additions & 2 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1128,13 +1128,15 @@

if(obs_in_node.size() > 0){

x_node = prediction_data->x_submat(obs_in_node, coef_indices[i]);
lincomb = prediction_data->x_submat_mult_beta(obs_in_node,
coef_indices[i],
coef_values[i]);

it = obs_in_node.begin();

for(uword j = 0; j < obs_in_node.size(); ++j, ++it){

if(dot(x_node.row(j), coef_values[i]) <= cutpoint[i]) {
if(lincomb[j] <= cutpoint[i]) {

pred_leaf[*it] = child_left[i];

Expand Down
17 changes: 16 additions & 1 deletion src/orsf_oop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Authors:
- Byron C. Jaeger (http://byronjaeger.com)
- test
#----------------------------------------------------------------------------*/

#include <RcppArmadillo.h>
Expand Down Expand Up @@ -171,6 +170,22 @@

}

// [[Rcpp::export]]
arma::vec x_submat_mult_beta_exported(arma::mat& x,
arma::mat& y,
arma::vec& w,
arma::uvec& x_rows,
arma::uvec& x_cols,
arma::vec& beta){

std::unique_ptr<Data> data = std::make_unique<Data>(x, y, w);

vec out = data->x_submat_mult_beta(x_rows, x_cols, beta);

return(out);

}

// [[Rcpp::export]]
List cph_scale(arma::mat& x, arma::vec& w){

Expand Down
24 changes: 24 additions & 0 deletions tests/testthat/test-DataCpp.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@


x_rows <- sample(nrow(pbc_mats$x), size = 100)
x_cols <- sample(ncol(pbc_mats$x), size = 10)

beta <- runif(n = length(x_cols))

test_that(
desc = "submatrix multiplication is correct",
code = {

data_cpp_answer <- x_submat_mult_beta_exported(x = pbc_mats$x,
y = pbc_mats$y,
w = pbc_mats$w,
x_rows = x_rows - 1,
x_cols = x_cols - 1,
beta = beta)

target <- pbc_mats$x[x_rows, x_cols] %*% beta

expect_equal(data_cpp_answer, target)

})

2 changes: 1 addition & 1 deletion tests/testthat/test-performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# rfsrc = predict(fit_rfsrc, newdata = pbc),
# times = 50
# )
#
# #
# fit_orsf <- orsf(flc, time + status ~ ., n_thread = 0, leaf_min_obs = 10)
#
# fit_rfsrc <- randomForestSRC::rfsrc(Surv(time, status) ~ .,
Expand Down

0 comments on commit db74186

Please sign in to comment.