Skip to content

Commit

Permalink
handling nans for PD when n_tree is small
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 7, 2023
1 parent 5bb7376 commit 386f66e
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 42 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: aorsf
Title: Accelerated Oblique Random Survival Forests
Version: 0.0.7.9000
Version: 0.1.0
Authors@R: c(
person(given = "Byron",
family = "Jaeger",
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

* Re-worked internal C++ routines following the design of `ranger`.

* re-worked how progress is printed to console when `verbose_progress` is `TRUE`, following the design of `ranger`. Messages now indicate the action being taken, the % complete, and the approximate time until finishing the action.
* Re-worked how progress is printed to console when `verbose_progress` is `TRUE`, following the design of `ranger`. Messages now indicate the action being taken, the % complete, and the approximate time until finishing the action.

* Allowed multi-threading to be performed in `orsf()`, `predict.orsf_fit()`, and functions in the `orsf_vi()` and `orsf_pd()` family.

Expand Down
15 changes: 14 additions & 1 deletion R/check.R
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@ check_orsf_inputs <- function(data = NULL,
oobag_eval_every = NULL,
importance = NULL,
tree_seeds = NULL,
attach_data = NULL){
attach_data = NULL,
verbose_progress = NULL){

if(!is.null(data)){

Expand Down Expand Up @@ -1092,6 +1093,18 @@ check_orsf_inputs <- function(data = NULL,

}

if(!is.null(verbose_progress)){

check_arg_type(arg_value = verbose_progress,
arg_name = 'verbose_progress',
expected_type = 'logical')

check_arg_length(arg_value = verbose_progress,
arg_name = 'verbose_progress',
expected_length = 1)

}

}

#' Check inputs for orsf_pd()
Expand Down
3 changes: 2 additions & 1 deletion R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ orsf <- function(data,
oobag_eval_every = oobag_eval_every,
importance = importance,
tree_seeds = tree_seeds,
attach_data = attach_data
attach_data = attach_data,
verbose_progress = verbose_progress
)

#TODO: more polish
Expand Down
3 changes: 3 additions & 0 deletions R/orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,9 @@ orsf_pred_dependence <- function(object,

out <- rbindlist(pd_vals)

# missings may occur when oobag=TRUE and n_tree is small
out <- collapse::na_omit(out)

ids <- c('id_variable', if(type_output == 'ice') 'id_row')

mid <- setdiff(names(out), c(ids, 'mean', prob_labels, 'pred'))
Expand Down
14 changes: 10 additions & 4 deletions R/orsf_summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,15 @@ print.orsf_summary_uni <- function(x, n_variables = NULL, ...){

if(is.null(n_variables)) n_variables <- length(unique(x$dt$variable))

risk_or_surv <- if(x$pred_type == 'risk') "risk" else "survival"
pred_label <- switch(
x$pred_type,
'risk' = 'Risk',
'surv' = 'Survival',
'chf' = 'Cumulative hazard',
'mort' = 'Mortality',
)

msg_btm <- paste("Predicted", risk_or_surv,
msg_btm <- paste("Predicted", tolower(pred_label),
"at time t =", x$pred_horizon,
"for top", n_variables,
"predictors")
Expand Down Expand Up @@ -324,15 +330,15 @@ print.orsf_summary_uni <- function(x, n_variables = NULL, ...){
banner_value_length <- banner_value_length + 1

header_length <-
(banner_input_length - banner_value_length - nchar(risk_or_surv)) / 2
(banner_input_length - banner_value_length - nchar(pred_label)) / 2

header_length <- header_length - 1.5

header_row <- paste(
paste(rep(" ", times = banner_value_length), collapse = ''),
paste(c("|",rep("-", times = header_length)), collapse = ''),
" ",
risk_or_surv,
pred_label,
" ",
paste(c(rep("-", times = header_length), "|"), collapse = ''),
collapse = '',
Expand Down
11 changes: 10 additions & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,21 @@ The importance of individual variables can be estimated in three ways using `aor
```

You can supply your own R function to estimate out-of-bag error when using negation or permutation importance. This feature is experimental and may be changed in the future (see [oob vignette](https://docs.ropensci.org/aorsf/articles/oobag.html))
You can supply your own R function to estimate out-of-bag error when using negation or permutation importance (see [oob vignette](https://docs.ropensci.org/aorsf/articles/oobag.html))

### Partial dependence (PD)

`r aorsf:::roxy_pd_explain()`

The summary function, `orsf_summarize_uni()`, computes PD for as many variables as you ask it to, using sensible values.

```{r}
orsf_summarize_uni(fit, n_variables = 2)
```


For more on PD, see the [vignette](https://docs.ropensci.org/aorsf/articles/pd.html)

### Individual conditional expectations (ICE)
Expand Down
68 changes: 48 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ using `aorsf`:
``` r

orsf_vi_negate(fit)
#> bili age sex ast ascites
#> 0.0959635932 0.0162247725 0.0136525524 0.0085081124 0.0059358924
#> edema stage copper hepato chol
#> 0.0051286110 0.0019786308 0.0015829046 0.0007914523 -0.0003957262
#> alk.phos albumin spiders trt platelet
#> -0.0021764939 -0.0023743569 -0.0043529877 -0.0045508508 -0.0059358924
#> bili sex copper ast age
#> 0.1190578208 0.0619364315 0.0290605798 0.0260108174 0.0251162396
#> stage protime edema ascites hepato
#> 0.0237810058 0.0158443269 0.0117270641 0.0105685230 0.0092028195
#> albumin chol trt alk.phos spiders
#> 0.0082647861 0.0041510636 0.0036548364 0.0010239241 -0.0003298163
#> trig platelet
#> -0.0011111508 -0.0045314656
```

- **permutation**: Each variable is assessed separately by randomly
Expand All @@ -169,12 +171,14 @@ using `aorsf`:
``` r

orsf_vi_permute(fit)
#> bili ascites sex age edema
#> 0.0096952909 0.0073209339 0.0067273447 0.0065294816 0.0037989711
#> albumin stage protime hepato chol
#> 0.0031658093 0.0029679462 0.0023743569 0.0019786308 0.0007914523
#> ast spiders copper trt trig
#> 0.0003957262 -0.0019786308 -0.0027700831 -0.0049465770 -0.0055401662
#> bili copper ast age sex
#> 0.0514084384 0.0170611427 0.0142227933 0.0140274813 0.0131527430
#> stage protime ascites edema albumin
#> 0.0119752045 0.0102865556 0.0098067817 0.0081730899 0.0080568255
#> hepato chol alk.phos trig spiders
#> 0.0069734562 0.0032811220 0.0015862128 0.0014909643 0.0007811902
#> trt platelet
#> -0.0007067631 -0.0022135241
```

- **analysis of variance (ANOVA)**<sup>3</sup>: A p-value is computed
Expand All @@ -190,17 +194,16 @@ using `aorsf`:
``` r

orsf_vi_anova(fit)
#> ascites bili edema sex age copper stage
#> 0.35231788 0.33216374 0.31401592 0.22045995 0.19044776 0.18155620 0.16907605
#> ast hepato albumin chol trig protime spiders
#> 0.14183124 0.13736655 0.12611012 0.11461988 0.10847044 0.10697115 0.08802817
#> alk.phos platelet trt
#> 0.07943094 0.06150342 0.04411765
#> ascites bili edema sex copper age ast
#> 0.39107612 0.36316990 0.36316238 0.24720893 0.20547180 0.19213732 0.19029233
#> albumin stage hepato trig chol protime alk.phos
#> 0.17219680 0.17068758 0.16126761 0.13379872 0.12964021 0.12659698 0.12352611
#> spiders platelet trt
#> 0.11728395 0.08997135 0.07305095
```

You can supply your own R function to estimate out-of-bag error when
using negation or permutation importance. This feature is experimental
and may be changed in the future (see [oob
using negation or permutation importance (see [oob
vignette](https://docs.ropensci.org/aorsf/articles/oobag.html))

### Partial dependence (PD)
Expand All @@ -211,6 +214,31 @@ is marginalized over the values of all other predictors, giving
something like a multivariable adjusted estimate of the model’s
prediction.

The summary function, `orsf_summarize_uni()`, computes PD for as many
variables as you ask it to, using sensible values.

``` r

orsf_summarize_uni(fit, n_variables = 2)
#>
#> -- bili (VI Rank: 1) ----------------------------
#>
#> |----------------- risk -----------------|
#> Value Mean Median 25th % 75th %
#> 0.70 0.2074286 0.09039332 0.03827337 0.3146957
#> 1.3 0.2261739 0.10784929 0.04915971 0.3425934
#> 3.2 0.3071951 0.21242141 0.11889617 0.4358309
#>
#> -- sex (VI Rank: 2) -----------------------------
#>
#> |----------------- risk -----------------|
#> Value Mean Median 25th % 75th %
#> m 0.3648659 0.2572239 0.15554270 0.5735661
#> f 0.2479179 0.1021787 0.04161796 0.3591612
#>
#> Predicted risk at time t = 1826.25 for top 2 predictors
```

For more on PD, see the
[vignette](https://docs.ropensci.org/aorsf/articles/pd.html)

Expand Down
7 changes: 6 additions & 1 deletion src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,14 @@ std::vector<std::vector<arma::mat>> Forest::compute_dependence(bool oobag){

if(pd_type == PD_SUMMARY){

mat preds_summary = mean(preds, 0);
if(preds.has_nonfinite()){
uvec is_finite = find_finite(preds.col(0));
preds = preds.rows(is_finite);
}

mat preds_summary = mean(preds, 0);
mat preds_quant = quantile(preds, pd_probs, 0);

result_k.push_back(join_vert(preds_summary, preds_quant));

} else if(pd_type == PD_ICE) {
Expand Down
2 changes: 2 additions & 0 deletions tests/testthat/test-orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,8 @@ test_that(
}
)



# high pred horizon
# TODO: move this to test file for summarize
# test_that(
Expand Down
23 changes: 11 additions & 12 deletions tests/testthat/test-orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,28 @@
#' @srrstats {G5.3} *Test that fits returned contain no missing (`NA`) or undefined (`NaN`, `Inf`) values.*
#' @srrstats {G5.8, G5.8d} **Edge condition tests** * an error is thrown when partial dependence functions are asked to predict estimates outside of boundaries determined by the aorsf model's training data*

fit <- orsf(formula = Surv(time, status) ~ .,
data = pbc_orsf)

fit_nodat <- orsf(formula = Surv(time, status) ~ .,
data = pbc_orsf,
attach_data = FALSE)

test_that(
desc = "oob stops if there are no data",
code = {

fit_nodat <- orsf(formula = Surv(time, status) ~ .,
data = pbc_orsf,
attach_data = FALSE)

expect_error(
orsf_pd_oob(fit_nodat, pred_spec = list(bili = c(0.8))),
regexp = 'no data'
)
}
)

fit <- fit_standard_pbc$fast

test_that(
"user cant supply empty pred_spec",
code = {
expect_error(
orsf_ice_oob(fit,
pred_spec = list()),
orsf_ice_oob(fit, pred_spec = list()),
regexp = 'pred_spec is empty'
)
}
Expand All @@ -37,8 +36,8 @@ test_that(
expect_error(
orsf_ice_oob(fit,
pred_spec = list(bili = 1:5,
nope = c(1,2),
no_sir = 1),
nope = c(1,2),
no_sir = 1),
pred_horizon = 1000),
regexp = 'nope and no_sir'
)
Expand Down Expand Up @@ -105,6 +104,7 @@ test_that(
test_that(
'No missing values in output',
code = {

expect_false(any(is.na(pd_vals_ice)))
expect_false(any(is.nan(as.matrix(pd_vals_ice))))
expect_false(any(is.infinite(as.matrix(pd_vals_ice))))
Expand All @@ -117,7 +117,6 @@ test_that(

test_that(
'multi-valued horizon inputs are allowed',
# as a bonus, repeat the test for equality with oob
code = {

pd_smry_multi_horiz <- orsf_pd_oob(
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/test-orsf_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ test_preds_surv <- function(pred_type){


if(pred_type %in% c("risk", "surv")){

test_that(
desc = paste("predictions of type", pred_type, "are bounded"),
code = {
Expand All @@ -67,6 +68,25 @@ test_preds_surv <- function(pred_type){
}
)

}

if(pred_type == 'mort'){

test_that(
desc = "predictions are accurate",
code = {

surv_concord <- survival::concordance(
survival::Surv(time, status) ~ prd_agg,
data = pbc_test
)

mort_cstat <- 1 - surv_concord$concordance

expect_true(mort_cstat > 0.60)

}
)

}

Expand Down

0 comments on commit 386f66e

Please sign in to comment.