From 2aa2497e1e193c9f9a6c2fbe8b29864d58e6398f Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Thu, 9 Nov 2023 08:57:15 -0500 Subject: [PATCH] smooth over pd checks --- R/orsf_R6.R | 209 +++++++++++++++++++++++++++++----- R/orsf_pd.R | 83 +------------- R/orsf_summary.R | 1 + man/orsf_ice_oob.Rd | 30 ++--- man/orsf_pd_oob.Rd | 42 +++---- man/orsf_summarize_uni.Rd | 14 +-- src/Forest.cpp | 2 +- tests/testthat/test-orsf_pd.R | 5 +- 8 files changed, 233 insertions(+), 153 deletions(-) diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 1e98a03c..159f3182 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -377,17 +377,44 @@ ObliqueForest <- R6::R6Class( oobag, type_output){ - public_state <- list(data = self$data, - na_action = self$na_action) + public_state <- list(data = self$data, + na_action = self$na_action, + pred_horizon = self$pred_horizon) private_state <- list(data_rows_complete = private$data_rows_complete) + private$check_boundary_checks(boundary_checks) + private$check_pred_spec(pred_spec, boundary_checks) + private$check_n_thread(n_thread) + private$check_expand_grid(expand_grid) + private$check_oobag_pred_mode(oobag, label = 'oobag') + + prob_values <- prob_values %||% c(0.025, 0.50, 0.975) + prob_labels <- prob_labels %||% c('lwr', 'medn', 'upr') + + private$check_prob_values(prob_values) + private$check_prob_labels(prob_labels) + + if(length(prob_values) != length(prob_labels)){ + stop("prob_values and prob_labels must have the same length.", + call. = FALSE) + } + + # oobag=FALSE to match the format of arg in orsf_pd(). + private$check_pred_type(pred_type, oobag = FALSE) + + pred_type <- pred_type %||% self$pred_type + + private$check_pred_horizon(pred_horizon, boundary_checks, pred_type) + + if(is.null(pred_horizon)) pred_horizon <- 1 + pred_horizon_order <- order(pred_horizon) + pred_horizon_ordered <- pred_horizon[pred_horizon_order] + # run checks before you assign new values to object. # otherwise, if a check throws an error, the object will # not be restored to its normal state. - private$check_n_thread(n_thread) - private$check_pred_spec(pred_spec, boundary_checks) if(!oobag){ private$check_data(new = TRUE, data = pd_data) @@ -398,7 +425,7 @@ ObliqueForest <- R6::R6Class( self$data <- pd_data } - + self$pred_horizon <- pred_horizon self$na_action <- na_action # make a visible binding for CRAN @@ -492,8 +519,7 @@ ObliqueForest <- R6::R6Class( } - pred_horizon_order <- order(pred_horizon) - pred_horizon_ordered <- pred_horizon[pred_horizon_order] + cpp_args <- private$prep_cpp_args(x = private$x, y = private$y, @@ -511,11 +537,11 @@ ObliqueForest <- R6::R6Class( verbosity = 0) - results <- list() + pd_vals <- list() for(i in seq_along(pred_spec_new)){ - results_i <- list() + pd_vals_i <- list() x_pd <- private$x @@ -525,13 +551,13 @@ ObliqueForest <- R6::R6Class( cpp_args$x <- x_pd - results_i[[j]] <- do.call(orsf_cpp, cpp_args)$pred_new + pd_vals_i[[j]] <- do.call(orsf_cpp, cpp_args)$pred_new } if(type_output == 'smry'){ - results_i <- lapply( - results_i, + pd_vals_i <- lapply( + pd_vals_i, function(x) { apply(x, 2, function(x_col){ as.numeric( @@ -544,29 +570,43 @@ ObliqueForest <- R6::R6Class( } - results[[i]] <- results_i + pd_vals[[i]] <- pd_vals_i } - pd_vals <- results - for(i in seq_along(pd_vals)){ pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]])) for(j in seq_along(pd_vals[[i]])){ - pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]], - nrow=length(pred_horizon), - byrow = T) - rownames(pd_vals[[i]][[j]]) <- pred_horizon + pd_vals[[i]][[j]] + + if(self$tree_type == 'survival'){ + + pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]], + nrow=length(pred_horizon), + byrow = T) + + rownames(pd_vals[[i]][[j]]) <- pred_horizon + + } else { + + pd_vals[[i]][[j]] <- t(pd_vals[[i]][[j]]) + + if(self$tree_type == 'classification'){ + rownames(pd_vals[[i]][[j]]) <- self$class_levels + } + + } if(type_output=='smry') colnames(pd_vals[[i]][[j]]) <- c('mean', prob_labels) else colnames(pd_vals[[i]][[j]]) <- c(paste(1:nrow(private$x))) + # this will be null for non-survival objects ph <- rownames(pd_vals[[i]][[j]]) pd_vals[[i]][[j]] <- as.data.frame(pd_vals[[i]][[j]]) @@ -618,16 +658,21 @@ ObliqueForest <- R6::R6Class( setcolorder(out, neworder = c(ids, mid, end)) - out[, pred_horizon := as.numeric(pred_horizon)] + if(self$tree_type == 'classification'){ + setnames(out, old = 'pred_horizon', new = 'class') + } + + if(self$tree_type == 'survival' && pred_type != 'mort') + out[, pred_horizon := as.numeric(pred_horizon)] + + if(pred_type == 'mort'){ + out[, pred_horizon := NULL] + } # not needed for summary if(type_output == 'smry') out[, id_variable := NULL] - # not needed for mort - if(pred_type == 'mort') - out[, pred_horizon := NULL] - # put data back into original scale for(j in intersect(names(means), names(pred_spec))){ @@ -744,7 +789,11 @@ ObliqueForest <- R6::R6Class( # check incoming values if they were specified. private$check_n_variables(n_variables) - private$check_pred_horizon(pred_horizon, boundary_checks = TRUE) + + if(!is.null(pred_horizon)){ + private$check_pred_horizon(pred_horizon, boundary_checks = TRUE) + } + private$check_pred_type(pred_type, oobag = FALSE) private$check_importance_type(importance_type) @@ -861,6 +910,15 @@ ObliqueForest <- R6::R6Class( return(private$data_names$y) }, + get_var_bounds = function(.name){ + + if(.name %in% private$data_names$x_numeric) + return(private$data_bounds[, .name]) + else + return(private$data_fctrs$lvls[[.name]]) + + }, + get_var_type = function(.name){ return(class(self$data[[.name]])[1]) }, @@ -2081,6 +2139,72 @@ ObliqueForest <- R6::R6Class( + }, + + check_boundary_checks = function(boundary_checks){ + + # not a field so boundary_checks should never be null + + check_arg_type(arg_value = boundary_checks, + arg_name = 'boundary_checks', + expected_type = 'logical') + + check_arg_length(arg_value = boundary_checks, + arg_name = 'boundary_checks', + expected_length = 1) + + + }, + + check_expand_grid = function(expand_grid){ + + # not a field so boundary_checks should never be null + check_arg_type(arg_value = expand_grid, + arg_name = 'expand_grid', + expected_type = 'logical') + + check_arg_length(arg_value = expand_grid, + arg_name = 'expand_grid', + expected_length = 1) + + }, + + check_prob_values = function(prob_values){ + + check_arg_type(arg_value = prob_values, + arg_name = 'prob_values', + expected_type = 'numeric') + + check_arg_gteq(arg_value = prob_values, + arg_name = 'prob_values', + bound = 0) + + check_arg_lteq(arg_value = prob_values, + arg_name = 'prob_values', + bound = 1) + + }, + + check_prob_labels = function(prob_labels){ + + + check_arg_type(arg_value = prob_labels, + arg_name = 'prob_labels', + expected_type = 'character') + + }, + + check_oobag_pred_mode = function(oobag_pred_mode, label){ + + check_arg_type(arg_value = oobag_pred_mode, + arg_name = label, + expected_type = 'logical') + + check_arg_length(arg_value = oobag_pred_mode, + arg_name = label, + expected_length = 1) + + }, # computers @@ -2092,6 +2216,7 @@ ObliqueForest <- R6::R6Class( }, + compute_modes = function(){ private$data_modes <- vapply( @@ -2463,10 +2588,21 @@ ObliqueForestSurvival <- R6::R6Class( }, - check_pred_horizon = function(pred_horizon = NULL, boundary_checks = TRUE){ + check_pred_horizon = function(pred_horizon = NULL, + boundary_checks = TRUE, + pred_type = NULL){ + pred_type <- pred_type %||% self$pred_type input <- pred_horizon %||% self$pred_horizon + if(is.null(input) && pred_type %in% c('risk', 'surv', 'chf')){ + + stop("pred_horizon must be specified for ", + pred_type, " predictions.", call. = FALSE) + + } + + if(self$oobag_pred_mode) arg_name <- 'oobag_pred_horizon' else @@ -2804,7 +2940,9 @@ ObliqueForestClassification <- R6::R6Class( cloneable = FALSE, public = list( - n_class = NULL + n_class = NULL, + + class_levels = NULL ), private = list( @@ -2828,6 +2966,15 @@ ObliqueForestClassification <- R6::R6Class( }, + check_pred_horizon = function(pred_horizon = NULL, + boundary_checks = TRUE, + pred_type = NULL){ + + # nothing to check + NULL + + }, + init_internal = function(){ self$tree_type <- "classification" @@ -2845,11 +2992,15 @@ ObliqueForestClassification <- R6::R6Class( y <- self$data[[private$data_names$y]] if(is.factor(y)){ - self$n_class <- length(levels(y)) + self$class_levels <- levels(y) + self$n_class <- length(self$class_levels) } else { - self$n_class <- length(unique(y)) + self$class_levels <- unique(y) + self$n_class <- length(self$class_levels) } + + }, prep_y_internal = function(){ diff --git a/R/orsf_pd.R b/R/orsf_pd.R index 68acf46b..ab910996 100644 --- a/R/orsf_pd.R +++ b/R/orsf_pd.R @@ -80,7 +80,7 @@ orsf_pd_oob <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, prob_values = c(0.025, 0.50, 0.975), prob_labels = c('lwr', 'medn', 'upr'), @@ -110,7 +110,7 @@ orsf_pd_oob <- function(object, orsf_pd_inb <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, prob_values = c(0.025, 0.50, 0.975), prob_labels = c('lwr', 'medn', 'upr'), @@ -146,7 +146,7 @@ orsf_pd_new <- function(object, pred_spec, new_data, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, na_action = 'fail', expand_grid = TRUE, prob_values = c(0.025, 0.50, 0.975), @@ -194,7 +194,7 @@ orsf_pd_new <- function(object, orsf_ice_oob <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -220,7 +220,7 @@ orsf_ice_oob <- function(object, orsf_ice_inb <- function(object, pred_spec, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -252,7 +252,7 @@ orsf_ice_new <- function(object, pred_spec, new_data, pred_horizon = NULL, - pred_type = 'risk', + pred_type = NULL, na_action = 'fail', expand_grid = TRUE, boundary_checks = TRUE, @@ -308,82 +308,11 @@ orsf_pred_dependence <- function(object, check_arg_is(object, arg_name = 'object', expected_class = 'ObliqueForest') - pred_horizon <- infer_pred_horizon(object, pred_type, pred_horizon) - - if(is.null(prob_values)) prob_values <- c(0.025, 0.50, 0.975) - if(is.null(prob_labels)) prob_labels <- c('lwr', 'medn', 'upr') - if(oobag && is.null(object$data)) stop("no data were found in object. ", "did you use attach_data = FALSE when ", "running orsf()?", call. = FALSE) - if(is.null(pred_horizon) && pred_type %in% c('risk', 'surv', 'chf')){ - stop("pred_horizon must be specified for ", - pred_type, " predictions.", call. = FALSE) - } - - check_arg_type(arg_value = boundary_checks, - arg_name = 'boundary_checks', - expected_type = 'logical') - - check_arg_length(arg_value = boundary_checks, - arg_name = 'boundary_checks', - expected_length = 1) - - check_arg_type(arg_value = expand_grid, - arg_name = 'expand_grid', - expected_type = 'logical') - - check_arg_length(arg_value = expand_grid, - arg_name = 'expand_grid', - expected_length = 1) - - check_arg_type(arg_value = prob_values, - arg_name = 'prob_values', - expected_type = 'numeric') - - check_arg_gteq(arg_value = prob_values, - arg_name = 'prob_values', - bound = 0) - - check_arg_lteq(arg_value = prob_values, - arg_name = 'prob_values', - bound = 1) - - check_arg_type(arg_value = prob_labels, - arg_name = 'prob_labels', - expected_type = 'character') - - if(length(prob_values) != length(prob_labels)){ - stop("prob_values and prob_labels must have the same length.", - call. = FALSE) - } - - check_arg_type(arg_value = oobag, - arg_name = 'oobag', - expected_type = 'logical') - - check_arg_length(arg_value = oobag, - arg_name = 'oobag', - expected_length = 1) - - check_arg_type(arg_value = pred_type, - arg_name = "pred_type", - expected_type = 'character') - - check_arg_length(arg_value = pred_type, - arg_name = "pred_type", - expected_length = 1) - - check_arg_is_valid(arg_value = pred_type, - arg_name = "pred_type", - valid_options = c("risk", - "surv", - "chf", - "mort", - "prob")) - object$compute_dependence(pd_data = pd_data, pred_spec = pred_spec, pred_horizon = pred_horizon, diff --git a/R/orsf_summary.R b/R/orsf_summary.R index 6426cbe0..ed60d6df 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -106,6 +106,7 @@ print.orsf_summary_uni <- function(x, n_variables = NULL, ...){ 'surv' = 'Survival', 'chf' = 'Cumulative hazard', 'mort' = 'Mortality', + 'prob' = "Probability" ) msg_btm <- paste("Predicted", tolower(pred_label), diff --git a/man/orsf_ice_oob.Rd b/man/orsf_ice_oob.Rd index f29b6aa0..ad2a1968 100644 --- a/man/orsf_ice_oob.Rd +++ b/man/orsf_ice_oob.Rd @@ -10,7 +10,7 @@ orsf_ice_oob( object, pred_spec, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -21,7 +21,7 @@ orsf_ice_inb( object, pred_spec, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, expand_grid = TRUE, boundary_checks = TRUE, n_thread = 1, @@ -33,7 +33,7 @@ orsf_ice_new( pred_spec, new_data, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, na_action = "fail", expand_grid = TRUE, boundary_checks = TRUE, @@ -151,18 +151,18 @@ ice_oob <- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE) ice_oob }\if{html}{\out{}} -\if{html}{\out{
}}\preformatted{## id_variable id_row pred_horizon bili pred -## 1: 1 1 1788 1 0.8735558 -## 2: 1 2 1788 1 0.8260273 -## 3: 1 3 1788 1 0.6095483 -## 4: 1 4 1788 1 0.7125248 -## 5: 1 5 1788 1 0.5602638 -## --- -## 6896: 25 272 1788 10 0.6923029 -## 6897: 25 273 1788 10 0.5057890 -## 6898: 25 274 1788 10 0.3592502 -## 6899: 25 275 1788 10 0.6128702 -## 6900: 25 276 1788 10 0.3520821 +\if{html}{\out{
}}\preformatted{## id_variable id_row pred_horizon bili pred +## 1: 1 1 1 1 1 +## 2: 1 2 1 1 1 +## 3: 1 3 1 1 1 +## 4: 1 4 1 1 1 +## 5: 1 5 1 1 1 +## --- +## 6896: 25 272 1 10 1 +## 6897: 25 273 1 10 1 +## 6898: 25 274 1 10 1 +## 6899: 25 275 1 10 1 +## 6900: 25 276 1 10 1 }\if{html}{\out{
}} Much more detailed examples are given in the diff --git a/man/orsf_pd_oob.Rd b/man/orsf_pd_oob.Rd index abab3ddf..854b7bfb 100644 --- a/man/orsf_pd_oob.Rd +++ b/man/orsf_pd_oob.Rd @@ -10,7 +10,7 @@ orsf_pd_oob( object, pred_spec, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, expand_grid = TRUE, prob_values = c(0.025, 0.5, 0.975), prob_labels = c("lwr", "medn", "upr"), @@ -23,7 +23,7 @@ orsf_pd_inb( object, pred_spec, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, expand_grid = TRUE, prob_values = c(0.025, 0.5, 0.975), prob_labels = c("lwr", "medn", "upr"), @@ -37,7 +37,7 @@ orsf_pd_new( pred_spec, new_data, pred_horizon = NULL, - pred_type = "risk", + pred_type = NULL, na_action = "fail", expand_grid = TRUE, prob_values = c(0.025, 0.5, 0.975), @@ -160,12 +160,12 @@ You can compute partial dependence and ICE three ways with \code{aorsf}: pd_train }\if{html}{\out{
}} -\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr -## 1: 1826.25 1 0.2082186 0.01838469 0.09586667 0.7829076 -## 2: 1826.25 2 0.2373819 0.02960212 0.13077561 0.8018503 -## 3: 1826.25 3 0.2775926 0.05276240 0.17198237 0.8236415 -## 4: 1826.25 4 0.3271510 0.09134466 0.24272453 0.8391211 -## 5: 1826.25 5 0.3702403 0.13418606 0.28854721 0.8455809 +\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr +## 1: 1 1 1 1 1 1 +## 2: 1 2 1 1 1 1 +## 3: 1 3 1 1 1 1 +## 4: 1 4 1 1 1 1 +## 5: 1 5 1 1 1 1 }\if{html}{\out{
}} \item using out-of-bag predictions for the training data @@ -174,12 +174,12 @@ pd_train pd_train }\if{html}{\out{
}} -\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr -## 1: 1826.25 1 0.2164392 0.01900953 0.1305748 0.7272463 -## 2: 1826.25 2 0.2455743 0.03066380 0.1678195 0.7474522 -## 3: 1826.25 3 0.2846930 0.05381358 0.2121561 0.7635418 -## 4: 1826.25 4 0.3348301 0.08488471 0.2844513 0.7739593 -## 5: 1826.25 5 0.3779981 0.13183233 0.3268844 0.7928344 +\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr +## 1: 1 1 1 1 1 1 +## 2: 1 2 1 1 1 1 +## 3: 1 3 1 1 1 1 +## 4: 1 4 1 1 1 1 +## 5: 1 5 1 1 1 1 }\if{html}{\out{
}} \item using predictions for a new set of data @@ -190,12 +190,12 @@ pd_train pd_test }\if{html}{\out{
}} -\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr -## 1: 1826.25 1 0.2480931 0.01966183 0.1883315 0.8131231 -## 2: 1826.25 2 0.2771130 0.03467559 0.2247012 0.8240438 -## 3: 1826.25 3 0.3188062 0.05964563 0.2783269 0.8418708 -## 4: 1826.25 4 0.3665008 0.09992269 0.3401974 0.8532184 -## 5: 1826.25 5 0.4093192 0.14954233 0.3813199 0.8612124 +\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr +## 1: 1 1 1 1 1 1 +## 2: 1 2 1 1 1 1 +## 3: 1 3 1 1 1 1 +## 4: 1 4 1 1 1 1 +## 5: 1 5 1 1 1 1 }\if{html}{\out{
}} \item in-bag partial dependence indicates relationships that the model has learned during training. This is helpful if your goal is to interpret diff --git a/man/orsf_summarize_uni.Rd b/man/orsf_summarize_uni.Rd index b86745b5..f2396583 100644 --- a/man/orsf_summarize_uni.Rd +++ b/man/orsf_summarize_uni.Rd @@ -73,15 +73,15 @@ by setting importance = 'none'. object <- orsf(pbc_orsf, Surv(time, status) ~ . - id, n_tree = 25) -# since anova importance was used to make object, we can -# safely say importance = 'none' and skip computation of -# variable importance while running orsf_summarize_uni +# since anova importance was used to make object, it is also +# used for ranking variables in the summary, unless we specify +# a different type of importance -orsf_summarize_uni(object, n_variables = 3, importance = 'none') +orsf_summarize_uni(object, n_variables = 3) -# however, if we want to summarize object according to variables -# ranked by negation importance, we can compute negation importance -# within orsf_summarize_uni() as follows: +# if we want to summarize object according to variables +# ranked by negation importance, we can compute negation +# importance within orsf_summarize_uni() as follows: orsf_summarize_uni(object, n_variables = 3, importance = 'negate') diff --git a/src/Forest.cpp b/src/Forest.cpp index c1185728..0f6f9ec6 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -502,7 +502,7 @@ void Forest::compute_prediction_accuracy(arma::mat& y, Rcpp::NumericVector w_ = Rcpp::wrap(w); for(uword i = 0; i < oobag_eval.n_cols; ++i){ - vec p = predictions.col(i); + vec p = predictions.unsafe_col(i); Rcpp::NumericVector p_ = Rcpp::wrap(p); Rcpp::NumericVector R_result = f_oobag_eval(y_, w_, p_); oobag_eval(row_fill, i) = R_result[0]; diff --git a/tests/testthat/test-orsf_pd.R b/tests/testthat/test-orsf_pd.R index 3551ddfa..e506faea 100644 --- a/tests/testthat/test-orsf_pd.R +++ b/tests/testthat/test-orsf_pd.R @@ -135,11 +135,8 @@ for(i in seq_along(funs)){ } ) - - } - } @@ -189,6 +186,7 @@ test_that( pd_smry_multi_horiz <- orsf_pd_oob( fit, + pred_type = 'risk', pred_spec = list(bili = 1), pred_horizon = c(1000, 2000, 3000) ) @@ -201,6 +199,7 @@ test_that( pd_ice_multi_horiz <- orsf_ice_oob( fit, + pred_type = 'risk', pred_spec = list(bili = 1), pred_horizon = c(1000, 2000, 3000) )