Skip to content

Commit

Permalink
refactoring tests, minor bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 7, 2023
1 parent 0823fbf commit e0b4ff8
Show file tree
Hide file tree
Showing 29 changed files with 981 additions and 751 deletions.
8 changes: 6 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@ is_col_splittable_exported <- function(x, y, r, j) {
.Call(`_aorsf_is_col_splittable_exported`, x, y, r, j)
}

find_cutpoints_survival_exported <- function(y, w, lincomb, leaf_min_events, leaf_min_obs) {
.Call(`_aorsf_find_cutpoints_survival_exported`, y, w, lincomb, leaf_min_events, leaf_min_obs)
find_cuts_survival_exported <- function(y, w, lincomb, leaf_min_events, leaf_min_obs, split_rule_R) {
.Call(`_aorsf_find_cuts_survival_exported`, y, w, lincomb, leaf_min_events, leaf_min_obs, split_rule_R)
}

sprout_node_survival_exported <- function(y, w) {
.Call(`_aorsf_sprout_node_survival_exported`, y, w)
}

find_rows_inbag_exported <- function(rows_oobag, n_obs) {
.Call(`_aorsf_find_rows_inbag_exported`, rows_oobag, n_obs)
}

cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}
Expand Down
15 changes: 13 additions & 2 deletions R/infer.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,27 @@
#'
#' @noRd

infer_pred_horizon <- function(object, pred_horizon){
infer_pred_horizon <- function(object, pred_type, pred_horizon){

check_arg_is(object, 'object', 'orsf_fit')

if(pred_type %in% c("mort", "leaf")){
# value of pred_horizon does not matter for these types of prediction
pred_horizon <- 1
}

# see if it was previously specified
if(is.null(pred_horizon)) pred_horizon <- object$pred_horizon

if(is.null(pred_horizon))
# throw error if pred_type requires pred_horizon
if(is.null(pred_horizon)){

stop("pred_horizon was not specified and could not be found in object.",
call. = FALSE)

}


pred_horizon

}
25 changes: 21 additions & 4 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ orsf <- function(data,

if(oobag_pred){

# makes labels for oobag evaluation type
orsf_out$eval_oobag$stat_type <-
switch(EXPR = as.character(orsf_out$eval_oobag$stat_type),
"0" = "None",
Expand All @@ -840,8 +841,6 @@ orsf <- function(data,
# put the oob predictions into the same order as the training data.
unsorted <- collapse::radixorder(sorted)

# makes labels for oobag evaluation type

if(oobag_pred_type == 'leaf'){
all_rows <- seq(nrow(data))
for(i in seq(n_tree)){
Expand All @@ -850,11 +849,19 @@ orsf <- function(data,
}
}

#' @srrstats {G2.10} *drop = FALSE for type consistency*
orsf_out$pred_oobag <- orsf_out$pred_oobag[unsorted, , drop = FALSE]

orsf_out$pred_oobag[is.nan(orsf_out$pred_oobag)] <- NA_real_

# mortality predictions should always be 1 column
# b/c they do not depend on the prediction horizon
if(oobag_pred_type == 'mort'){
orsf_out$pred_oobag <-
orsf_out$pred_oobag[, 1L, drop = FALSE]

orsf_out$eval_oobag$stat_values <-
orsf_out$eval_oobag$stat_values[, 1, drop = FALSE]
}

}

}
Expand Down Expand Up @@ -1234,6 +1241,16 @@ orsf_train_ <- function(object,

object$pred_oobag <- object$pred_oobag[unsorted, , drop = FALSE]

# mortality predictions should always be 1 column
# b/c they do not depend on the prediction horizon
if(get_oobag_pred_type(object) == 'mort'){
object$pred_oobag <-
object$pred_oobag[, 1L, drop = FALSE]

object$eval_oobag$stat_values <-
object$eval_oobag$stat_values[, 1, drop = FALSE]
}

}

attr(object, "n_leaves_mean") <- compute_mean_leaves(orsf_out$forest)
Expand Down
2 changes: 1 addition & 1 deletion R/orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ orsf_pred_dependence <- function(object,
oobag,
type_output){

pred_horizon <- infer_pred_horizon(object, pred_horizon)
pred_horizon <- infer_pred_horizon(object, pred_type, pred_horizon)

# make a visible binding for CRAN
id_variable = NULL
Expand Down
7 changes: 3 additions & 4 deletions R/orsf_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,10 @@ predict.orsf_fit <- function(object,
warning("pred_horizon does not impact predictions",
" when pred_type is '", pred_type, "'.",
extra_text, call. = FALSE)
# avoid copies of predictions and copies of this warning.
pred_horizon <- pred_horizon[1]

}

pred_horizon <- infer_pred_horizon(object, pred_horizon)
pred_horizon <- infer_pred_horizon(object, pred_type, pred_horizon)

check_predict(object = object,
new_data = new_data,
Expand Down Expand Up @@ -167,7 +166,7 @@ predict.orsf_fit <- function(object,

}

if(is.null(pred_horizon) && pred_type != 'mort'){
if(is.null(pred_horizon) && !(pred_type %in% c('mort', 'leaf'))){
stop("pred_horizon must be specified for ",
pred_type, " predictions.", call. = FALSE)
}
Expand Down
9 changes: 4 additions & 5 deletions inst/CITATION
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ citEntry(
as.person("Matthew W. Segar"),
as.person("Ambarish Pandey"),
as.person("Nicholas M. Pajewski")),
journal = "arXiv",
year = "2022",
url = "https://arxiv.org/abs/2208.01129",
journal = "Journal of Computational and Graphical Statistics",
year = "2023",
url = "https://doi.org/10.1080/10618600.2023.2231048",
textVersion = paste(
"Jaeger BC, Welden S, Lenoir K, Speiser JL, Segar MW, Pandey A, Pajewski NM.",
"Accelerated and interpretable oblique random survival forests.",
"arXiv e-prints.",
"2022 Aug 3:arXiv-2208."
"Journal of Computational and Graphical Statistics. 2023 Aug 3:1-6."
)
)

Expand Down
10 changes: 5 additions & 5 deletions man/orsf_control_custom.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions man/orsf_ice_oob.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 17 additions & 17 deletions man/orsf_pd_oob.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 43 additions & 45 deletions man/orsf_vi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e0b4ff8

Please sign in to comment.