From 870c741627e850d932e466ab7db821324dbbd496 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sun, 8 Oct 2023 14:31:42 -0400 Subject: [PATCH] trying to optimize the more precise mat mult --- src/Tree.cpp | 12 ++-------- tests/testthat/test-performance.R | 39 +++++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/Tree.cpp b/src/Tree.cpp index beeb1b6d..14c7b467 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -1095,21 +1095,13 @@ if(obs_in_node.size() > 0){ - // lincomb = prediction_data->x_submat(obs_in_node, coef_indices[i]) * coef_values[i]; - x_node = prediction_data->x_submat(obs_in_node, coef_indices[i]); - lincomb.set_size(x_node.n_rows); - - for(uword k = 0; k < lincomb.size(); k++){ - vec new_lincomb_value = (x_node.row(k) * coef_values[i]); - lincomb[k] = new_lincomb_value(0); - } it = obs_in_node.begin(); - for(uword j = 0; j < lincomb.size(); ++j, ++it){ + for(uword j = 0; j < obs_in_node.size(); ++j, ++it){ - if(lincomb[j] <= cutpoint[i]) { + if(dot(x_node.row(j), coef_values[i]) <= cutpoint[i]) { pred_leaf[*it] = child_left[i]; diff --git a/tests/testthat/test-performance.R b/tests/testthat/test-performance.R index 327f7d9b..47d0be28 100644 --- a/tests/testthat/test-performance.R +++ b/tests/testthat/test-performance.R @@ -1,7 +1,7 @@ - - - -# flc %>% dim() +# +# +# +# # grow ---- # # microbenchmark::microbenchmark( # aorsf = orsf(flc, time + status ~ ., @@ -31,3 +31,34 @@ # min.node.size = 10), # times = 10 # ) +# +# # predict ---- +# source("tests/testthat/setup.R") +# +# fit_orsf <- orsf(pbc, time + status ~ ., n_thread = 0, leaf_min_obs = 10) +# +# fit_rfsrc <- randomForestSRC::rfsrc(Surv(time, status) ~ ., +# data = pbc, +# nodesize = 10, +# samptype = 'swr') +# +# +# microbenchmark::microbenchmark( +# orsf = predict(fit_orsf, new_data = pbc), +# rfsrc = predict(fit_rfsrc, newdata = pbc), +# times = 50 +# ) +# +# fit_orsf <- orsf(flc, time + status ~ ., n_thread = 0, leaf_min_obs = 10) +# +# fit_rfsrc <- randomForestSRC::rfsrc(Surv(time, status) ~ ., +# data = flc, +# nodesize = 10, +# samptype = 'swr') +# +# +# microbenchmark::microbenchmark( +# orsf = predict(fit_orsf, new_data = flc), +# rfsrc = predict(fit_rfsrc, newdata = flc), +# times = 3 +# )