-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resolves ropensci/aorsf#29 #30
Conversation
This is awesome, thank you @ciaran-evans! I tinkered with the After checking the two functions, we can finish this PR by doing the following:
I would be happy to do these steps myself. Just let me know if you'd like to take the lead on them. library(microbenchmark)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
arma::vec w_left = w_node % (1 - g_node);
arma::vec w_right = w_node % g_node;
double root_mean = sum(y_node % w_node)/sum(w_node);
double left_mean = sum(y_node % w_left)/sum(w_left);
double right_mean = sum(y_node % w_right)/sum(w_right);
return (sum(w_node % pow(y_node - root_mean, 2)) - sum(w_left % pow(y_node - left_mean, 2)) -
sum(w_right % pow(y_node - right_mean, 2)))/sum(w_node);
}
"
)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction_2(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
double root_mean = 0, left_mean = 0, right_mean = 0;
double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i] * w_i;
root_w_sum += w_i;
root_mean += y_i;
if(g_node[i] == 1){
right_w_sum += w_i;
right_mean += y_i;
} else {
left_w_sum += w_i;
left_mean += y_i;
}
}
root_mean /= root_w_sum;
left_mean /= left_w_sum;
right_mean /= right_w_sum;
double ans = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i];
ans += w_i * pow(y_i - root_mean, 2);
if(g_node[i] == 1){
ans -= w_i * pow(y_i - right_mean, 2);
} else {
ans -= w_i * pow(y_i - left_mean, 2);
}
}
ans /= root_w_sum;
return(ans);
}
"
)
n <- 10000
y <- rnorm(n)
w <- runif(n, 0, 2)
g <- rbinom(n, 1, 0.5)
microbenchmark(
v1 = compute_var_reduction(y, w, g),
v2 = compute_var_reduction_2(y, w, g),
times = 5000
)
#> Unit: microseconds
#> expr min lq mean median uq max neval cld
#> v1 65.001 80.601 111.15622 93.301 109.402 8312.402 5000 a
#> v2 72.400 74.302 84.17134 75.301 87.601 9133.501 5000 b
v1 = compute_var_reduction(y, w, g)
v2 = compute_var_reduction_2(y, w, g)
testthat::expect_true(v1 - v2 < .Machine$double.eps) Created on 2023-10-29 with reprex v2.0.2 |
One last thing. I think the second for loop in |
Thanks so much @bcjaeger ! Using lower level operations does run faster on my machine too. I have modified the second for loop, as you suggested, to avoid the if...else... statement. The modified version (v3 in the code below) does appear to run faster. I have just committed a revised version with the v3 code. library(microbenchmark)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
arma::vec w_left = w_node % (1 - g_node);
arma::vec w_right = w_node % g_node;
double root_mean = sum(y_node % w_node)/sum(w_node);
double left_mean = sum(y_node % w_left)/sum(w_left);
double right_mean = sum(y_node % w_right)/sum(w_right);
return (sum(w_node % pow(y_node - root_mean, 2)) - sum(w_left % pow(y_node - left_mean, 2)) -
sum(w_right % pow(y_node - right_mean, 2)))/sum(w_node);
}
"
)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction_2(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
double root_mean = 0, left_mean = 0, right_mean = 0;
double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i] * w_i;
root_w_sum += w_i;
root_mean += y_i;
if(g_node[i] == 1){
right_w_sum += w_i;
right_mean += y_i;
} else {
left_w_sum += w_i;
left_mean += y_i;
}
}
root_mean /= root_w_sum;
left_mean /= left_w_sum;
right_mean /= right_w_sum;
double ans = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i];
ans += w_i * pow(y_i - root_mean, 2);
if(g_node[i] == 1){
ans -= w_i * pow(y_i - right_mean, 2);
} else {
ans -= w_i * pow(y_i - left_mean, 2);
}
}
ans /= root_w_sum;
return(ans);
}
"
)
Rcpp::cppFunction(
depends = 'RcppArmadillo',
code = "
double compute_var_reduction_3(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){
double root_mean = 0, left_mean = 0, right_mean = 0;
double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i] * w_i;
root_w_sum += w_i;
root_mean += y_i;
if(g_node[i] == 1){
right_w_sum += w_i;
right_mean += y_i;
} else {
left_w_sum += w_i;
left_mean += y_i;
}
}
root_mean /= root_w_sum;
left_mean /= left_w_sum;
right_mean /= right_w_sum;
double ans = 0;
for(arma::uword i = 0; i < y_node.n_rows; ++i){
double w_i = w_node[i];
double y_i = y_node[i];
double g_i = g_node[i];
double obs_mean = g_i*right_mean + (1 - g_i)*left_mean;
ans += w_i * pow(y_i - root_mean, 2) - w_i * pow(y_i - obs_mean, 2);
}
ans /= root_w_sum;
return(ans);
}
"
)
n <- 10000
y <- rnorm(n)
w <- runif(n, 0, 2)
g <- rbinom(n, 1, 0.5)
microbenchmark(
v1 = compute_var_reduction(y, w, g),
v2 = compute_var_reduction_2(y, w, g),
v3 = compute_var_reduction_3(y, w, g),
times = 5000
)
# Unit: microseconds
# expr min lq mean median uq max neval
# v1 80.334 82.376 106.76682 85.542 89.417 9912.667 5000
# v2 42.584 53.167 74.96333 56.459 60.959 10168.834 5000
# v3 34.500 43.167 63.38960 47.000 51.959 13631.459 5000
v1 = compute_var_reduction(y, w, g)
v3 = compute_var_reduction_3(y, w, g)
testthat::expect_true(v1 - v3 < .Machine$double.eps) |
Fantastic =] Awesome work. I am looking forward to reviewing! |
I'm very happy with this! No changes to request on my end. This function will be called hundreds if not thousands of times for a single forest so every bit of efficiency really counts. =] Thank you SO much. If you'd like to work together on this more, I should be focusing a lot more on regression trees soon. |
Wonderful, thanks! I would love to contribute more; feel free to tag me on any issues where I would be useful |
This adds a
compute_var_reduction
function inorsf_oop.cpp
to compute the variance reduction after a possible split, allowing for weights. The function is evaluated for correctness and speed intest-compute_var_reduction.R
.Given vectors$y$ (the response values), $w$ (the weights), and $g$ (the group assignments -- 0 = left, 1 = right), the weighted reduction in variance is
where