Skip to content

Commit

Permalink
pushing test-orsf components to specific contexts; WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 4, 2023
1 parent 5ea7b3c commit e53cfab
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 113 deletions.
17 changes: 9 additions & 8 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -826,19 +826,21 @@ orsf <- function(data,
rev(orsf_out$importance[order(orsf_out$importance), , drop=TRUE])
}

if(oobag_pred && !no_fit){

# put the oob predictions into the same order as the training data.
unsorted <- collapse::radixorder(sorted)

# makes labels for oobag evaluation type
if(oobag_pred){

orsf_out$eval_oobag$stat_type <-
switch(EXPR = as.character(orsf_out$eval_oobag$stat_type),
"0" = "None",
"1" = "Harrell's C-statistic",
"2" = "User-specified function")

if(!no_fit){

# put the oob predictions into the same order as the training data.
unsorted <- collapse::radixorder(sorted)

# makes labels for oobag evaluation type

if(oobag_pred_type == 'leaf'){
all_rows <- seq(nrow(data))
for(i in seq(n_tree)){
Expand All @@ -852,8 +854,7 @@ orsf <- function(data,

orsf_out$pred_oobag[is.nan(orsf_out$pred_oobag)] <- NA_real_



}

}

Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/helper-orsf.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

library(survival)

# misc functions used for tests

no_miss_list <- function(l){
Expand Down Expand Up @@ -36,3 +38,15 @@ f_pca <- function(x_node, y_node, w_node) {
pca$rotation[, 2, drop = FALSE]

}

expect_equal_leaf_summary <- function(x, y){
expect_equal(x$forest$leaf_summary,
y$forest$leaf_summary,
tolerance = 1e-9)
}

expect_equal_oobag_eval <- function(x, y){
expect_equal(x$eval_oobag$stat_values,
y$eval_oobag$stat_values,
tolerance = 1e-9)
}
10 changes: 10 additions & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@


# standard fit object used to check validity of other fits

seeds_standard <- c(5, 20, 1000, 30, 50, 98, 22, 100, 329, 10)

fit_standard <- orsf(pbc_orsf,
formula = time + status ~ . - id,
n_tree = 10,
tree_seed = seeds_standard)
52 changes: 4 additions & 48 deletions tests/testthat/test-orsf.R
Original file line number Diff line number Diff line change
@@ -1,52 +1,7 @@

library(survival) # for Surv

#' @srrstats {G5.0} *tests use the PBC data, a standard set that has been widely studied and disseminated in other R package (e.g., survival and randomForestSRC)*

# catch bad inputs, give informative error

pbc_temp <- pbc_orsf
pbc_temp$id <- factor(pbc_temp$id)
pbc_temp$status <- pbc_temp$status+1

# should get the same forest, whether status is 1/2 or 0/1 or a surv object

pbc_surv <- Surv(pbc_temp$time, pbc_temp$status)
pbc_surv_data <- cbind(pbc_temp, surv_object=pbc_surv)

fit_surv <- orsf(pbc_surv_data,
formula = surv_object ~ . - id - time - status,
n_tree = 10,
tree_seed = 1:10)

fit_surv_untrained <- orsf(pbc_surv_data,
formula = surv_object ~ . - id - time - status,
n_tree = 10,
tree_seed = 1:10,
no_fit = TRUE)

fit_surv_trained <- orsf_train(fit_surv_untrained)

fit_12 <- orsf(pbc_temp,
formula = Surv(time, status) ~ . -id,
n_tree = 10,
tree_seeds = 1:10)

fit_01 <- orsf(pbc_orsf,
formula = time + status ~ . -id,
n_tree = 10,
tree_seeds = 1:10)


test_that(
desc = 'New status, same forest',
code = {
expect_identical(fit_12$forest, fit_01$forest)
expect_identical(fit_surv$forest, fit_01$forest)
expect_identical(fit_surv_trained$forest, fit_01$forest)
}
)

#' @srrstats {G5.0} *tests use the PBC data, a standard set that has been widely studied and disseminated in other R package (e.g., survival and randomForestSRC)*

f <- time + status ~ . - id

Expand All @@ -62,8 +17,9 @@ test_that(
expect_error(orsf(pbc_orsf, f, attachData = TRUE), 'attach_data?')
expect_error(orsf(pbc_orsf, f, Control = 0), 'control?')

pbc_temp$date_var <- Sys.Date()
expect_error(orsf(pbc_temp, f), 'unsupported type')
pbc_orsf$date_var <- Sys.Date()
expect_error(orsf(pbc_orsf, f), 'unsupported type')
pbc_orsf$date_var <- NULL

}
)
Expand Down
147 changes: 90 additions & 57 deletions tests/testthat/test-orsf_formula.R
Original file line number Diff line number Diff line change
@@ -1,48 +1,45 @@

require(survival)

# set id to a factor so that it can trigger the id error
pbc_orsf$id <- factor(pbc_orsf$id)
pbc_orsf$status <- pbc_orsf$status+1

f1 <- Surv(time, status) ~ unknown_variable + bili
# dropped test - see https://github.com/mlr-org/mlr3extralearners/issues/259
# f2 <- Surv(time, status) ~ bili
f3 <- Surv(time, status) ~ bili + factor(hepato)
f4 <- Surv(time, status) ~ bili * ascites
f5 <- Surv(time, status) ~ bili + id
f6 <- Surv(time, not_right) ~ .
f7 <- Surv(not_right, status) ~ .
f8 <- Surv(start, time, status) ~ .
f9 <- Surv(status, time) ~ . - id
f10 <- Surv(time, time) ~ . - id
f11 <- Surv(time, id) ~ . -id
f12 <- Surv(time, status) ~ . -id
f13 <- ~ .
f14 <- status + time ~ . - id
f15 <- time + status ~ id + bili

#' @srrstats {G5.2} *Appropriate error behaviour is explicitly demonstrated through tests.*
#' @srrstats {G5.2b} *Tests demonstrate conditions which trigger error messages.*
test_that(
desc = 'formula inputs are vetted',
code = {

expect_error(orsf(pbc_orsf, f1), 'not found in data')
# # dropped - see https://github.com/mlr-org/mlr3extralearners/issues/259
# expect_warning(orsf(pbc_orsf, f2), 'at least 2 predictors')
expect_error(orsf(pbc_orsf, f3), 'unrecognized')
expect_error(orsf(pbc_orsf, f4), 'unrecognized')
expect_error(orsf(pbc_orsf, f5), 'id variable?')
expect_error(orsf(pbc_orsf, f6), 'not_right')
expect_error(orsf(pbc_orsf, f7), 'not_right')
expect_error(orsf(pbc_orsf, f8), 'must have two variables')
expect_error(orsf(pbc_orsf, f9), 'Did you enter')
expect_error(orsf(pbc_orsf, f10), 'must have two variables')
expect_error(orsf(pbc_orsf, f11), 'detected >1 event type')
expect_error(orsf(pbc_orsf, f13), 'must be two sided')
expect_error(orsf(pbc_orsf, f14), 'Did you enter')
expect_error(orsf(pbc_orsf, f15), "as many levels as there are rows")
# set id to a factor so that it can trigger the id error
pbc_orsf$id <- factor(pbc_orsf$id)

expect_error(orsf(pbc_orsf, Surv(time, status) ~ unknown_variable + bili),
'not found in data')

expect_error(orsf(pbc_orsf, Surv(time, status) ~ bili + factor(hepato)),
'unrecognized')

expect_error(orsf(pbc_orsf, Surv(time, status) ~ bili * ascites),
'unrecognized')

expect_error(orsf(pbc_orsf, Surv(time, status) ~ bili + id),
'id variable?')

expect_error(orsf(pbc_orsf, Surv(time, not_right) ~ .),
'not_right')

expect_error(orsf(pbc_orsf, Surv(not_right, status) ~ .),
'not_right')

expect_error(orsf(pbc_orsf, Surv(start, time, status) ~ .),
'must have two variables')

expect_error(orsf(pbc_orsf, Surv(time, time) ~ . - id),
'must have two variables')

expect_error(orsf(pbc_orsf, Surv(time, id) ~ . -id),
'detected >1 event type')

expect_error(orsf(pbc_orsf, ~ .), 'must be two sided')

expect_error(orsf(pbc_orsf, time + status ~ id + bili),
"as many levels as there are rows")

}
)
Expand All @@ -51,25 +48,7 @@ test_that(
desc = 'long formulas with repetition are allowed',
code = {

x_vars <- c(
"trt",
"age",
"sex",
"ascites",
"hepato",
"spiders",
"edema",
"bili",
"chol",
"albumin",
"copper",
"alk.phos",
"ast",
"trig",
"platelet",
"protime",
"stage"
)
x_vars <- c(setdiff(names(pbc_orsf), c('time', 'status', 'id')))

long_rhs <- paste(x_vars, collapse = ' + ')

Expand All @@ -79,12 +58,66 @@ test_that(

f_long <- as.formula(paste("time + status ~", long_rhs))

fit_long <- orsf(formula = f_long, pbc_orsf, n_tree = 10)
fit_long <- orsf(pbc_orsf,
formula = f_long,
n_tree = 10,
tree_seeds = seeds_standard)

# fits the orsf as expected
expect_s3_class(fit_long, 'orsf_fit')
# keeps unique names
expect_equal(x_vars, get_names_x(fit_long))
# is the same forest as standard
expect_equal_leaf_summary(fit_long, fit_standard)

}
)

test_that(
desc = "Surv objects in formula are used correctly",
code = {

pbc_surv <- Surv(pbc_orsf$time, pbc_orsf$status)

pbc_surv_data <- cbind(pbc_orsf, surv_object = pbc_surv)

fit_surv <- orsf(
pbc_surv_data,
formula = surv_object ~ . - id - time - status,
n_tree = 10,
tree_seed = seeds_standard
)

# name of surv object is correctly stored, values can be reproduced
expect_equal(
pbc_surv_data[[get_names_y(fit_surv)]],
pbc_surv
)

# different formula but same as standard forest
expect_equal_leaf_summary(fit_surv, fit_standard)
}
)

test_that(
desc = "Status can be 0/1 or 1/2, or generally x/x+1",
code = {
for(i in seq(1:5)){

pbc_orsf$status <- pbc_orsf$status+1

fit_status_modified <- orsf(pbc_orsf,
time + status ~ . - id,
n_tree = 10,
tree_seeds = seeds_standard)

expect_equal_leaf_summary(fit_status_modified, fit_standard)

expect_error(
orsf(pbc_orsf, Surv(status, time) ~ . - id),
'Did you enter'
)

}
}
)
26 changes: 26 additions & 0 deletions tests/testthat/test-orsf_train.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@


test_that(
desc = "untrained forest acts the same as a trained one",
code = {

fit_untrained <- orsf(pbc_orsf,
formula = time + status ~ . - id,
n_tree = 10,
tree_seed = seeds_standard,
no_fit = TRUE)

expect_true(is_empty(fit_untrained$eval_oobag$stat_values))


expect_equal( attr(fit_untrained, 'trained'), FALSE )

fit_trained <- orsf_train(fit_untrained)

expect_equal(fit_trained$forest$leaf_summary,
fit_standard$forest$leaf_summary)


}
)

0 comments on commit e53cfab

Please sign in to comment.