Skip to content

Commit

Permalink
deal with constant cols
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Nov 14, 2023
1 parent 3a2a060 commit 6ebb48a
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 10 deletions.
32 changes: 26 additions & 6 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ ObliqueForest <- R6::R6Class(
private$check_data(new = TRUE, data = new_data)
private$check_na_action(new = TRUE, na_action = na_action)
private$check_var_missing(new = TRUE, data = new_data, na_action)
private$check_var_values(new = TRUE, data = new_data)
private$check_units(data = new_data)
private$check_boundary_checks(boundary_checks)
private$check_n_thread(n_thread)
Expand Down Expand Up @@ -751,20 +752,24 @@ ObliqueForest <- R6::R6Class(
# allow re-training.
self$forest <- list()

cpp_args <- private$prep_cpp_args(pred_type = 'mort',
pred_type <- switch(self$tree_type,
'survival' = 'mort',
'classification' = 'prob',
'regression' = 'mean')

cpp_args <- private$prep_cpp_args(pred_type = pred_type,
oobag_pred = TRUE,
importance_group_factors = TRUE,
write_forest = FALSE)

mtry_safe <- self$mtry


while(n_predictors >= n_predictor_min){

if(mtry_safe >= n_predictors){
mtry_safe <- max(mtry_safe - 1, 1)
if(verbose_progress){
cat("Current number of predictors:", n_predictors, "\r")
}

mtry_safe <- ceiling(sqrt(n_predictors))

if(self$control$lincomb_df_target > mtry_safe){
self$control$lincomb_df_target <- mtry_safe
}
Expand Down Expand Up @@ -1076,6 +1081,7 @@ ObliqueForest <- R6::R6Class(
no = nrow(self$data))

private$check_var_missing()
private$check_var_values()

unit_names <- c(names_y_data[types_y_data == 'units'],
names_x_data[types_x_data == 'units'])
Expand Down Expand Up @@ -1374,6 +1380,12 @@ ObliqueForest <- R6::R6Class(

}

},

check_var_values = function(data = NULL, new = FALSE){

input <- data %||% self$data

for(i in private$data_names$x_original){

if(collapse::allNA(input[[i]])){
Expand All @@ -1386,9 +1398,17 @@ ObliqueForest <- R6::R6Class(
call. = FALSE)
}

if(!new){
if(collapse::fnunique(collapse::na_omit(input[[i]])) == 1L){
stop("column ", i, " is constant.",
call. = FALSE)
}
}

}

},

check_formula = function(formula = NULL){

input <- formula %||% self$formula
Expand Down
5 changes: 4 additions & 1 deletion src/TreeClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

TreeClassification::TreeClassification(arma::uword n_class){
this->n_class = n_class;
this->binary = n_class == 2;
}

TreeClassification::TreeClassification(arma::uword n_obs,
Expand Down Expand Up @@ -196,7 +197,9 @@

double cstat_sum = 0;

for(uword i = 0; i < y_oobag.n_cols; i++){
uword start = 0; if(binary) start = 1;

for(uword i = start; i < y_oobag.n_cols; i++){
vec y_i = y_oobag.unsafe_col(i);
vec p_i = preds.unsafe_col(i);
cstat_sum += compute_cstat_clsf(y_i, w_oobag, p_i);
Expand Down
1 change: 1 addition & 0 deletions src/TreeClassification.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
}

arma::uword n_class;
bool binary;

arma::uvec splittable_y_cols;
arma::uword y_col_split;
Expand Down
6 changes: 3 additions & 3 deletions src/utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@
solve_opts::no_approx);

if(!nonsingular){
mat result(beta.size(), 2, fill::zeros);
mat result(x_node.n_cols, 2, fill::zeros);
return(result);
}

Expand All @@ -638,7 +638,7 @@
bool invertible = inv(xtx_inverse, X.t() * diagmat(w_node) * X);

if(!invertible) {
mat result(beta.size(), 2, fill::zeros);
mat result(x_node.n_cols, 2, fill::zeros);
return(result);
}

Expand Down Expand Up @@ -721,7 +721,7 @@
double beta_trace = arma::accu(arma::abs(beta));

if(beta_trace < std::numeric_limits<double>::epsilon()){
mat result(beta.size(), 2, fill::zeros);
mat result(x_node.n_cols, 2, fill::zeros);
return(result);
}

Expand Down
33 changes: 33 additions & 0 deletions tests/testthat/test-lincomb_linreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,36 @@ test_that(

}
)

test_that(
desc = 'constant columns dont break linreg',
code = {

nrows <- 100
ncols <- 2

X <- matrix(data = rnorm(nrows*ncols), nrow = nrows, ncol = ncols)

# X <- cbind(1, X)

colnames(X) <- c(
# "intercept",
paste0("x", seq(ncols))
)

Y <- matrix(rnorm(nrows), ncol = 1)

W <- rep(1, nrow(X))

X[,1] <- 1

zeros <- matrix(0, ncol = 2, nrow = ncols)

results <- linreg_fit_exported(x_node = X, y_node = Y, w_node = W,
do_scale = T, epsilon = 1e-9, iter_max = 20)

expect_equal(zeros, results)


}
)

0 comments on commit 6ebb48a

Please sign in to comment.