diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 1e98a03c..159f3182 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -377,17 +377,44 @@ ObliqueForest <- R6::R6Class( oobag, type_output){ - public_state <- list(data = self$data, - na_action = self$na_action) + public_state <- list(data = self$data, + na_action = self$na_action, + pred_horizon = self$pred_horizon) private_state <- list(data_rows_complete = private$data_rows_complete) + private$check_boundary_checks(boundary_checks) + private$check_pred_spec(pred_spec, boundary_checks) + private$check_n_thread(n_thread) + private$check_expand_grid(expand_grid) + private$check_oobag_pred_mode(oobag, label = 'oobag') + + prob_values <- prob_values %||% c(0.025, 0.50, 0.975) + prob_labels <- prob_labels %||% c('lwr', 'medn', 'upr') + + private$check_prob_values(prob_values) + private$check_prob_labels(prob_labels) + + if(length(prob_values) != length(prob_labels)){ + stop("prob_values and prob_labels must have the same length.", + call. = FALSE) + } + + # oobag=FALSE to match the format of arg in orsf_pd(). + private$check_pred_type(pred_type, oobag = FALSE) + + pred_type <- pred_type %||% self$pred_type + + private$check_pred_horizon(pred_horizon, boundary_checks, pred_type) + + if(is.null(pred_horizon)) pred_horizon <- 1 + pred_horizon_order <- order(pred_horizon) + pred_horizon_ordered <- pred_horizon[pred_horizon_order] + # run checks before you assign new values to object. # otherwise, if a check throws an error, the object will # not be restored to its normal state. - private$check_n_thread(n_thread) - private$check_pred_spec(pred_spec, boundary_checks) if(!oobag){ private$check_data(new = TRUE, data = pd_data) @@ -398,7 +425,7 @@ ObliqueForest <- R6::R6Class( self$data <- pd_data } - + self$pred_horizon <- pred_horizon self$na_action <- na_action # make a visible binding for CRAN @@ -492,8 +519,7 @@ ObliqueForest <- R6::R6Class( } - pred_horizon_order <- order(pred_horizon) - pred_horizon_ordered <- pred_horizon[pred_horizon_order] + cpp_args <- private$prep_cpp_args(x = private$x, y = private$y, @@ -511,11 +537,11 @@ ObliqueForest <- R6::R6Class( verbosity = 0) - results <- list() + pd_vals <- list() for(i in seq_along(pred_spec_new)){ - results_i <- list() + pd_vals_i <- list() x_pd <- private$x @@ -525,13 +551,13 @@ ObliqueForest <- R6::R6Class( cpp_args$x <- x_pd - results_i[[j]] <- do.call(orsf_cpp, cpp_args)$pred_new + pd_vals_i[[j]] <- do.call(orsf_cpp, cpp_args)$pred_new } if(type_output == 'smry'){ - results_i <- lapply( - results_i, + pd_vals_i <- lapply( + pd_vals_i, function(x) { apply(x, 2, function(x_col){ as.numeric( @@ -544,29 +570,43 @@ ObliqueForest <- R6::R6Class( } - results[[i]] <- results_i + pd_vals[[i]] <- pd_vals_i } - pd_vals <- results - for(i in seq_along(pd_vals)){ pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]])) for(j in seq_along(pd_vals[[i]])){ - pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]], - nrow=length(pred_horizon), - byrow = T) - rownames(pd_vals[[i]][[j]]) <- pred_horizon + pd_vals[[i]][[j]] + + if(self$tree_type == 'survival'){ + + pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]], + nrow=length(pred_horizon), + byrow = T) + + rownames(pd_vals[[i]][[j]]) <- pred_horizon + + } else { + + pd_vals[[i]][[j]] <- t(pd_vals[[i]][[j]]) + + if(self$tree_type == 'classification'){ + rownames(pd_vals[[i]][[j]]) <- self$class_levels + } + + } if(type_output=='smry') colnames(pd_vals[[i]][[j]]) <- c('mean', prob_labels) else colnames(pd_vals[[i]][[j]]) <- c(paste(1:nrow(private$x))) + # this will be null for non-survival objects ph <- rownames(pd_vals[[i]][[j]]) pd_vals[[i]][[j]] <- as.data.frame(pd_vals[[i]][[j]]) @@ -618,16 +658,21 @@ ObliqueForest <- R6::R6Class( setcolorder(out, neworder = c(ids, mid, end)) - out[, pred_horizon := as.numeric(pred_horizon)] + if(self$tree_type == 'classification'){ + setnames(out, old = 'pred_horizon', new = 'class') + } + + if(self$tree_type == 'survival' && pred_type != 'mort') + out[, pred_horizon := as.numeric(pred_horizon)] + + if(pred_type == 'mort'){ + out[, pred_horizon := NULL] + } # not needed for summary if(type_output == 'smry') out[, id_variable := NULL] - # not needed for mort - if(pred_type == 'mort') - out[, pred_horizon := NULL] - # put data back into original scale for(j in intersect(names(means), names(pred_spec))){ @@ -744,7 +789,11 @@ ObliqueForest <- R6::R6Class( # check incoming values if they were specified. private$check_n_variables(n_variables) - private$check_pred_horizon(pred_horizon, boundary_checks = TRUE) + + if(!is.null(pred_horizon)){ + private$check_pred_horizon(pred_horizon, boundary_checks = TRUE) + } + private$check_pred_type(pred_type, oobag = FALSE) private$check_importance_type(importance_type) @@ -861,6 +910,15 @@ ObliqueForest <- R6::R6Class( return(private$data_names$y) }, + get_var_bounds = function(.name){ + + if(.name %in% private$data_names$x_numeric) + return(private$data_bounds[, .name]) + else + return(private$data_fctrs$lvls[[.name]]) + + }, + get_var_type = function(.name){ return(class(self$data[[.name]])[1]) }, @@ -2081,6 +2139,72 @@ ObliqueForest <- R6::R6Class( + }, + + check_boundary_checks = function(boundary_checks){ + + # not a field so boundary_checks should never be null + + check_arg_type(arg_value = boundary_checks, + arg_name = 'boundary_checks', + expected_type = 'logical') + + check_arg_length(arg_value = boundary_checks, + arg_name = 'boundary_checks', + expected_length = 1) + + + }, + + check_expand_grid = function(expand_grid){ + + # not a field so boundary_checks should never be null + check_arg_type(arg_value = expand_grid, + arg_name = 'expand_grid', + expected_type = 'logical') + + check_arg_length(arg_value = expand_grid, + arg_name = 'expand_grid', + expected_length = 1) + + }, + + check_prob_values = function(prob_values){ + + check_arg_type(arg_value = prob_values, + arg_name = 'prob_values', + expected_type = 'numeric') + + check_arg_gteq(arg_value = prob_values, + arg_name = 'prob_values', + bound = 0) + + check_arg_lteq(arg_value = prob_values, + arg_name = 'prob_values', + bound = 1) + + }, + + check_prob_labels = function(prob_labels){ + + + check_arg_type(arg_value = prob_labels, + arg_name = 'prob_labels', + expected_type = 'character') + + }, + + check_oobag_pred_mode = function(oobag_pred_mode, label){ + + check_arg_type(arg_value = oobag_pred_mode, + arg_name = label, + expected_type = 'logical') + + check_arg_length(arg_value = oobag_pred_mode, + arg_name = label, + expected_length = 1) + + }, # computers @@ -2092,6 +2216,7 @@ ObliqueForest <- R6::R6Class( }, + compute_modes = function(){ private$data_modes <- vapply( @@ -2463,10 +2588,21 @@ ObliqueForestSurvival <- R6::R6Class( }, - check_pred_horizon = function(pred_horizon = NULL, boundary_checks = TRUE){ + check_pred_horizon = function(pred_horizon = NULL, + boundary_checks = TRUE, + pred_type = NULL){ + pred_type <- pred_type %||% self$pred_type input <- pred_horizon %||% self$pred_horizon + if(is.null(input) && pred_type %in% c('risk', 'surv', 'chf')){ + + stop("pred_horizon must be specified for ", + pred_type, " predictions.", call. = FALSE) + + } + + if(self$oobag_pred_mode) arg_name <- 'oobag_pred_horizon' else @@ -2804,7 +2940,9 @@ ObliqueForestClassification <- R6::R6Class( cloneable = FALSE, public = list( - n_class = NULL + n_class = NULL, + + class_levels = NULL ), private = list( @@ -2828,6 +2966,15 @@ ObliqueForestClassification <- R6::R6Class( }, + check_pred_horizon = function(pred_horizon = NULL, + boundary_checks = TRUE, + pred_type = NULL){ + + # nothing to check + NULL + + }, + init_internal = function(){ self$tree_type <- "classification" @@ -2845,11 +2992,15 @@ ObliqueForestClassification <- R6::R6Class( y <- self$data[[private$data_names$y]] if(is.factor(y)){ - self$n_class <- length(levels(y)) + self$class_levels <- levels(y) + self$n_class <- length(self$class_levels) } else { - self$n_class <- length(unique(y)) + self$class_levels <- unique(y) + self$n_class <- length(self$class_levels) } + + }, prep_y_internal = function(){ diff --git a/R/orsf_pd.R b/R/orsf_pd.R index 68acf46b..ab910996 100644 --- a/R/orsf_pd.R +++ b/R/orsf_pd.R @@ -80,7 +80,7 @@ orsf_pd_oob <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, prob_values = c(0.025, 0.50, 0.975), prob_labels = c('lwr', 'medn', 'upr'), @@ -110,7 +110,7 @@ orsf_pd_oob <- function(object, orsf_pd_inb <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, prob_values = c(0.025, 0.50, 0.975), prob_labels = c('lwr', 'medn', 'upr'), @@ -146,7 +146,7 @@ orsf_pd_new <- function(object, pred_spec, new_data, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, na_action = 'fail', expand_grid = TRUE, prob_values = c(0.025, 0.50, 0.975), @@ -194,7 +194,7 @@ orsf_pd_new <- function(object, orsf_ice_oob <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -220,7 +220,7 @@ orsf_ice_oob <- function(object, orsf_ice_inb <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -252,7 +252,7 @@ orsf_ice_new <- function(object, pred_spec, new_data, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, na_action = 'fail', expand_grid = TRUE, boundary_checks = TRUE, @@ -308,82 +308,11 @@ orsf_pred_dependence <- function(object, check_arg_is(object, arg_name = 'object', expected_class = 'ObliqueForest') - pred_horizon <- infer_pred_horizon(object, pred_type, pred_horizon) - - if(is.null(prob_values)) prob_values <- c(0.025, 0.50, 0.975) - if(is.null(prob_labels)) prob_labels <- c('lwr', 'medn', 'upr') - if(oobag && is.null(object$data)) stop("no data were found in object. ", "did you use attach_data = FALSE when ", "running orsf()?", call. = FALSE) - if(is.null(pred_horizon) && pred_type %in% c('risk', 'surv', 'chf')){ - stop("pred_horizon must be specified for ", - pred_type, " predictions.", call. = FALSE) - } - - check_arg_type(arg_value = boundary_checks, - arg_name = 'boundary_checks', - expected_type = 'logical') - - check_arg_length(arg_value = boundary_checks, - arg_name = 'boundary_checks', - expected_length = 1) - - check_arg_type(arg_value = expand_grid, - arg_name = 'expand_grid', - expected_type = 'logical') - - check_arg_length(arg_value = expand_grid, - arg_name = 'expand_grid', - expected_length = 1) - - check_arg_type(arg_value = prob_values, - arg_name = 'prob_values', - expected_type = 'numeric') - - check_arg_gteq(arg_value = prob_values, - arg_name = 'prob_values', - bound = 0) - - check_arg_lteq(arg_value = prob_values, - arg_name = 'prob_values', - bound = 1) - - check_arg_type(arg_value = prob_labels, - arg_name = 'prob_labels', - expected_type = 'character') - - if(length(prob_values) != length(prob_labels)){ - stop("prob_values and prob_labels must have the same length.", - call. = FALSE) - } - - check_arg_type(arg_value = oobag, - arg_name = 'oobag', - expected_type = 'logical') - - check_arg_length(arg_value = oobag, - arg_name = 'oobag', - expected_length = 1) - - check_arg_type(arg_value = pred_type, - arg_name = "pred_type", - expected_type = 'character') - - check_arg_length(arg_value = pred_type, - arg_name = "pred_type", - expected_length = 1) - - check_arg_is_valid(arg_value = pred_type, - arg_name = "pred_type", - valid_options = c("risk", - "surv", - "chf", - "mort", - "prob")) - object$compute_dependence(pd_data = pd_data, pred_spec = pred_spec, pred_horizon = pred_horizon, diff --git a/R/orsf_summary.R b/R/orsf_summary.R index 6426cbe0..ed60d6df 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -106,6 +106,7 @@ print.orsf_summary_uni <- function(x, n_variables = NULL, ...){ 'surv' = 'Survival', 'chf' = 'Cumulative hazard', 'mort' = 'Mortality', + 'prob' = "Probability" ) msg_btm <- paste("Predicted", tolower(pred_label), diff --git a/man/orsf_ice_oob.Rd b/man/orsf_ice_oob.Rd index f29b6aa0..ad2a1968 100644 --- a/man/orsf_ice_oob.Rd +++ b/man/orsf_ice_oob.Rd @@ -10,7 +10,7 @@ orsf_ice_oob( object, pred_spec, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -21,7 +21,7 @@ orsf_ice_inb( object, pred_spec, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -33,7 +33,7 @@ orsf_ice_new( pred_spec, new_data, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, na_action = "fail", expand_grid = TRUE, boundary_checks = TRUE, @@ -151,18 +151,18 @@ ice_oob <- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE) ice_oob }\if{html}{\out{}} -\if{html}{\out{