diff --git a/R/check.R b/R/check.R index c30a273b..7a8d9f37 100644 --- a/R/check.R +++ b/R/check.R @@ -1287,7 +1287,8 @@ check_pd_inputs <- function(object, new_data = new_data, pred_horizon = pred_horizon, pred_type = pred_type, - na_action = na_action) + na_action = na_action, + valid_pred_types = c("risk", "surv", "chf", "mort")) } @@ -1564,7 +1565,8 @@ check_predict <- function(object, pred_horizon = NULL, pred_type = NULL, na_action = NULL, - boundary_checks = TRUE){ + boundary_checks = TRUE, + valid_pred_types = c("risk", "surv", "chf", "mort", "leaf")){ if(!is.null(new_data)){ @@ -1630,11 +1632,7 @@ check_predict <- function(object, check_arg_is_valid(arg_value = pred_type, arg_name = 'pred_type', - valid_options = c("risk", - "surv", - "chf", - "mort", - "leaf")) + valid_options = valid_pred_types) } diff --git a/R/orsf_pd.R b/R/orsf_pd.R index 0a539ba3..715d7e09 100644 --- a/R/orsf_pd.R +++ b/R/orsf_pd.R @@ -556,6 +556,10 @@ orsf_pred_dependence <- function(object, if(type_output == 'smry') out[, id_variable := NULL] + # not needed for mort + if(pred_type == 'mort') + out[, pred_horizon := NULL] + # put data back into original scale for(j in intersect(names(means), names(pred_spec))){ diff --git a/tests/testthat/test-orsf_pd.R b/tests/testthat/test-orsf_pd.R index d82c9ffe..98103fe7 100644 --- a/tests/testthat/test-orsf_pd.R +++ b/tests/testthat/test-orsf_pd.R @@ -9,6 +9,7 @@ test_that( fit_nodat <- orsf(formula = Surv(time, status) ~ ., data = pbc_orsf, + n_tree = 1, attach_data = FALSE) expect_error( @@ -91,25 +92,60 @@ for(i in seq_along(funs)){ formals <- setdiff(names(formals(funs[[i]])), '...') - pd_object_grid <- do.call(funs[[i]], args = args_grid[formals]) - pd_object_loop <- do.call(funs[[i]], args = args_loop[formals]) - - test_that( - desc = paste('pred_spec data are returned on the original scale', - ' for orsf_', f_name, sep = ''), - code = { - expect_equal(unique(pd_object_grid$bili), 1:4) - expect_equal(unique(pd_object_loop[variable == 'bili', value]), 1:4) - } - ) - - test_that( - desc = paste(f_name, 'returns a data.table'), - code = { - expect_s3_class(pd_object_grid, 'data.table') - expect_s3_class(pd_object_loop, 'data.table') - } - ) + for(pred_type in setdiff(pred_types_surv, 'leaf')){ + + args_grid$pred_type = pred_type + args_loop$pred_type = pred_type + + pd_object_grid <- do.call(funs[[i]], args = args_grid[formals]) + pd_object_loop <- do.call(funs[[i]], args = args_loop[formals]) + + test_that( + desc = paste('pred_spec data are returned on the original scale', + ' for orsf_', f_name, sep = ''), + code = { + expect_equal(unique(pd_object_grid$bili), 1:4) + expect_equal(unique(pd_object_loop[variable == 'bili', value]), 1:4) + } + ) + + test_that( + desc = paste(f_name, 'returns a data.table'), + code = { + expect_s3_class(pd_object_grid, 'data.table') + expect_s3_class(pd_object_loop, 'data.table') + } + ) + + test_that( + desc = 'output is named correctly', + code = { + + if(f_name %in% c("ice_new", "ice_inb", "ice_oob")){ + expect_true('id_variable' %in% names(pd_object_grid)) + expect_true('id_variable' %in% names(pd_object_loop)) + expect_true('id_row' %in% names(pd_object_grid)) + expect_true('id_row' %in% names(pd_object_loop)) + } + + expect_true('variable' %in% names(pd_object_loop)) + expect_true('value' %in% names(pd_object_loop)) + + vars <- names(args_loop$pred_spec) + expect_true(all(vars %in% names(pd_object_grid))) + expect_true(all(vars %in% unique(pd_object_loop$variable))) + + if(pred_type == 'mort'){ + expect_false('pred_horizon' %in% names(pd_object_grid)) + expect_false('pred_horizon' %in% names(pd_object_loop)) + } + + } + ) + + + + } } diff --git a/vignettes/fast.Rmd b/vignettes/fast.Rmd new file mode 100644 index 00000000..fc180611 --- /dev/null +++ b/vignettes/fast.Rmd @@ -0,0 +1,144 @@ +--- +title: "Tips to speed up computation" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Tips to speed up computation} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +```{r setup} +library(aorsf) +``` + +## Go faster + +Analyses can slow to a crawl when models need hours to run. In this article you will find a few tricks to prevent this bottleneck when using `orsf()`. We'll use the `flchain` data from `survival` to demonstrate. + +```{r} + +data("flchain", package = 'survival') + +flc <- flchain +# do this to avoid orsf() throwing an error about time to event = 0 +flc <- flc[flc$futime > 0, ] +# modify names +names(flc)[names(flc) == 'futime'] <- 'time' +names(flc)[names(flc) == 'death'] <- 'status' + +``` + +Our `flc` data has `r nrow(flc)` rows and `r ncol(flc)` columns: + +```{r} +head(flc) +``` + + +## Use `orsf_control_fast()` + +This is the default `control` value for `orsf()` and its run-time compared to other approaches can be striking. For example: + + +```{r} + +time_fast <- system.time( + expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode', + control = orsf_control_fast(), n_tree = 10) +) + +time_net <- system.time( + expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode', + control = orsf_control_net(), n_tree = 10) +) + +# control_fast() is much faster +time_net['elapsed'] / time_fast['elapsed'] + +``` + +## Use `n_thread` + +The `n_thread` argument uses multi-threading to run `aorsf` functions in parallel when possible. If you know how many threads you want, e.g. you want exactly 5, just say `n_thread = 5`. If you aren't sure how many threads you have available but want to use as many as you can, say `n_thread = 0` and `aorsf` will figure out the number for you. + +```{r} + +time_1_thread <- system.time( + expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode', + n_thread = 1, n_tree = 500) +) + +time_5_thread <- system.time( + expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode', + n_thread = 5, n_tree = 500) +) + +time_auto_thread <- system.time( + expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode', + n_thread = 0, n_tree = 500) +) + +# 5 threads and auto thread are both about 3 times faster than one thread + +time_1_thread['elapsed'] / time_5_thread['elapsed'] +time_1_thread['elapsed'] / time_auto_thread['elapsed'] + +``` + +Because R is a single threaded language, multi-threading cannot be applied when `orsf()` needs to call R functions from C++, which occurs when a customized R function is used to find linear combination of variables or compute prediction accuracy. + +## Do less + +There are some defaults in `orsf()` that can be adjusted to make it run faster: + +- set `n_retry` to 0 instead of 3 (the default) + +- set `oobag_pred_type` to 'none' instead of 'surv' (the default) + +- set 'importance' to 'none' instead of 'anova' (the default) + +- increase `split_min_events`, `split_min_obs`, `leaf_min_events`, or `leaf_min_obs` to make trees stop growing sooner + +- increase `split_min_stat` to make trees stop growing sooner + +Applying these tips: + +```{r} + +time_lightweight <- system.time( + expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode', + n_thread = 0, n_tree = 500, n_retry = 0, + oobag_pred_type = 'none', importance = 'none', + split_min_events = 20, leaf_min_events = 10, + split_min_stat = 10) +) + +# about two times faster than auto thread with defaults +time_auto_thread['elapsed'] / time_lightweight['elapsed'] + +``` + +While these default values do make `orsf()` run slower, they also usually make its predictions more accurate or make the fit easier to interpret. + +## Show progress + +Setting `verbose_progress = TRUE` doesn't make anything run faster, but it can help make it *feel* like things are running less slow. + +```{r} + +verbose_fit <- orsf(flc, time+status~., + na_action = 'na_impute_meanmode', + n_thread = 0, + n_tree = 500, + verbose_progress = TRUE) + +``` + +