diff --git a/R/collapse_misc.R b/R/collapse_misc.R index ac0cce73..e234a5cd 100644 --- a/R/collapse_misc.R +++ b/R/collapse_misc.R @@ -37,15 +37,15 @@ data_impute <- function(data, cols, values){ } -data_impute_nocheck <- function(data, cols, values){ - - for(col in cols) - data <- collapse::replace_NA(data, - cols = col, - value = values[[col]]) - - data - -} +# data_impute_nocheck <- function(data, cols, values){ +# +# for(col in cols) +# data <- collapse::replace_NA(data, +# cols = col, +# value = values[[col]]) +# +# data +# +# } diff --git a/R/infer.R b/R/infer.R index 20fe0ac9..727865b2 100644 --- a/R/infer.R +++ b/R/infer.R @@ -26,10 +26,8 @@ infer_pred_horizon <- function(object, pred_type, pred_horizon){ # throw error if pred_type requires pred_horizon if(is.null(pred_horizon)){ - stop("pred_horizon was not specified and could not be found in object.", call. = FALSE) - } diff --git a/src/Forest.cpp b/src/Forest.cpp index 664354d0..28afd964 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -90,6 +90,7 @@ void Forest::init(std::unique_ptr input_data, // oobag denominator tracks the number of times an obs is oobag oobag_denom.zeros(data->get_n_rows()); + // # nocov start if(verbosity > 1){ Rcout << "------------ input data dimensions ------------" << std::endl; @@ -100,6 +101,7 @@ void Forest::init(std::unique_ptr input_data, Rcout << std::endl; } + // # nocov end } diff --git a/src/Tree.cpp b/src/Tree.cpp index 98528d4a..64b9d6c1 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -316,12 +316,12 @@ } if(verbosity > 4){ - + // # nocov start mat x_print = x_inbag.rows(rows_node); Rcout << " -- Column " << j << " was sampled but "; Rcout << "its unique values are " << unique(x_print.col(j)); Rcout << std::endl; - + // # nocov end } return(false); @@ -385,10 +385,12 @@ if(n_obs >= leaf_min_obs) { if(verbosity > 3){ + // # nocov start Rcout << std::endl; Rcout << " -- lower cutpoint: " << lincomb(*it) << std::endl; Rcout << " - n_obs, left node: " << n_obs << std::endl; Rcout << std::endl; + // # nocov end } break; @@ -404,7 +406,9 @@ if(it == lincomb_sort.end()-1) { if(verbosity > 3){ + // # nocov start Rcout << " -- Could not find a valid cut-point" << std::endl; + // # nocov end } return; @@ -437,10 +441,12 @@ --it; if(verbosity > 3){ + // # nocov start Rcout << std::endl; Rcout << " -- upper cutpoint: " << lincomb(*it) << std::endl; Rcout << " - n_obs, right node: " << n_obs << std::endl; Rcout << std::endl; + // # nocov end } break; @@ -459,7 +465,9 @@ if(j > k){ if(verbosity > 2) { + // # nocov start Rcout << " -- Could not find valid cut-points" << std::endl; + // # nocov end } return; @@ -502,7 +510,9 @@ double stat, stat_best = 0; if(verbosity > 3){ + // # nocov start Rcout << " -- cutpoint (score)" << std::endl; + // # nocov end } for(it = cuts_sampled.begin(); it != cuts_sampled.end(); ++it){ @@ -517,21 +527,25 @@ it_start = *it; if(verbosity > 3){ + // # nocov start Rcout << " --- "; Rcout << lincomb.at(lincomb_sort(*it)); Rcout << " (" << stat << "), "; Rcout << "N = " << sum(g_node % w_node) << " moving right"; Rcout << std::endl; + // # nocov end } } if(verbosity > 3){ + // # nocov start Rcout << std::endl; Rcout << " -- best stat: " << stat_best; Rcout << ", min to split: " << split_min_stat; Rcout << std::endl; Rcout << std::endl; + // # nocov end } // do not split if best stat < minimum stat @@ -587,10 +601,12 @@ void Tree::sprout_leaf(uword node_id){ if(verbosity > 2){ + // # nocov start Rcout << "-- sprouting node " << node_id << " into a leaf"; Rcout << " (N = " << sum(w_node) << ")"; Rcout << std::endl; Rcout << std::endl; + // # nocov end } leaf_summary[node_id] = mean(y_node.col(0)); @@ -648,11 +664,13 @@ double accuracy_normal = compute_prediction_accuracy(pred_values); if(verbosity > 1){ + // # nocov start Rcout << " -- prediction accuracy before noising: "; Rcout << accuracy_normal << std::endl; Rcout << " -- mean leaf pred: "; Rcout << mean(conv_to::from(pred_leaf)); Rcout << std::endl << std::endl; + // # nocov end } random_number_generator.seed(seed); @@ -691,11 +709,13 @@ double accuracy_permuted = compute_prediction_accuracy(pred_values); if(verbosity > 3){ + // # nocov start Rcout << " -- prediction accuracy after noising " << pred_col << ": "; Rcout << accuracy_permuted << std::endl; Rcout << " - mean leaf pred: "; Rcout << mean(conv_to::from(pred_leaf)); Rcout << std::endl << std::endl; + // # nocov end } double accuracy_difference = accuracy_normal - accuracy_permuted; @@ -733,7 +753,7 @@ this->max_nodes = (2 * max_leaves) - 1; if(verbosity > 2){ - + // # nocov start Rcout << "- N obs inbag: " << n_obs_inbag; Rcout << std::endl; Rcout << "- N row inbag: " << n_rows_inbag; @@ -743,8 +763,7 @@ Rcout << "- max leaves: " << max_leaves; Rcout << std::endl; Rcout << std::endl; - - + // # nocov end } // reserve memory for outputs (likely more than we need) @@ -798,12 +817,13 @@ n_retry++; if(verbosity > 3){ - + // # nocov start Rcout << "-- attempting to split node " << *node; Rcout << " (N = " << sum(w_node) << ","; Rcout << " try number " << n_retry << ")"; Rcout << std::endl; Rcout << std::endl; + // # nocov end } sample_cols(); @@ -813,7 +833,9 @@ x_node = x_inbag(rows_node, cols_node); if(verbosity > 3) { + // # nocov start print_uvec(cols_node, "columns sampled (showing up to 5)", 5); + // # nocov end } // beta holds estimates (first item) and variance (second) @@ -892,7 +914,9 @@ vec beta_est = beta.unsafe_col(0); if(verbosity > 3) { + // # nocov start print_vec(beta_est, "linear combo weights (showing up to 5)", 5); + // # nocov end } bool beta_all_zeros = find(beta_est != 0).is_empty(); @@ -908,10 +932,10 @@ find_all_cuts(); if(verbosity > 3 && cuts_all.is_empty()){ - + // # nocov start Rcout << " -- no cutpoints identified"; Rcout << std::endl; - + // # nocov end } // empty cuts_all => no valid cutpoints => make leaf or retry @@ -930,7 +954,9 @@ // 2. the method used for lincombs allows it if(verbosity > 3){ + // # nocov start Rcout << " -- p-values:" << std::endl; + // # nocov end } vec beta_var = beta.unsafe_col(1); @@ -946,7 +972,7 @@ pvalue = R::pchisq(pow(beta_est[i],2)/beta_var[i], 1, false, false); if(verbosity > 3){ - + // # nocov start Rcout << " --- column " << cols_node[i] << ": "; Rcout << pvalue; if(pvalue < 0.05) Rcout << "*"; @@ -954,7 +980,7 @@ if(pvalue < 0.001) Rcout << "*"; if(pvalue < vi_max_pvalue) Rcout << " [+1 to VI numerator]"; Rcout << std::endl; - + // # nocov end } if(pvalue < vi_max_pvalue){ (*vi_numer)[cols_node[i]]++; } @@ -963,7 +989,11 @@ } - if(verbosity > 3){ Rcout << std::endl; } + if(verbosity > 3){ + // # nocov start + Rcout << std::endl; + // # nocov end + } } @@ -981,11 +1011,13 @@ node_assignments.elem(rows_node) = node_left + g_node; if(verbosity > 2){ + // # nocov start Rcout << "-- node " << *node << " was split into "; Rcout << "node " << node_left << " (left) and "; Rcout << node_left+1 << " (right)"; Rcout << std::endl; Rcout << std::endl; + // # nocov end } nodes_queued.push_back(node_left); @@ -1035,7 +1067,9 @@ if(coef_values.size() == 0) return; if(verbosity > 2){ + // # nocov start Rcout << " -- computing leaf predictions" << std::endl; + // # nocov end } uvec obs_in_node; @@ -1080,12 +1114,14 @@ } if(verbosity > 4){ + // # nocov start uvec in_left = find(pred_leaf == child_left[i]); uvec in_right = find(pred_leaf == child_left[i]+1); Rcout << "No. to node " << child_left[i] << ": "; Rcout << in_left.size() << "; " << std::endl; Rcout << "No. to node " << child_left[i]+1 << ": "; Rcout << in_right.size() << std::endl << std::endl; + // # nocov end } } diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index 1c9e875e..bcf13c11 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -120,7 +120,7 @@ } if(verbosity > 3){ - + // # nocov start mat x_print = x_inbag.rows(rows_node); mat y_print = y_inbag.rows(rows_node); @@ -130,7 +130,7 @@ Rcout << " --- Column " << j << " was sampled but "; Rcout << " unique values of column " << j << " are "; Rcout << unique(x_print.col(j)) << std::endl; - + // # nocov end } return(false); @@ -178,11 +178,13 @@ n_risk >= leaf_min_obs ) { if(verbosity > 2){ + // # nocov start Rcout << std::endl; Rcout << " -- lower cutpoint: " << lincomb(*it) << std::endl; Rcout << " - n_events, left node: " << n_events << std::endl; Rcout << " - n_risk, left node: " << n_risk << std::endl; Rcout << std::endl; + // # nocov end } break; @@ -198,7 +200,9 @@ if(it == lincomb_sort.end()-1) { if(verbosity > 2){ + // # nocov start Rcout << " -- Could not find a valid cut-point" << std::endl; + // # nocov end } return; @@ -233,11 +237,13 @@ --it; if(verbosity > 2){ + // # nocov start Rcout << std::endl; Rcout << " -- upper cutpoint: " << lincomb(*it) << std::endl; Rcout << " - n_events, right node: " << n_events << std::endl; Rcout << " - n_risk, right node: " << n_risk << std::endl; Rcout << std::endl; + // # nocov end } break; @@ -256,7 +262,9 @@ if(j > k){ if(verbosity > 2) { + // # nocov start Rcout << "Could not find valid cut-points" << std::endl; + // # nocov end } return; @@ -316,10 +324,12 @@ void TreeSurvival::sprout_leaf(uword node_id){ if(verbosity > 2){ + // # nocov start Rcout << "-- sprouting node " << node_id << " into a leaf"; Rcout << " (N = " << sum(w_node) << ")"; Rcout << std::endl; Rcout << std::endl; + // # nocov end } @@ -415,9 +425,11 @@ if(verbosity > 3){ + // # nocov start mat tmp_mat = join_horiz(y_node, w_node); print_mat(tmp_mat, "time & status & weights in this node", 10, 10); print_mat(leaf_data, "leaf_data (showing up to 5 rows)", 5, 5); + // # nocov end } leaf_pred_indx[node_id] = leaf_data.col(0); @@ -476,6 +488,7 @@ uvec::iterator it = pred_leaf_sort.begin(); if(verbosity > 2){ + // # nocov start uvec tmp_uvec = find(pred_leaf < max_nodes); if(tmp_uvec.size() == 0){ @@ -484,6 +497,7 @@ } Rcout << " -- N preds expected: " << tmp_uvec.size() << std::endl; + // # nocov end } uword leaf_id = pred_leaf[*it]; @@ -657,9 +671,11 @@ } if(verbosity > 2){ + // # nocov start Rcout << " -- N preds made: " << n_preds_made; Rcout << std::endl; Rcout << std::endl; + // # nocov end } diff --git a/src/utility.cpp b/src/utility.cpp index 13ebec6b..5db7d0f9 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -13,6 +13,7 @@ namespace aorsf { + // # nocov start void print_mat(arma::mat& x, std::string label, arma::uword max_cols, @@ -69,6 +70,46 @@ } + std::string uintToString(uint number) { + return std::to_string(number); + } + + std::string beautifyTime(uint seconds) { + + std::string result; + + // Add seconds, minutes, hours, days if larger than zero + uint out_seconds = (uint) seconds % 60; + result = uintToString(out_seconds) + " seconds"; + uint out_minutes = (seconds / 60) % 60; + if (seconds / 60 == 0) { + return result; + } else if (out_minutes == 1) { + result = "1 minute, " + result; + } else { + result = uintToString(out_minutes) + " minutes, " + result; + } + uint out_hours = (seconds / 3600) % 24; + if (seconds / 3600 == 0) { + return result; + } else if (out_hours == 1) { + result = "1 hour, " + result; + } else { + result = uintToString(out_hours) + " hours, " + result; + } + uint out_days = (seconds / 86400); + if (out_days == 0) { + return result; + } else if (out_days == 1) { + result = "1 day, " + result; + } else { + result = uintToString(out_days) + " days, " + result; + } + return result; + } + + // # nocov end + void equalSplit(std::vector& result, uint start, uint end, uint num_parts) { result.reserve(num_parts + 1); @@ -136,43 +177,6 @@ } - std::string uintToString(uint number) { - return std::to_string(number); - } - - std::string beautifyTime(uint seconds) { - - std::string result; - - // Add seconds, minutes, hours, days if larger than zero - uint out_seconds = (uint) seconds % 60; - result = uintToString(out_seconds) + " seconds"; - uint out_minutes = (seconds / 60) % 60; - if (seconds / 60 == 0) { - return result; - } else if (out_minutes == 1) { - result = "1 minute, " + result; - } else { - result = uintToString(out_minutes) + " minutes, " + result; - } - uint out_hours = (seconds / 3600) % 24; - if (seconds / 3600 == 0) { - return result; - } else if (out_hours == 1) { - result = "1 hour, " + result; - } else { - result = uintToString(out_hours) + " hours, " + result; - } - uint out_days = (seconds / 86400); - if (out_days == 0) { - return result; - } else if (out_days == 1) { - result = "1 day, " + result; - } else { - result = uintToString(out_days) + " days, " + result; - } - return result; - } double compute_logrank(arma::mat& y, arma::vec& w, diff --git a/tests/testthat/_snaps/verbosity.md b/tests/testthat/_snaps/verbosity.md new file mode 100644 index 00000000..01ec66aa --- /dev/null +++ b/tests/testthat/_snaps/verbosity.md @@ -0,0 +1,20 @@ +# verbosity prints grow, predict, and importance notes + + Code + fit_verbose <- orsf(pbc, time + status ~ ., verbose_progress = TRUE, n_tree = n_tree_test, + importance = "negate") + Output + Growing trees: 100%. + Computing predictions: 100%. + Computing importance: 100%. + +--- + + Code + fit_verbose <- orsf(pbc, time + status ~ ., verbose_progress = TRUE, n_tree = n_tree_test, + importance = "negate", n_thread = 5) + Output + Growing trees: 100%. + Computing predictions: 100%. + Computing importance: 100%. + diff --git a/tests/testthat/test-infer.R b/tests/testthat/test-infer.R new file mode 100644 index 00000000..0b6e76dd --- /dev/null +++ b/tests/testthat/test-infer.R @@ -0,0 +1,41 @@ + +test_that( + desc = 'inferred pred horizon is correct', + code = { + + expect_equal( + infer_pred_horizon(fit_standard_pbc$fast, + pred_type = 'risk', + pred_horizon = NULL), + get_oobag_pred_horizon(fit_standard_pbc$fast) + ) + + expect_equal( + infer_pred_horizon(fit_standard_pbc$fast, + pred_type = 'risk', + pred_horizon = 100), + 100 + ) + + expect_equal( + infer_pred_horizon(fit_standard_pbc$fast, + pred_type = 'mort', + pred_horizon = 100), + 1 + ) + + fit_renegade <- fit_standard_pbc$fast + + fit_renegade$pred_horizon <- NULL + attr(fit_renegade, 'oobag_pred_horizon') <- NULL + + expect_error(infer_pred_horizon(fit_renegade, + pred_type = 'risk', + pred_horizon = NULL), + regexp = 'could not be found') + + } +) + + + diff --git a/tests/testthat/test-oobag.R b/tests/testthat/test-oobag.R index 2d876c21..6ac24f22 100644 --- a/tests/testthat/test-oobag.R +++ b/tests/testthat/test-oobag.R @@ -1,27 +1,23 @@ -fit_custom_oobag <- orsf(pbc_orsf, - formula = Surv(time, status) ~ . - id, - n_tree = 100, - oobag_fun = oobag_c_survival, - tree_seeds = 1:100) - -fit_standard_oobag <- orsf(pbc_orsf, - formula = Surv(time, status) ~ . - id, - n_tree = 100, - tree_seeds = 1:100) - -testthat::expect_equal( - fit_custom_oobag$forest$rows_oobag, - fit_standard_oobag$forest$rows_oobag -) test_that( - desc = 'tree seeds show that a custom oobag fun matches the internal one', + desc = 'oobag error works w/oobag_eval_every & custom oobag fun works', code = { + + fit_custom_oobag <- orsf(pbc, + formula = Surv(time, status) ~ ., + n_tree = n_tree_test, + oobag_eval_every = 1, + oobag_fun = oobag_c_survival, + tree_seeds = seeds_standard) + + expect_equal_leaf_summary(fit_custom_oobag, fit_standard_pbc$fast) + expect_equal( - fit_standard_oobag$eval_oobag$stat_values, - fit_custom_oobag$eval_oobag$stat_values + get_last_oob_stat_value(fit_standard_pbc$fast), + get_last_oob_stat_value(fit_custom_oobag) ) + } ) diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index 6159c79d..4746ce78 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -44,8 +44,6 @@ test_that( ) -#' @srrstats {G5.8b, G5.8b} *Data of unsupported types trigger an error* - test_that( desc = "blank and non-standard names trigger an error", code = { @@ -77,8 +75,6 @@ test_that( ) -#' @srrstats {G2.11} *testing allowance and accounting for units class* - test_that( 'orsf tracks meta data for units class variables', code = { @@ -117,91 +113,6 @@ test_that( ) -data_fit <- copy(pbc_orsf) - -fit_with_vi <- orsf(data = data_fit, - formula = Surv(time, status) ~ . - id, - importance = 'negate', - n_tree = 50) - -test_that("data are not unintentionally modified by reference", - code = {expect_identical(data_fit, pbc_orsf)}) - - -fit_no_vi <- orsf(data = pbc_orsf, - formula = Surv(time, status) ~ . - id, - importance = 'none', - n_tree = 50) - -# I'm making the difference in data size very big because I don't want this -# test to fail on some operating systems. -pbc_small <- pbc_orsf[1:50, ] - -pbc_large <- rbind(pbc_orsf, pbc_orsf, pbc_orsf, pbc_orsf, pbc_orsf) -pbc_large <- rbind(pbc_large, pbc_large, pbc_large, pbc_large, pbc_large) - -test_that( - desc = "algorithm runs slower as data size increases", - code = { - time_small <- system.time(orsf(pbc_small, - Surv(time, status) ~ . -id, - n_tree=50)) - - time_large <- system.time(orsf(pbc_large, - Surv(time, status) ~ . -id, - n_tree=50)) - - expect_true(time_small['elapsed'] < time_large['elapsed']) - } -) - - -test_that( - desc = "algorithm runs faster with lower convergence tolerance", - code = { - - time_small <- system.time( - orsf(pbc_orsf, - control = orsf_control_fast(), - Surv(time, status) ~ . -id, - n_tree = 500) - ) - - time_large <- system.time( - orsf(pbc_orsf, - control = orsf_control_cph(iter_max = 50, eps = 1e-10), - Surv(time, status) ~ . -id, - n_tree = 500) - ) - - expect_true(time_small['elapsed'] < time_large['elapsed']) - - } -) - -test_that( - desc = "algorithm runs faster with lower number of iterations", - code = { - - time_small <- system.time( - orsf(pbc_orsf, - Surv(time, status) ~ . -id, - n_tree = 5) - ) - - time_large <- system.time( - orsf(pbc_orsf, - Surv(time, status) ~ . -id, - n_tree = 1000) # big difference prevents unneeded failure - ) - - expect_true(time_small['elapsed'] < time_large['elapsed']) - - } -) - - -#' @srrstats {ML7.11} *OOB C-statistic is monitored by this test. As the number of trees in the forest increases, the C-statistic should also increase* test_that( desc = "algorithm grows more accurate with higher number of iterations", @@ -219,8 +130,6 @@ test_that( ) -#' @srrstats {G5.8, G5.8a} **Edge condition tests** *Zero-length data produce expected behaviour* - test_that( desc = 'Boundary case: empty training data throw an error', code = { @@ -620,6 +529,7 @@ test_that( n_split = 1, n_retry = 0, mtry = 3, + sample_with_replacement = c(TRUE, FALSE), leaf_min_events = 5, leaf_min_obs = c(10), split_rule = c("logrank", "cstat"), @@ -650,9 +560,17 @@ test_that( 'net' = orsf_control_net(), 'custom' = orsf_control_custom(beta_fun = f_pca)) + if(inputs$sample_with_replacement[i]){ + sample_fraction <- 0.632 + } else { + sample_fraction <- runif(n = 1, min = .25, max = .75) + } + fit <- orsf(data = data_fun(pbc_orsf), formula = time + status ~ . - id, control = control, + sample_with_replacement = inputs$sample_with_replacement[i], + sample_fraction = sample_fraction, n_tree = inputs$n_tree[i], n_split = inputs$n_split[i], n_retry = inputs$n_retry[i], @@ -667,6 +585,10 @@ test_that( expect_s3_class(fit, class = 'orsf_fit') + # data are not unintentionally modified by reference, + expect_identical(data_fun(pbc_orsf), fit$data) + + expect_no_missing(fit$forest) expect_no_missing(fit$importance) expect_no_missing(fit$pred_horizon) @@ -688,6 +610,14 @@ test_that( expect_length(fit$forest$coef_values, n = get_n_tree(fit)) expect_length(fit$forest$leaf_summary, n = get_n_tree(fit)) + if(!inputs$sample_with_replacement[i]){ + expect_equal( + 1 - length(fit$forest$rows_oobag[[1]]) / get_n_obs(fit), + sample_fraction, + tolerance = 0.025 + ) + } + if(inputs$oobag_pred_type[i] != 'none'){ if(inputs$oobag_pred_type[i] %in% c("chf","surv","risk")){ diff --git a/tests/testthat/test-orsf_pd.R b/tests/testthat/test-orsf_pd.R index 2a8393c4..d82c9ffe 100644 --- a/tests/testthat/test-orsf_pd.R +++ b/tests/testthat/test-orsf_pd.R @@ -60,6 +60,60 @@ test_that( } ) +funs <- list( + ice_new = orsf_ice_new, + ice_inb = orsf_ice_inb, + ice_oob = orsf_ice_oob, + pd_new = orsf_pd_new, + pd_inb = orsf_pd_inb, + pd_oob = orsf_pd_oob +) + +args_loop <- args_grid <- list( + object = fit, + pred_spec = list(bili = 1:4, sex = c("m", "f")), + new_data = pbc_test, + pred_horizon = 1000, + pred_type = 'risk', + na_action = 'fail', + expand_grid = TRUE, + prob_values = c(0.025, 0.50, 0.975), + prob_labels = c("lwr", "medn", "upr"), + boundary_checks = TRUE, + n_thread = 3 +) + +args_loop$expand_grid <- FALSE + +for(i in seq_along(funs)){ + + f_name <- names(funs)[i] + + formals <- setdiff(names(formals(funs[[i]])), '...') + + pd_object_grid <- do.call(funs[[i]], args = args_grid[formals]) + pd_object_loop <- do.call(funs[[i]], args = args_loop[formals]) + + test_that( + desc = paste('pred_spec data are returned on the original scale', + ' for orsf_', f_name, sep = ''), + code = { + expect_equal(unique(pd_object_grid$bili), 1:4) + expect_equal(unique(pd_object_loop[variable == 'bili', value]), 1:4) + } + ) + + test_that( + desc = paste(f_name, 'returns a data.table'), + code = { + expect_s3_class(pd_object_grid, 'data.table') + expect_s3_class(pd_object_loop, 'data.table') + } + ) + + +} + pd_vals_ice <- orsf_ice_new( fit, new_data = pbc_orsf, @@ -74,18 +128,6 @@ pd_vals_smry <- orsf_pd_new( pred_horizon = 1000 ) -test_that( - 'pred_spec data are returned on the original scale', - - code = { - - expect_equal(unique(pd_vals_ice$bili), 1:4) - expect_equal(unique(pd_vals_smry$bili), 1:4) - - } - -) - test_that( 'ice values summarized are the same as pd values', code = { diff --git a/tests/testthat/test-orsf_predict.R b/tests/testthat/test-orsf_predict.R index 8dd61742..6df94d8c 100644 --- a/tests/testthat/test-orsf_predict.R +++ b/tests/testthat/test-orsf_predict.R @@ -355,25 +355,25 @@ test_that( ) test_that( - desc = 'predictions do not depend on other observations in the data', + desc = 'leaf predictions do not depend on other observations in the data', code = { - p_risk <- predict(fit, new_data = new_data) + p_leaf <- predict(fit, new_data = new_data, pred_type = 'leaf') for(i in seq(nrow(new_data))){ - p_1row <- predict(fit, new_data = new_data[i,]) - expect_equal(p_1row, p_risk[i], ignore_attr = TRUE, tolerance = 1e-9) + p_1row <- predict(fit, new_data = new_data[i,], pred_type = 'leaf') + expect_equal(p_1row, p_leaf[i, , drop=FALSE]) } } ) test_that( - 'predictions do not depend on order of the data', + 'leaf predictions do not depend on order of the data', code = { - for(pred_type in pred_types_surv){ + for(pred_type in c('leaf')){ p_before <- predict(fit, new_data = new_data, diff --git a/tests/testthat/test-verbosity.R b/tests/testthat/test-verbosity.R new file mode 100644 index 00000000..824d9632 --- /dev/null +++ b/tests/testthat/test-verbosity.R @@ -0,0 +1,26 @@ + + +test_that( + desc = 'verbosity prints grow, predict, and importance notes', + code = { + + expect_snapshot( + fit_verbose <- orsf(pbc, time + status ~., + verbose_progress = TRUE, + n_tree = n_tree_test, + importance = 'negate') + ) + + expect_snapshot( + fit_verbose <- orsf(pbc, time + status ~., + verbose_progress = TRUE, + n_tree = n_tree_test, + importance = 'negate', + n_thread = 5) + ) + + } +) + + +