Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Dec 20, 2023
2 parents f9423a9 + 1ec02cc commit 461a1e0
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 52 deletions.
38 changes: 28 additions & 10 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@

# TODO:
# - add nocov to cpp
# - automatic bounds for pd (better interface)
# - tests for check_oobag_eval_function
# - tests for survival forest w/no censored
# - tests for check_oobag_eval_function_internal


# ObliqueForest class ----
Expand Down Expand Up @@ -592,6 +589,8 @@ ObliqueForest <- R6::R6Class(
oobag,
type_output){

na_action <- na_action %||% self$na_action

public_state <- list(data = self$data,
na_action = self$na_action,
pred_horizon = self$pred_horizon,
Expand All @@ -614,7 +613,9 @@ ObliqueForest <- R6::R6Class(
pred_spec <- list_init(pred_spec)

for(i in names(pred_spec)){
pred_spec[[i]] <- self$get_var_bounds(i)

pred_spec[[i]] <- unique(self$get_var_bounds(i))

}

} else if (inherits(pred_spec, 'pspec_intr')){
Expand Down Expand Up @@ -847,9 +848,10 @@ ObliqueForest <- R6::R6Class(

if(i %in% colnames(bounds)){

pred_spec[[i]] <- unique(
as.numeric(bounds[c('25%','50%','75%'), i])
)
pred_spec[[i]] <- self$get_var_bounds(i)
# unique(
# as.numeric(bounds[c('25%','50%','75%'), i])
# )

} else if (i %in% fctrs$cols) {

Expand All @@ -865,6 +867,7 @@ ObliqueForest <- R6::R6Class(
pred_type = pred_type,
prob_values = c(0.25, 0.50, 0.75),
pred_horizon = pred_horizon,
boundary_checks = FALSE,
verbose_progress = verbose_progress)

fctrs_unordered <- c()
Expand Down Expand Up @@ -978,11 +981,26 @@ ObliqueForest <- R6::R6Class(

get_var_bounds = function(.name){

if(.name %in% private$data_names$x_numeric)
return(as.numeric(private$data_bounds[, .name]))
else
if(.name %in% private$data_names$x_numeric){

out <- unique(as.numeric(private$data_bounds[, .name]))

if(length(out) < 5){
# too few unique values to use quantiles,
# so use the most common unique values instead.
unis <- sort(table(self$data[[.name]]), decreasing = TRUE)
n_items <- min(5, length(unis))
out <- sort(as.numeric(names(unis)[seq(n_items)]))
}

return(out)

} else {

return(private$data_fctrs$lvls[[.name]])

}

},

get_var_type = function(.name){
Expand Down
2 changes: 1 addition & 1 deletion R/orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ orsf_dependence <- function(object,
pred_spec,
pred_horizon,
pred_type,
na_action = 'fail',
na_action = NULL,
expand_grid,
prob_values = NULL,
prob_labels = NULL,
Expand Down
2 changes: 1 addition & 1 deletion R/orsf_vint.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ orsf_vint <- function(object,
pred_spec = pspec,
pred_horizon = NULL,
pred_type = ptype,
na_action = 'fail',
na_action = object$na_action,
expand_grid = FALSE,
prob_values = NULL,
prob_labels = NULL,
Expand Down
36 changes: 18 additions & 18 deletions man/orsf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 69 additions & 14 deletions man/orsf_pd_oob.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/testthat/test-impute_meanmode.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ fit_miss <- orsf(pbc_miss,
formula = time + status ~ .,
na_action = 'impute_meanmode')

test_that(
desc = "missingness rule is passed to pd / vi functions and retained in object",
code = {

pd_miss <- orsf_pd_oob(fit_miss, pred_spec_auto(ascites))
vi_miss <- orsf_vi(fit_miss, importance = 'permute')

expect_equal(fit_miss$na_action, 'impute_meanmode')

}
)


impute_values <- c(fit_miss$get_means(),
fit_miss$get_modes())

Expand Down
41 changes: 41 additions & 0 deletions tests/testthat/test-orsf_vint.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

set.seed(329)

data <- data.frame(
x1 = rnorm(500),
x2 = rnorm(500),
x3 = rnorm(500)
)

data$y = with(data, expr = x1 + x2 + x3 + 1/2*x1*x2 + x2*x3 + rnorm(500))

forest <- orsf(data, y ~ ., n_tree = 5)

vints_1 <- orsf_vint(forest)
vints_2 <- orsf_vint(forest, predictors = c("x1", "x3"))

test_that(
desc = "orsf_vint orders interactions correctly",
code = {

expect_equal(vints_1$interaction[1], "x2..x3")
expect_equal(vints_1$interaction[2], "x1..x2")
expect_equal(vints_1$interaction[3], "x1..x3")

}
)

test_that(
desc = "orsf_vint uses only predictors requested",
code = {
expect_equal(nrow(vints_2), 1)
}
)

test_that(
desc = "interaction score does not depend on unused predictors",
code = {
expect_equal(vints_2$score,
vints_1$score[vints_1$interaction == 'x1..x3'])
}
)
Loading

0 comments on commit 461a1e0

Please sign in to comment.