Skip to content

Commit

Permalink
new vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 9, 2023
1 parent e883301 commit f3b6ae5
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 26 deletions.
12 changes: 5 additions & 7 deletions R/check.R
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,8 @@ check_pd_inputs <- function(object,
new_data = new_data,
pred_horizon = pred_horizon,
pred_type = pred_type,
na_action = na_action)
na_action = na_action,
valid_pred_types = c("risk", "surv", "chf", "mort"))

}

Expand Down Expand Up @@ -1564,7 +1565,8 @@ check_predict <- function(object,
pred_horizon = NULL,
pred_type = NULL,
na_action = NULL,
boundary_checks = TRUE){
boundary_checks = TRUE,
valid_pred_types = c("risk", "surv", "chf", "mort", "leaf")){

if(!is.null(new_data)){

Expand Down Expand Up @@ -1630,11 +1632,7 @@ check_predict <- function(object,

check_arg_is_valid(arg_value = pred_type,
arg_name = 'pred_type',
valid_options = c("risk",
"surv",
"chf",
"mort",
"leaf"))
valid_options = valid_pred_types)

}

Expand Down
4 changes: 4 additions & 0 deletions R/orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,10 @@ orsf_pred_dependence <- function(object,
if(type_output == 'smry')
out[, id_variable := NULL]

# not needed for mort
if(pred_type == 'mort')
out[, pred_horizon := NULL]

# put data back into original scale
for(j in intersect(names(means), names(pred_spec))){

Expand Down
74 changes: 55 additions & 19 deletions tests/testthat/test-orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ test_that(

fit_nodat <- orsf(formula = Surv(time, status) ~ .,
data = pbc_orsf,
n_tree = 1,
attach_data = FALSE)

expect_error(
Expand Down Expand Up @@ -91,25 +92,60 @@ for(i in seq_along(funs)){

formals <- setdiff(names(formals(funs[[i]])), '...')

pd_object_grid <- do.call(funs[[i]], args = args_grid[formals])
pd_object_loop <- do.call(funs[[i]], args = args_loop[formals])

test_that(
desc = paste('pred_spec data are returned on the original scale',
' for orsf_', f_name, sep = ''),
code = {
expect_equal(unique(pd_object_grid$bili), 1:4)
expect_equal(unique(pd_object_loop[variable == 'bili', value]), 1:4)
}
)

test_that(
desc = paste(f_name, 'returns a data.table'),
code = {
expect_s3_class(pd_object_grid, 'data.table')
expect_s3_class(pd_object_loop, 'data.table')
}
)
for(pred_type in setdiff(pred_types_surv, 'leaf')){

args_grid$pred_type = pred_type
args_loop$pred_type = pred_type

pd_object_grid <- do.call(funs[[i]], args = args_grid[formals])
pd_object_loop <- do.call(funs[[i]], args = args_loop[formals])

test_that(
desc = paste('pred_spec data are returned on the original scale',
' for orsf_', f_name, sep = ''),
code = {
expect_equal(unique(pd_object_grid$bili), 1:4)
expect_equal(unique(pd_object_loop[variable == 'bili', value]), 1:4)
}
)

test_that(
desc = paste(f_name, 'returns a data.table'),
code = {
expect_s3_class(pd_object_grid, 'data.table')
expect_s3_class(pd_object_loop, 'data.table')
}
)

test_that(
desc = 'output is named correctly',
code = {

if(f_name %in% c("ice_new", "ice_inb", "ice_oob")){
expect_true('id_variable' %in% names(pd_object_grid))
expect_true('id_variable' %in% names(pd_object_loop))
expect_true('id_row' %in% names(pd_object_grid))
expect_true('id_row' %in% names(pd_object_loop))
}

expect_true('variable' %in% names(pd_object_loop))
expect_true('value' %in% names(pd_object_loop))

vars <- names(args_loop$pred_spec)
expect_true(all(vars %in% names(pd_object_grid)))
expect_true(all(vars %in% unique(pd_object_loop$variable)))

if(pred_type == 'mort'){
expect_false('pred_horizon' %in% names(pd_object_grid))
expect_false('pred_horizon' %in% names(pd_object_loop))
}

}
)



}


}
Expand Down
144 changes: 144 additions & 0 deletions vignettes/fast.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
---
title: "Tips to speed up computation"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Tips to speed up computation}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
```

```{r setup}
library(aorsf)
```

## Go faster

Analyses can slow to a crawl when models need hours to run. In this article you will find a few tricks to prevent this bottleneck when using `orsf()`. We'll use the `flchain` data from `survival` to demonstrate.

```{r}
data("flchain", package = 'survival')
flc <- flchain
# do this to avoid orsf() throwing an error about time to event = 0
flc <- flc[flc$futime > 0, ]
# modify names
names(flc)[names(flc) == 'futime'] <- 'time'
names(flc)[names(flc) == 'death'] <- 'status'
```

Our `flc` data has `r nrow(flc)` rows and `r ncol(flc)` columns:

```{r}
head(flc)
```


## Use `orsf_control_fast()`

This is the default `control` value for `orsf()` and its run-time compared to other approaches can be striking. For example:


```{r}
time_fast <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
control = orsf_control_fast(), n_tree = 10)
)
time_net <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
control = orsf_control_net(), n_tree = 10)
)
# control_fast() is much faster
time_net['elapsed'] / time_fast['elapsed']
```

## Use `n_thread`

The `n_thread` argument uses multi-threading to run `aorsf` functions in parallel when possible. If you know how many threads you want, e.g. you want exactly 5, just say `n_thread = 5`. If you aren't sure how many threads you have available but want to use as many as you can, say `n_thread = 0` and `aorsf` will figure out the number for you.

```{r}
time_1_thread <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 1, n_tree = 500)
)
time_5_thread <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 5, n_tree = 500)
)
time_auto_thread <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 0, n_tree = 500)
)
# 5 threads and auto thread are both about 3 times faster than one thread
time_1_thread['elapsed'] / time_5_thread['elapsed']
time_1_thread['elapsed'] / time_auto_thread['elapsed']
```

Because R is a single threaded language, multi-threading cannot be applied when `orsf()` needs to call R functions from C++, which occurs when a customized R function is used to find linear combination of variables or compute prediction accuracy.

## Do less

There are some defaults in `orsf()` that can be adjusted to make it run faster:

- set `n_retry` to 0 instead of 3 (the default)

- set `oobag_pred_type` to 'none' instead of 'surv' (the default)

- set 'importance' to 'none' instead of 'anova' (the default)

- increase `split_min_events`, `split_min_obs`, `leaf_min_events`, or `leaf_min_obs` to make trees stop growing sooner

- increase `split_min_stat` to make trees stop growing sooner

Applying these tips:

```{r}
time_lightweight <- system.time(
expr = orsf(flc, time+status~., na_action = 'na_impute_meanmode',
n_thread = 0, n_tree = 500, n_retry = 0,
oobag_pred_type = 'none', importance = 'none',
split_min_events = 20, leaf_min_events = 10,
split_min_stat = 10)
)
# about two times faster than auto thread with defaults
time_auto_thread['elapsed'] / time_lightweight['elapsed']
```

While these default values do make `orsf()` run slower, they also usually make its predictions more accurate or make the fit easier to interpret.

## Show progress

Setting `verbose_progress = TRUE` doesn't make anything run faster, but it can help make it *feel* like things are running less slow.

```{r}
verbose_fit <- orsf(flc, time+status~.,
na_action = 'na_impute_meanmode',
n_thread = 0,
n_tree = 500,
verbose_progress = TRUE)
```


0 comments on commit f3b6ae5

Please sign in to comment.