From 63cc2033ecca6caa9a7ca3595a2ab7bf15bb030d Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Fri, 12 Jan 2024 08:53:22 -0500 Subject: [PATCH] more careful checks for time to train --- CRAN-SUBMISSION | 4 ++-- R/orsf.R | 24 +++++++++++++++++++----- R/orsf_update.R | 2 ++ src/Tree.cpp | 1 + 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/CRAN-SUBMISSION b/CRAN-SUBMISSION index a835f801..bba7e17f 100644 --- a/CRAN-SUBMISSION +++ b/CRAN-SUBMISSION @@ -1,3 +1,3 @@ Version: 0.1.2 -Date: 2024-01-11 14:20:18 UTC -SHA: 48469549668880b518502d598e4f56765bf4ef1d +Date: 2024-01-12 02:20:10 UTC +SHA: 8cdb854819420d52a14b685653b47d2a55d2c7a8 diff --git a/R/orsf.R b/R/orsf.R index c003f86a..4b073abf 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -431,8 +431,9 @@ orsf_train <- function(object, attach_data = TRUE){ #' @param object an untrained `aorsf` object #' #' @param n_tree_subset (*integer*) how many trees should be fit in order -#' to estimate the time needed to train `object`. The default value is 50, -#' as this usually gives a good enough approximation. +#' to estimate the time needed to train `object`. The default value is 10% +#' of the trees specified in `object`. I.e., if `object` has `n_tree` of +#' 500, then the default value `n_tree_subset` is 50. #' #' @return a [difftime] object. #' @@ -444,8 +445,8 @@ orsf_train <- function(object, attach_data = TRUE){ #' object <- orsf(pbc_orsf, Surv(time, status) ~ . - id, #' n_tree = 10, no_fit = TRUE) #' -#' # approximate the time it will take to grow 500 trees -#' time_estimated <- orsf_time_to_train(object) +#' # approximate the time it will take to grow 10 trees +#' time_estimated <- orsf_time_to_train(object, n_tree_subset=1) #' #' print(time_estimated) #' @@ -462,10 +463,23 @@ orsf_train <- function(object, attach_data = TRUE){ #' abs(time_true - time_estimated) #' -orsf_time_to_train <- function(object, n_tree_subset = 50){ +orsf_time_to_train <- function(object, n_tree_subset = NULL){ n_tree_original <- object$n_tree + if(n_tree_original == 1){ + stop("Cannot estimate time to train for a forest with 1 tree.", + call. = FALSE) + } + + n_tree_subset <- n_tree_subset %||% ceiling(n_tree_original * 0.10) + + if (n_tree_subset >= n_tree_original){ + msg <- paste0("n_tree_subset (", n_tree_subset, ")", + "must be < n_tree_original (", n_tree_original, ").") + stop(msg, call. = FALSE) + } + time_train_start <- Sys.time() object$train(n_tree = n_tree_subset) diff --git a/R/orsf_update.R b/R/orsf_update.R index 4b27a131..f135e947 100644 --- a/R/orsf_update.R +++ b/R/orsf_update.R @@ -67,6 +67,7 @@ #' #' @examples #' +#' \dontrun{ #' # initial fit has mtry of 5 #' fit <- orsf(pbc_orsf, time + status ~ . -id) #' @@ -75,6 +76,7 @@ #' #' # prevent dynamic updates by specifying inputs you want to freeze. #' fit_newer <- orsf_update(fit_new, mtry = 2) +#' } #' #' orsf_update <- function(object, diff --git a/src/Tree.cpp b/src/Tree.cpp index 4a296214..d0577695 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -121,6 +121,7 @@ this->data = data; this->n_cols_total = data->n_cols_x; this->n_rows_total = data->n_rows; + this->seed = seed; this->mtry = mtry; this->sample_with_replacement = sample_with_replacement;