Skip to content

Commit

Permalink
trying to optimize the more precise mat mult
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 8, 2023
1 parent fc9c909 commit 870c741
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
12 changes: 2 additions & 10 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,21 +1095,13 @@

if(obs_in_node.size() > 0){

// lincomb = prediction_data->x_submat(obs_in_node, coef_indices[i]) * coef_values[i];

x_node = prediction_data->x_submat(obs_in_node, coef_indices[i]);
lincomb.set_size(x_node.n_rows);

for(uword k = 0; k < lincomb.size(); k++){
vec new_lincomb_value = (x_node.row(k) * coef_values[i]);
lincomb[k] = new_lincomb_value(0);
}

it = obs_in_node.begin();

for(uword j = 0; j < lincomb.size(); ++j, ++it){
for(uword j = 0; j < obs_in_node.size(); ++j, ++it){

if(lincomb[j] <= cutpoint[i]) {
if(dot(x_node.row(j), coef_values[i]) <= cutpoint[i]) {

pred_leaf[*it] = child_left[i];

Expand Down
39 changes: 35 additions & 4 deletions tests/testthat/test-performance.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@



# flc %>% dim()
#
#
#
# # grow ----
#
# microbenchmark::microbenchmark(
# aorsf = orsf(flc, time + status ~ .,
Expand Down Expand Up @@ -31,3 +31,34 @@
# min.node.size = 10),
# times = 10
# )
#
# # predict ----
# source("tests/testthat/setup.R")
#
# fit_orsf <- orsf(pbc, time + status ~ ., n_thread = 0, leaf_min_obs = 10)
#
# fit_rfsrc <- randomForestSRC::rfsrc(Surv(time, status) ~ .,
# data = pbc,
# nodesize = 10,
# samptype = 'swr')
#
#
# microbenchmark::microbenchmark(
# orsf = predict(fit_orsf, new_data = pbc),
# rfsrc = predict(fit_rfsrc, newdata = pbc),
# times = 50
# )
#
# fit_orsf <- orsf(flc, time + status ~ ., n_thread = 0, leaf_min_obs = 10)
#
# fit_rfsrc <- randomForestSRC::rfsrc(Surv(time, status) ~ .,
# data = flc,
# nodesize = 10,
# samptype = 'swr')
#
#
# microbenchmark::microbenchmark(
# orsf = predict(fit_orsf, new_data = flc),
# rfsrc = predict(fit_rfsrc, newdata = flc),
# times = 3
# )

0 comments on commit 870c741

Please sign in to comment.