Skip to content

Commit

Permalink
smooth over pd checks
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Nov 9, 2023
1 parent 008fe92 commit 2aa2497
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 153 deletions.
209 changes: 180 additions & 29 deletions R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -377,17 +377,44 @@ ObliqueForest <- R6::R6Class(
oobag,
type_output){

public_state <- list(data = self$data,
na_action = self$na_action)
public_state <- list(data = self$data,
na_action = self$na_action,
pred_horizon = self$pred_horizon)

private_state <- list(data_rows_complete = private$data_rows_complete)

private$check_boundary_checks(boundary_checks)
private$check_pred_spec(pred_spec, boundary_checks)
private$check_n_thread(n_thread)
private$check_expand_grid(expand_grid)
private$check_oobag_pred_mode(oobag, label = 'oobag')

prob_values <- prob_values %||% c(0.025, 0.50, 0.975)
prob_labels <- prob_labels %||% c('lwr', 'medn', 'upr')

private$check_prob_values(prob_values)
private$check_prob_labels(prob_labels)

if(length(prob_values) != length(prob_labels)){
stop("prob_values and prob_labels must have the same length.",
call. = FALSE)
}

# oobag=FALSE to match the format of arg in orsf_pd().
private$check_pred_type(pred_type, oobag = FALSE)

pred_type <- pred_type %||% self$pred_type

private$check_pred_horizon(pred_horizon, boundary_checks, pred_type)

if(is.null(pred_horizon)) pred_horizon <- 1
pred_horizon_order <- order(pred_horizon)
pred_horizon_ordered <- pred_horizon[pred_horizon_order]

# run checks before you assign new values to object.
# otherwise, if a check throws an error, the object will
# not be restored to its normal state.

private$check_n_thread(n_thread)
private$check_pred_spec(pred_spec, boundary_checks)

if(!oobag){
private$check_data(new = TRUE, data = pd_data)
Expand All @@ -398,7 +425,7 @@ ObliqueForest <- R6::R6Class(
self$data <- pd_data
}


self$pred_horizon <- pred_horizon
self$na_action <- na_action

# make a visible binding for CRAN
Expand Down Expand Up @@ -492,8 +519,7 @@ ObliqueForest <- R6::R6Class(

}

pred_horizon_order <- order(pred_horizon)
pred_horizon_ordered <- pred_horizon[pred_horizon_order]


cpp_args <- private$prep_cpp_args(x = private$x,
y = private$y,
Expand All @@ -511,11 +537,11 @@ ObliqueForest <- R6::R6Class(
verbosity = 0)


results <- list()
pd_vals <- list()

for(i in seq_along(pred_spec_new)){

results_i <- list()
pd_vals_i <- list()

x_pd <- private$x

Expand All @@ -525,13 +551,13 @@ ObliqueForest <- R6::R6Class(

cpp_args$x <- x_pd

results_i[[j]] <- do.call(orsf_cpp, cpp_args)$pred_new
pd_vals_i[[j]] <- do.call(orsf_cpp, cpp_args)$pred_new

}

if(type_output == 'smry'){
results_i <- lapply(
results_i,
pd_vals_i <- lapply(
pd_vals_i,
function(x) {
apply(x, 2, function(x_col){
as.numeric(
Expand All @@ -544,29 +570,43 @@ ObliqueForest <- R6::R6Class(
}


results[[i]] <- results_i
pd_vals[[i]] <- pd_vals_i

}

pd_vals <- results

for(i in seq_along(pd_vals)){

pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]]))

for(j in seq_along(pd_vals[[i]])){

pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]],
nrow=length(pred_horizon),
byrow = T)

rownames(pd_vals[[i]][[j]]) <- pred_horizon
pd_vals[[i]][[j]]

if(self$tree_type == 'survival'){

pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]],
nrow=length(pred_horizon),
byrow = T)

rownames(pd_vals[[i]][[j]]) <- pred_horizon

} else {

pd_vals[[i]][[j]] <- t(pd_vals[[i]][[j]])

if(self$tree_type == 'classification'){
rownames(pd_vals[[i]][[j]]) <- self$class_levels
}

}

if(type_output=='smry')
colnames(pd_vals[[i]][[j]]) <- c('mean', prob_labels)
else
colnames(pd_vals[[i]][[j]]) <- c(paste(1:nrow(private$x)))

# this will be null for non-survival objects
ph <- rownames(pd_vals[[i]][[j]])

pd_vals[[i]][[j]] <- as.data.frame(pd_vals[[i]][[j]])
Expand Down Expand Up @@ -618,16 +658,21 @@ ObliqueForest <- R6::R6Class(

setcolorder(out, neworder = c(ids, mid, end))

out[, pred_horizon := as.numeric(pred_horizon)]
if(self$tree_type == 'classification'){
setnames(out, old = 'pred_horizon', new = 'class')
}

if(self$tree_type == 'survival' && pred_type != 'mort')
out[, pred_horizon := as.numeric(pred_horizon)]

if(pred_type == 'mort'){
out[, pred_horizon := NULL]
}

# not needed for summary
if(type_output == 'smry')
out[, id_variable := NULL]

# not needed for mort
if(pred_type == 'mort')
out[, pred_horizon := NULL]

# put data back into original scale
for(j in intersect(names(means), names(pred_spec))){

Expand Down Expand Up @@ -744,7 +789,11 @@ ObliqueForest <- R6::R6Class(

# check incoming values if they were specified.
private$check_n_variables(n_variables)
private$check_pred_horizon(pred_horizon, boundary_checks = TRUE)

if(!is.null(pred_horizon)){
private$check_pred_horizon(pred_horizon, boundary_checks = TRUE)
}

private$check_pred_type(pred_type, oobag = FALSE)
private$check_importance_type(importance_type)

Expand Down Expand Up @@ -861,6 +910,15 @@ ObliqueForest <- R6::R6Class(
return(private$data_names$y)
},

get_var_bounds = function(.name){

if(.name %in% private$data_names$x_numeric)
return(private$data_bounds[, .name])
else
return(private$data_fctrs$lvls[[.name]])

},

get_var_type = function(.name){
return(class(self$data[[.name]])[1])
},
Expand Down Expand Up @@ -2081,6 +2139,72 @@ ObliqueForest <- R6::R6Class(



},

check_boundary_checks = function(boundary_checks){

# not a field so boundary_checks should never be null

check_arg_type(arg_value = boundary_checks,
arg_name = 'boundary_checks',
expected_type = 'logical')

check_arg_length(arg_value = boundary_checks,
arg_name = 'boundary_checks',
expected_length = 1)


},

check_expand_grid = function(expand_grid){

# not a field so boundary_checks should never be null
check_arg_type(arg_value = expand_grid,
arg_name = 'expand_grid',
expected_type = 'logical')

check_arg_length(arg_value = expand_grid,
arg_name = 'expand_grid',
expected_length = 1)

},

check_prob_values = function(prob_values){

check_arg_type(arg_value = prob_values,
arg_name = 'prob_values',
expected_type = 'numeric')

check_arg_gteq(arg_value = prob_values,
arg_name = 'prob_values',
bound = 0)

check_arg_lteq(arg_value = prob_values,
arg_name = 'prob_values',
bound = 1)

},

check_prob_labels = function(prob_labels){


check_arg_type(arg_value = prob_labels,
arg_name = 'prob_labels',
expected_type = 'character')

},

check_oobag_pred_mode = function(oobag_pred_mode, label){

check_arg_type(arg_value = oobag_pred_mode,
arg_name = label,
expected_type = 'logical')

check_arg_length(arg_value = oobag_pred_mode,
arg_name = label,
expected_length = 1)


},

# computers
Expand All @@ -2092,6 +2216,7 @@ ObliqueForest <- R6::R6Class(

},


compute_modes = function(){

private$data_modes <- vapply(
Expand Down Expand Up @@ -2463,10 +2588,21 @@ ObliqueForestSurvival <- R6::R6Class(

},

check_pred_horizon = function(pred_horizon = NULL, boundary_checks = TRUE){
check_pred_horizon = function(pred_horizon = NULL,
boundary_checks = TRUE,
pred_type = NULL){

pred_type <- pred_type %||% self$pred_type
input <- pred_horizon %||% self$pred_horizon

if(is.null(input) && pred_type %in% c('risk', 'surv', 'chf')){

stop("pred_horizon must be specified for ",
pred_type, " predictions.", call. = FALSE)

}


if(self$oobag_pred_mode)
arg_name <- 'oobag_pred_horizon'
else
Expand Down Expand Up @@ -2804,7 +2940,9 @@ ObliqueForestClassification <- R6::R6Class(
cloneable = FALSE,
public = list(

n_class = NULL
n_class = NULL,

class_levels = NULL

),
private = list(
Expand All @@ -2828,6 +2966,15 @@ ObliqueForestClassification <- R6::R6Class(

},

check_pred_horizon = function(pred_horizon = NULL,
boundary_checks = TRUE,
pred_type = NULL){

# nothing to check
NULL

},

init_internal = function(){

self$tree_type <- "classification"
Expand All @@ -2845,11 +2992,15 @@ ObliqueForestClassification <- R6::R6Class(
y <- self$data[[private$data_names$y]]

if(is.factor(y)){
self$n_class <- length(levels(y))
self$class_levels <- levels(y)
self$n_class <- length(self$class_levels)
} else {
self$n_class <- length(unique(y))
self$class_levels <- unique(y)
self$n_class <- length(self$class_levels)
}



},

prep_y_internal = function(){
Expand Down
Loading

0 comments on commit 2aa2497

Please sign in to comment.