Skip to content

Commit

Permalink
more careful checks for time to train
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Jan 12, 2024
1 parent 8cdb854 commit 63cc203
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
4 changes: 2 additions & 2 deletions CRAN-SUBMISSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Version: 0.1.2
Date: 2024-01-11 14:20:18 UTC
SHA: 48469549668880b518502d598e4f56765bf4ef1d
Date: 2024-01-12 02:20:10 UTC
SHA: 8cdb854819420d52a14b685653b47d2a55d2c7a8
24 changes: 19 additions & 5 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,9 @@ orsf_train <- function(object, attach_data = TRUE){
#' @param object an untrained `aorsf` object
#'
#' @param n_tree_subset (*integer*) how many trees should be fit in order
#' to estimate the time needed to train `object`. The default value is 50,
#' as this usually gives a good enough approximation.
#' to estimate the time needed to train `object`. The default value is 10%
#' of the trees specified in `object`. I.e., if `object` has `n_tree` of
#' 500, then the default value `n_tree_subset` is 50.
#'
#' @return a [difftime] object.
#'
Expand All @@ -444,8 +445,8 @@ orsf_train <- function(object, attach_data = TRUE){
#' object <- orsf(pbc_orsf, Surv(time, status) ~ . - id,
#' n_tree = 10, no_fit = TRUE)
#'
#' # approximate the time it will take to grow 500 trees
#' time_estimated <- orsf_time_to_train(object)
#' # approximate the time it will take to grow 10 trees
#' time_estimated <- orsf_time_to_train(object, n_tree_subset=1)
#'
#' print(time_estimated)
#'
Expand All @@ -462,10 +463,23 @@ orsf_train <- function(object, attach_data = TRUE){
#' abs(time_true - time_estimated)
#'

orsf_time_to_train <- function(object, n_tree_subset = 50){
orsf_time_to_train <- function(object, n_tree_subset = NULL){

n_tree_original <- object$n_tree

if(n_tree_original == 1){
stop("Cannot estimate time to train for a forest with 1 tree.",
call. = FALSE)
}

n_tree_subset <- n_tree_subset %||% ceiling(n_tree_original * 0.10)

if (n_tree_subset >= n_tree_original){
msg <- paste0("n_tree_subset (", n_tree_subset, ")",
"must be < n_tree_original (", n_tree_original, ").")
stop(msg, call. = FALSE)
}

time_train_start <- Sys.time()

object$train(n_tree = n_tree_subset)
Expand Down
2 changes: 2 additions & 0 deletions R/orsf_update.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#'
#' @examples
#'
#' \dontrun{
#' # initial fit has mtry of 5
#' fit <- orsf(pbc_orsf, time + status ~ . -id)
#'
Expand All @@ -75,6 +76,7 @@
#'
#' # prevent dynamic updates by specifying inputs you want to freeze.
#' fit_newer <- orsf_update(fit_new, mtry = 2)
#' }
#'
#'
orsf_update <- function(object,
Expand Down
1 change: 1 addition & 0 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
this->data = data;
this->n_cols_total = data->n_cols_x;
this->n_rows_total = data->n_rows;

this->seed = seed;
this->mtry = mtry;
this->sample_with_replacement = sample_with_replacement;
Expand Down

0 comments on commit 63cc203

Please sign in to comment.