Skip to content

Commit

Permalink
Merge pull request #30 from ciaran-evans/variance-reduction
Browse files Browse the repository at this point in the history
Resolves #29
  • Loading branch information
bcjaeger authored Nov 8, 2023
2 parents af8de1b + f27435d commit a946fc0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
48 changes: 48 additions & 0 deletions src/orsf_oop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,51 @@
return(result);

}


// [[Rcpp::export]]
double compute_var_reduction(arma::vec& y_node,
arma::vec& w_node,
arma::uvec& g_node){

double root_mean = 0, left_mean = 0, right_mean = 0;
double root_w_sum = 0, left_w_sum = 0, right_w_sum = 0;

for(arma::uword i = 0; i < y_node.n_rows; ++i){

double w_i = w_node[i];
double y_i = y_node[i] * w_i;

root_w_sum += w_i;
root_mean += y_i;

if(g_node[i] == 1){
right_w_sum += w_i;
right_mean += y_i;
} else {
left_w_sum += w_i;
left_mean += y_i;
}

}

root_mean /= root_w_sum;
left_mean /= left_w_sum;
right_mean /= right_w_sum;

double ans = 0;

for(arma::uword i = 0; i < y_node.n_rows; ++i){

double w_i = w_node[i];
double y_i = y_node[i];
double g_i = g_node[i];
double obs_mean = g_i*right_mean + (1 - g_i)*left_mean;

ans += w_i * pow(y_i - root_mean, 2) - w_i * pow(y_i - obs_mean, 2);

}
ans /= root_w_sum;
return(ans);
}

38 changes: 38 additions & 0 deletions tests/testthat/test-compute_var_reduction.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

# R version written using matrixStats
var_reduction_R <- function(y, w, g){
(sum(w) - 1)/sum(w) * matrixStats::weightedVar(y, w = w) -
(sum(w*g) - 1)/(sum(w))*matrixStats::weightedVar(y, w = w, idxs = which(g == 1)) -
(sum(w*(1-g)) - 1)/(sum(w))*matrixStats::weightedVar(y, w = w, idxs = which(g == 0))
}

test_that(
desc = 'computed variance reduction close to matrixStats::weightedVar',
code = {

n_runs <- 100

diffs_vec <- vector(mode = 'numeric', length = n_runs)

for(i in seq(n_runs)){

y <- rnorm(100)
w <- runif(100, 0, 2)
g <- rbinom(100, 1, 0.5)
diffs_vec[i] <- abs(compute_var_reduction(y, w, g) -
var_reduction_R(y, w, g))
}

# unweighted is basically identical to cstat from survival
expect_lt(mean(diffs_vec), 1e-6)
}
)


# # The cpp implementation is 80+ times faster than the implementation using
# # matrixStats::weightedVar
# microbenchmark::microbenchmark(
# cpp = compute_var_reduction(y, w, g),
# r = var_reduction_R(y, w, g),
# times = 10000
# )

0 comments on commit a946fc0

Please sign in to comment.