Skip to content

Commit

Permalink
tests for oobag predict()
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Apr 27, 2024
1 parent ed06dd3 commit 957a8e3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
12 changes: 11 additions & 1 deletion R/orsf_R6.R
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,8 @@ ObliqueForest <- R6::R6Class(
# object is restored to its original state.
if(is_error(out)) stop(out, call. = FALSE)

return(out)
# NaNs may occur with oobag = TRUE and small n_tree
coerce_nans(out, to = NA_real_)


},
Expand Down Expand Up @@ -3940,6 +3941,7 @@ ObliqueForestSurvival <- R6::R6Class(
# put the oob predictions into the same order as the training data.
unsorted <- collapse::radixorder(private$data_row_sort)
out_values <- out_values[unsorted, , drop = FALSE]

}

private$clean_pred_new(out_values)
Expand Down Expand Up @@ -4177,6 +4179,14 @@ ObliqueForestClassification <- R6::R6Class(

},

clean_pred_oobag_internal = function(){

if(self$pred_type %in% c("prob") && is.matrix(self$pred_oobag)){
colnames(self$pred_oobag) <- self$class_levels
}

},

predict_internal = function(simplify, oobag){

# resize y to have the right number of columns
Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/test-orsf_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,28 @@ test_that(
}
)

test_that(
desc = "oobag predictions match predict with oobag = TRUE",
code = {

expect_equal(
fit_standard_pbc$fast$pred_oobag,
predict(fit_standard_pbc$fast, oobag = TRUE)
)

expect_equal(
fit_standard_penguin_bills$net$pred_oobag,
predict(fit_standard_penguin_bills$net, oobag = TRUE)
)

expect_equal(
fit_standard_penguin_species$custom$pred_oobag,
predict(fit_standard_penguin_species$custom, oobag = TRUE)
)

}
)

new_data <- pbc_test[1:10, ]

test_that(
Expand Down

0 comments on commit 957a8e3

Please sign in to comment.