From 523272790d50c8e4a024826009ede26dbb8b0576 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sat, 21 Oct 2023 16:07:39 -0400 Subject: [PATCH 1/5] delayed fill of pd values --- R/RcppExports.R | 4 ++++ src/Data.h | 35 ++++++++++++++++++++++++++++++ src/RcppExports.cpp | 19 ++++++++++++++++ src/orsf_oop.cpp | 19 ++++++++++++++++ tests/testthat/test-DataCpp.R | 41 +++++++++++++++++++++++++++++++++++ 5 files changed, 118 insertions(+) diff --git a/R/RcppExports.R b/R/RcppExports.R index 64d0177f..d306fd24 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -45,6 +45,10 @@ x_submat_mult_beta_exported <- function(x, y, w, x_rows, x_cols, beta) { .Call(`_aorsf_x_submat_mult_beta_exported`, x, y, w, x_rows, x_cols, beta) } +x_submat_mult_beta_pd_exported <- function(x, y, w, x_rows, x_cols, beta, pd_x_vals, pd_x_cols) { + .Call(`_aorsf_x_submat_mult_beta_pd_exported`, x, y, w, x_rows, x_cols, beta, pd_x_vals, pd_x_cols) +} + scale_x_exported <- function(x, w) { .Call(`_aorsf_scale_x_exported`, x, w) } diff --git a/src/Data.h b/src/Data.h index 9210701f..0a6fd41e 100644 --- a/src/Data.h +++ b/src/Data.h @@ -113,6 +113,41 @@ } + // multiply X matrix by lincomb coefficients + // without taking a sub-matrix of X + arma::vec x_submat_mult_beta(arma::uvec& x_rows, + arma::uvec& x_cols, + arma::vec& beta, + arma::vec& pd_x_vals, + arma::uvec& pd_x_cols){ + + arma::vec out(x_rows.size()); + arma::uword j = 0; + + for(auto col : x_cols){ + + arma::uword i = 0; + arma::uvec pd_col = find(pd_x_cols == col); + + if(pd_col.is_empty()){ + for(auto row : x_rows){ + out[i] += x.at(row, col) * beta[j]; + i++; + } + } else { + for(i = 0; i < out.size(); i++){ + out[i] += pd_x_vals[pd_col[0]] * beta[j]; + } + } + + j++; + + } + + return(out); + + } + void permute_col(arma::uword j, std::mt19937_64& rng){ arma::vec x_j = x.unsafe_col(j); diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index c4cd334c..f015c900 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -170,6 +170,24 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// x_submat_mult_beta_pd_exported +arma::vec x_submat_mult_beta_pd_exported(arma::mat& x, arma::mat& y, arma::vec& w, arma::uvec& x_rows, arma::uvec& x_cols, arma::vec& beta, arma::vec& pd_x_vals, arma::uvec& pd_x_cols); +RcppExport SEXP _aorsf_x_submat_mult_beta_pd_exported(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP x_rowsSEXP, SEXP x_colsSEXP, SEXP betaSEXP, SEXP pd_x_valsSEXP, SEXP pd_x_colsSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type x(xSEXP); + Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP); + Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP); + Rcpp::traits::input_parameter< arma::uvec& >::type x_rows(x_rowsSEXP); + Rcpp::traits::input_parameter< arma::uvec& >::type x_cols(x_colsSEXP); + Rcpp::traits::input_parameter< arma::vec& >::type beta(betaSEXP); + Rcpp::traits::input_parameter< arma::vec& >::type pd_x_vals(pd_x_valsSEXP); + Rcpp::traits::input_parameter< arma::uvec& >::type pd_x_cols(pd_x_colsSEXP); + rcpp_result_gen = Rcpp::wrap(x_submat_mult_beta_pd_exported(x, y, w, x_rows, x_cols, beta, pd_x_vals, pd_x_cols)); + return rcpp_result_gen; +END_RCPP +} // scale_x_exported List scale_x_exported(arma::mat& x, arma::vec& w); RcppExport SEXP _aorsf_scale_x_exported(SEXP xSEXP, SEXP wSEXP) { @@ -261,6 +279,7 @@ static const R_CallMethodDef CallEntries[] = { {"_aorsf_sprout_node_survival_exported", (DL_FUNC) &_aorsf_sprout_node_survival_exported, 2}, {"_aorsf_find_rows_inbag_exported", (DL_FUNC) &_aorsf_find_rows_inbag_exported, 2}, {"_aorsf_x_submat_mult_beta_exported", (DL_FUNC) &_aorsf_x_submat_mult_beta_exported, 6}, + {"_aorsf_x_submat_mult_beta_pd_exported", (DL_FUNC) &_aorsf_x_submat_mult_beta_pd_exported, 8}, {"_aorsf_scale_x_exported", (DL_FUNC) &_aorsf_scale_x_exported, 2}, {"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2}, {"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44}, diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index 1b7bbca4..75d103c1 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -212,6 +212,25 @@ } + // [[Rcpp::export]] + arma::vec x_submat_mult_beta_pd_exported(arma::mat& x, + arma::mat& y, + arma::vec& w, + arma::uvec& x_rows, + arma::uvec& x_cols, + arma::vec& beta, + arma::vec& pd_x_vals, + arma::uvec& pd_x_cols){ + + std::unique_ptr data = std::make_unique(x, y, w); + + vec out = data->x_submat_mult_beta(x_rows, x_cols, beta, + pd_x_vals, pd_x_cols); + + return(out); + + } + // [[Rcpp::export]] List scale_x_exported(arma::mat& x, arma::vec& w){ diff --git a/tests/testthat/test-DataCpp.R b/tests/testthat/test-DataCpp.R index f7748ef9..54f2ecb1 100644 --- a/tests/testthat/test-DataCpp.R +++ b/tests/testthat/test-DataCpp.R @@ -16,9 +16,50 @@ test_that( x_cols = x_cols - 1, beta = beta) + # won't be used + pd_x_vals <- matrix(1) + pd_x_cols <- ncol(pbc_mats$x) + + data_cpp_answer_2 <- x_submat_mult_beta_pd_exported(x = pbc_mats$x, + y = pbc_mats$y, + w = pbc_mats$w, + x_rows = x_rows - 1, + x_cols = x_cols - 1, + beta = beta, + pd_x_vals = pd_x_vals, + pd_x_cols = pd_x_cols) + target <- pbc_mats$x[x_rows, x_cols] %*% beta expect_equal(data_cpp_answer, target) + expect_equal(data_cpp_answer_2, target) }) +test_that( + desc = "submatrix multiplication with PD values is correct", + code = { + + pd_x_vals <- c(0, 0) + pd_x_cols <- x_cols[1:2] - 1 + + data_cpp_answer_2 <- x_submat_mult_beta_pd_exported(x = pbc_mats$x, + y = pbc_mats$y, + w = pbc_mats$w, + x_rows = x_rows - 1, + x_cols = x_cols - 1, + beta = beta, + pd_x_vals = pd_x_vals, + pd_x_cols = pd_x_cols) + + x_pd <- pbc_mats$x + x_pd[, x_cols[c(1,2)]] <- 0 + + target <- x_pd[x_rows, x_cols] %*% beta + + expect_equal(data_cpp_answer_2, target) + + + } +) + From 70a1c656f7f311c0a97b521697b3ed4edf31aa97 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sun, 22 Oct 2023 08:52:07 -0400 Subject: [PATCH 2/5] safety --- src/Data.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Data.h b/src/Data.h index 0a6fd41e..09bfa98e 100644 --- a/src/Data.h +++ b/src/Data.h @@ -117,10 +117,14 @@ // without taking a sub-matrix of X arma::vec x_submat_mult_beta(arma::uvec& x_rows, arma::uvec& x_cols, - arma::vec& beta, - arma::vec& pd_x_vals, + arma::vec& beta, + arma::vec& pd_x_vals, arma::uvec& pd_x_cols){ + if(pd_x_cols.is_empty()){ + return(x_submat_mult_beta(x_rows, x_cols, beta)); + } + arma::vec out(x_rows.size()); arma::uword j = 0; From e30e0988f359efedacff1a5b9e79dde6be63d111 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sun, 22 Oct 2023 08:52:32 -0400 Subject: [PATCH 3/5] general purpose orsf_cpp --- R/infer.R | 151 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) diff --git a/R/infer.R b/R/infer.R index 075c5bd7..1ab9781e 100644 --- a/R/infer.R +++ b/R/infer.R @@ -1,4 +1,11 @@ +#' null operator (copied from rlang) + +`%||%` <- function (x, y) { + if (is.null(x)) + y + else x +} #' helper for guessing pred_horizon input #' @@ -73,3 +80,147 @@ infer_outcome_type <- function(names_y_data, data){ stop("could not infer outcome type", call. = FALSE) } + + +infer_orsf_args <- function(x, + y, + w = rep(1, nrow(x)), + ..., + object = NULL){ + + .dots <- list(...) + + control <- .dots$control %||% + get_control(object) %||% + orsf_control_fast() + + n_tree = .dots$n_tree %||% + get_n_tree(object) %||% + 500L + + tree_type = .dots$tree_type %||% + get_tree_type(object) %||% + 'survival' + + split_rule <- .dots$split_rule %||% + get_split_rule(object) %||% + 'logrank' + + split_min_stat <- .dots$split_min_stat %||% + get_split_min_stat(object) %||% + switch(split_rule, "logrank" = 3.841459, "cstat" = 0.50) + + mtry <- .dots$mtry %||% + get_mtry(object) %||% + ceiling(sqrt(ncol(x))) + + oobag_pred_type <- get_oobag_pred_type(object) %||% "surv" + oobag_pred <- oobag_pred_type != 'none' + + + pred_horizon <- get_oobag_pred_horizon(object) %||% + if(tree_type == 'survival') stats::median(y[, 1]) else 1 + + type_oobag_eval <- get_type_oobag_eval(object) %||% + if(oobag_pred) 'cstat' else 'none' + + vi_type <- .dots$vi_type %||% get_importance(object) %||% "anova" + + pd_type <- .dots$pd_type %||% 'none' + + list( + x = x, + y = y, + w = w, + tree_type_R = switch(tree_type, + 'classification' = 1, + 'regression'= 2, + 'survival' = 3), + tree_seeds = .dots$tree_seeds %||% get_tree_seeds(object) %||% 329, + loaded_forest = object$forest %||% list(), + n_tree = n_tree, + mtry = mtry, + sample_with_replacement = .dots$sample_with_replacement %||% + get_sample_with_replacement(object) %||% + TRUE, + sample_fraction = .dots$sample_fraction %||% + get_sample_fraction(object) %||% + 0.632, + vi_type_R = switch(vi_type, + "none" = 0, + "negate" = 1, + "permute" = 2, + "anova" = 3), + vi_max_pvalue = .dots$vi_max_pvalue %||% + get_vi_max_pvalue(object) %||% + 0.01, + leaf_min_events = .dots$leaf_min_events %||% + get_leaf_min_events(object) %||% + 1, + leaf_min_obs = .dots$leaf_min_obs %||% + get_leaf_min_obs(object) %||% + 5, + split_rule_R = switch(split_rule, "logrank" = 1, "cstat" = 2), + split_min_events = .dots$split_min_event %||% + get_split_min_events(object) %||% + 5, + split_min_obs = .dots$split_min_obs %||% + get_split_min_obs(object) %||% + 10, + split_min_stat = .dots$split_min_stat %||% + get_split_min_stat(object) %||% + NA_real_, + split_max_cuts = .dots$split_max_cuts %||% + get_n_split(object) %||% + 5, + split_max_retry = .dots$split_max_retry %||% + get_n_retry(object) %||% + 3, + lincomb_R_function = control$lincomb_R_function, + lincomb_type_R = switch(control$lincomb_type, + 'glm' = 1, + 'random' = 2, + 'net' = 3, + 'custom' = 4), + lincomb_eps = control$lincomb_eps, + lincomb_iter_max = control$lincomb_iter_max, + lincomb_scale = control$lincomb_scale, + lincomb_alpha = control$lincomb_alpha, + lincomb_df_target = control$lincomb_df_target %||% mtry, + lincomb_ties_method = switch(tolower(control$lincomb_ties_method), + 'breslow' = 0, + 'efron' = 1), + pred_type_R = switch(oobag_pred_type, + "none" = 0, + "risk" = 1, + "surv" = 2, + "chf" = 3, + "mort" = 4, + "leaf" = 8), + pred_mode = .dots$pred_mode %||% FALSE, + pred_aggregate = .dots$pred_aggregate %||% oobag_pred_type != 'leaf', + pred_horizon = pred_horizon, + oobag = oobag_pred, + oobag_R_function = .dots$oobag_R_function %||% + get_f_oobag_eval(object) %||% + function(x) x, + oobag_eval_type_R = switch(type_oobag_eval, + 'none' = 0, + 'cstat' = 1, + 'user' = 2), + oobag_eval_every = .dots$oobag_eval_every %||% + get_oobag_eval_every(object) %||% + n_tree, + pd_type_R = switch(pd_type, "none" = 0L, "smry" = 1L, "ice" = 2L), + pd_x_vals = .dots$pd_x_vals %||% list(matrix(0, ncol=0, nrow=0)), + pd_x_cols = .dots$pd_x_cols %||% list(matrix(0, ncol=0, nrow=0)), + pd_probs = .dots$pd_probs %||% 0, + n_thread = .dots$n_thread %||% get_n_thread(object) %||% 1, + write_forest = .dots$write_forest %||% TRUE, + run_forest = .dots$run_forest %||% TRUE, + verbosity = .dots$verbosity %||% + get_verbose_progress(object) %||% + FALSE + ) + +} From 130e81340a784baee5b0480361e450fa1c9e9bfb Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sun, 22 Oct 2023 08:52:49 -0400 Subject: [PATCH 4/5] better name --- tests/testthat/test-DataCpp.R | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-DataCpp.R b/tests/testthat/test-DataCpp.R index 54f2ecb1..7d6c7e14 100644 --- a/tests/testthat/test-DataCpp.R +++ b/tests/testthat/test-DataCpp.R @@ -20,7 +20,7 @@ test_that( pd_x_vals <- matrix(1) pd_x_cols <- ncol(pbc_mats$x) - data_cpp_answer_2 <- x_submat_mult_beta_pd_exported(x = pbc_mats$x, + data_cpp_pd_answer <- x_submat_mult_beta_pd_exported(x = pbc_mats$x, y = pbc_mats$y, w = pbc_mats$w, x_rows = x_rows - 1, @@ -29,10 +29,11 @@ test_that( pd_x_vals = pd_x_vals, pd_x_cols = pd_x_cols) + target <- pbc_mats$x[x_rows, x_cols] %*% beta expect_equal(data_cpp_answer, target) - expect_equal(data_cpp_answer_2, target) + expect_equal(data_cpp_pd_answer, target) }) @@ -43,7 +44,7 @@ test_that( pd_x_vals <- c(0, 0) pd_x_cols <- x_cols[1:2] - 1 - data_cpp_answer_2 <- x_submat_mult_beta_pd_exported(x = pbc_mats$x, + data_cpp_pd_answer <- x_submat_mult_beta_pd_exported(x = pbc_mats$x, y = pbc_mats$y, w = pbc_mats$w, x_rows = x_rows - 1, @@ -57,7 +58,7 @@ test_that( target <- x_pd[x_rows, x_cols] %*% beta - expect_equal(data_cpp_answer_2, target) + expect_equal(data_cpp_pd_answer, target) } From 3960e32c7c21a5256ff7b390589dc9c9cfbf7559 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sun, 22 Oct 2023 21:35:48 -0400 Subject: [PATCH 5/5] infer orsf args for cleaner orsf_cpp calls --- R/infer.R | 37 ++++++--- R/orsf.R | 2 +- R/orsf_attr.R | 3 +- R/orsf_pd.R | 32 +++++++ R/orsf_predict.R | 211 ++++++++++++++++++++++++++++++++++------------- 5 files changed, 213 insertions(+), 72 deletions(-) diff --git a/R/infer.R b/R/infer.R index 1ab9781e..a74a86aa 100644 --- a/R/infer.R +++ b/R/infer.R @@ -83,7 +83,7 @@ infer_outcome_type <- function(names_y_data, data){ infer_orsf_args <- function(x, - y, + y = matrix(1, ncol=2), w = rep(1, nrow(x)), ..., object = NULL){ @@ -114,17 +114,32 @@ infer_orsf_args <- function(x, get_mtry(object) %||% ceiling(sqrt(ncol(x))) - oobag_pred_type <- get_oobag_pred_type(object) %||% "surv" - oobag_pred <- oobag_pred_type != 'none' + oobag_pred_type <- .dots$pred_type %||% + get_oobag_pred_type(object) %||% + "surv" + oobag_pred <- .dots$oobag_pred %||% + get_oobag_pred(object) %||% + (oobag_pred_type != 'none') - pred_horizon <- get_oobag_pred_horizon(object) %||% + + pred_horizon <- .dots$pred_horizon %||% + get_oobag_pred_horizon(object) %||% if(tree_type == 'survival') stats::median(y[, 1]) else 1 - type_oobag_eval <- get_type_oobag_eval(object) %||% - if(oobag_pred) 'cstat' else 'none' + oobag_eval_type <- 'none' + + if(oobag_pred){ + + oobag_eval_type <- .dots$oobag_eval_type %||% + get_oobag_eval_type(object) %||% + "cstat" + + } - vi_type <- .dots$vi_type %||% get_importance(object) %||% "anova" + vi_type <- .dots$vi_type %||% + get_importance(object) %||% + "none" pd_type <- .dots$pd_type %||% 'none' @@ -136,7 +151,9 @@ infer_orsf_args <- function(x, 'classification' = 1, 'regression'= 2, 'survival' = 3), - tree_seeds = .dots$tree_seeds %||% get_tree_seeds(object) %||% 329, + tree_seeds = .dots$tree_seeds %||% + get_tree_seeds(object) %||% + 329, loaded_forest = object$forest %||% list(), n_tree = n_tree, mtry = mtry, @@ -198,13 +215,13 @@ infer_orsf_args <- function(x, "mort" = 4, "leaf" = 8), pred_mode = .dots$pred_mode %||% FALSE, - pred_aggregate = .dots$pred_aggregate %||% oobag_pred_type != 'leaf', + pred_aggregate = .dots$pred_aggregate %||% (oobag_pred_type != 'leaf'), pred_horizon = pred_horizon, oobag = oobag_pred, oobag_R_function = .dots$oobag_R_function %||% get_f_oobag_eval(object) %||% function(x) x, - oobag_eval_type_R = switch(type_oobag_eval, + oobag_eval_type_R = switch(oobag_eval_type, 'none' = 0, 'cstat' = 1, 'user' = 2), diff --git a/R/orsf.R b/R/orsf.R index 9a925bc7..42498d4f 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -1085,7 +1085,7 @@ orsf_train_ <- function(object, pred_aggregate = get_oobag_pred_type(object) != 'leaf', pred_horizon = get_oobag_pred_horizon(object), oobag = get_oobag_pred(object), - oobag_eval_type_R = switch(get_type_oobag_eval(object), + oobag_eval_type_R = switch(get_oobag_eval_type(object), 'none' = 0, 'cstat' = 1, 'user' = 2), diff --git a/R/orsf_attr.R b/R/orsf_attr.R index 200aca4a..eb43bc5f 100644 --- a/R/orsf_attr.R +++ b/R/orsf_attr.R @@ -34,7 +34,6 @@ get_modes <- function(object) attr(object, 'modes') get_standard_deviations<- function(object) attr(object, 'standard_deviations') get_n_retry <- function(object) attr(object, 'n_retry') get_f_oobag_eval <- function(object) attr(object, 'f_oobag_eval') -get_type_oobag_eval <- function(object) attr(object, 'type_oobag_eval') get_oobag_fun <- function(object) attr(object, 'oobag_fun') get_oobag_pred <- function(object) attr(object, 'oobag_pred') get_oobag_pred_type <- function(object) attr(object, 'oobag_pred_type') @@ -44,7 +43,7 @@ get_importance <- function(object) attr(object, 'importance') get_importance_values <- function(object) attr(object, 'importance_values') get_group_factors <- function(object) attr(object, 'group_factors') get_f_oobag_eval <- function(object) attr(object, 'f_oobag_eval') -get_type_oobag_eval <- function(object) attr(object, 'type_oobag_eval') +get_oobag_eval_type <- function(object) attr(object, 'type_oobag_eval') get_tree_seeds <- function(object) attr(object, 'tree_seeds') get_weights_user <- function(object) attr(object, 'weights_user') get_event_times <- function(object) attr(object, 'event_times') diff --git a/R/orsf_pd.R b/R/orsf_pd.R index 48caa608..ae836027 100644 --- a/R/orsf_pd.R +++ b/R/orsf_pd.R @@ -592,3 +592,35 @@ orsf_pred_dependence <- function(object, } +pd_list_split <- function(x_vals, x_cols){ + + x_vals_out <- x_cols_out <- vector(mode = 'list') + counter <- 1 + + for(i in seq_along(x_vals)){ + + x_vals_split <- split(x_vals[[i]], row(x_vals[[i]])) + + for(j in seq_along(x_vals_split)){ + + x_vals_out[[counter]] <- matrix(x_vals_split[[j]], + ncol = ncol(x_vals[[i]]), + nrow = 1) + colnames(x_vals_out[[counter]]) <- colnames(x_vals[[i]]) + + x_cols_out[[counter]] <- x_cols[[i]] + + counter <- counter + 1 + + } + + } + + list( + x_vals = x_vals_out, + x_cols = x_cols_out + ) + +} + + diff --git a/R/orsf_predict.R b/R/orsf_predict.R index 129f0cc4..dbef82a1 100644 --- a/R/orsf_predict.R +++ b/R/orsf_predict.R @@ -166,65 +166,20 @@ predict.orsf_fit <- function(object, x_new <- prep_x_from_orsf(object, data = new_data[cc, ]) - control <- get_control(object) - - orsf_out <- orsf_cpp(x = x_new, - y = matrix(1, ncol=2), - w = rep(1, nrow(x_new)), - tree_type_R = get_tree_type(object), - tree_seeds = get_tree_seeds(object), - loaded_forest = object$forest, - n_tree = get_n_tree(object), - mtry = get_mtry(object), - sample_with_replacement = get_sample_with_replacement(object), - sample_fraction = get_sample_fraction(object), - vi_type_R = 0, - vi_max_pvalue = get_vi_max_pvalue(object), - oobag_R_function = get_f_oobag_eval(object), - leaf_min_events = get_leaf_min_events(object), - leaf_min_obs = get_leaf_min_obs(object), - split_rule_R = switch(get_split_rule(object), - "logrank" = 1, - "cstat" = 2), - split_min_events = get_split_min_events(object), - split_min_obs = get_split_min_obs(object), - split_min_stat = get_split_min_stat(object), - split_max_cuts = get_n_split(object), - split_max_retry = get_n_retry(object), - lincomb_R_function = control$lincomb_R_function, - lincomb_type_R = switch(control$lincomb_type, - 'glm' = 1, - 'random' = 2, - 'net' = 3, - 'custom' = 4), - lincomb_eps = control$lincomb_eps, - lincomb_iter_max = control$lincomb_iter_max, - lincomb_scale = control$lincomb_scale, - lincomb_alpha = control$lincomb_alpha, - lincomb_df_target = control$lincomb_df_target, - lincomb_ties_method = switch(tolower(control$lincomb_ties_method), - 'breslow' = 0, - 'efron' = 1), - pred_type_R = switch(pred_type, - "risk" = 1, - "surv" = 2, - "chf" = 3, - "mort" = 4, - "leaf" = 8), - pred_mode = TRUE, - pred_aggregate = pred_aggregate, - pred_horizon = pred_horizon_ordered, - oobag = FALSE, - oobag_eval_type_R = 0, - oobag_eval_every = get_n_tree(object), - pd_type_R = 0, - pd_x_vals = list(matrix(0, ncol=1, nrow=1)), - pd_x_cols = list(matrix(1L, ncol=1, nrow=1)), - pd_probs = c(0), - n_thread = n_thread, - write_forest = FALSE, - run_forest = TRUE, - verbosity = as.integer(verbose_progress)) + # control <- get_control(object) + + args <- infer_orsf_args(x = x_new, + vi_type = 'none', + object = object, + pred_type = pred_type, + pred_aggregate = pred_aggregate, + pred_horizon = pred_horizon_ordered, + oobag_pred = FALSE, + pred_mode = TRUE, + write_forest = FALSE, + run_forest = TRUE) + + orsf_out <- do.call(orsf_cpp, args = args) out_values <- orsf_out$pred_new @@ -248,3 +203,141 @@ predict.orsf_fit <- function(object, } + + + + + + + + + + + +# old code in case the infer function fails me: +# args_tmp <- list(x = x_new, +# y = matrix(1, ncol=2), +# w = rep(1, nrow(x_new)), +# tree_type_R = get_tree_type(object), +# tree_seeds = get_tree_seeds(object), +# loaded_forest = object$forest, +# n_tree = get_n_tree(object), +# mtry = get_mtry(object), +# sample_with_replacement = get_sample_with_replacement(object), +# sample_fraction = get_sample_fraction(object), +# vi_type_R = 0, +# vi_max_pvalue = get_vi_max_pvalue(object), +# oobag_R_function = get_f_oobag_eval(object), +# leaf_min_events = get_leaf_min_events(object), +# leaf_min_obs = get_leaf_min_obs(object), +# split_rule_R = switch(get_split_rule(object), +# "logrank" = 1, +# "cstat" = 2), +# split_min_events = get_split_min_events(object), +# split_min_obs = get_split_min_obs(object), +# split_min_stat = get_split_min_stat(object), +# split_max_cuts = get_n_split(object), +# split_max_retry = get_n_retry(object), +# lincomb_R_function = control$lincomb_R_function, +# lincomb_type_R = switch(control$lincomb_type, +# 'glm' = 1, +# 'random' = 2, +# 'net' = 3, +# 'custom' = 4), +# lincomb_eps = control$lincomb_eps, +# lincomb_iter_max = control$lincomb_iter_max, +# lincomb_scale = control$lincomb_scale, +# lincomb_alpha = control$lincomb_alpha, +# lincomb_df_target = control$lincomb_df_target, +# lincomb_ties_method = switch(tolower(control$lincomb_ties_method), +# 'breslow' = 0, +# 'efron' = 1), +# pred_type_R = switch(pred_type, +# "risk" = 1, +# "surv" = 2, +# "chf" = 3, +# "mort" = 4, +# "leaf" = 8), +# pred_mode = TRUE, +# pred_aggregate = pred_aggregate, +# pred_horizon = pred_horizon_ordered, +# oobag = FALSE, +# oobag_eval_type_R = 0, +# oobag_eval_every = get_n_tree(object), +# pd_type_R = 0, +# pd_x_vals = list(matrix(0, ncol=1, nrow=1)), +# pd_x_cols = list(matrix(1L, ncol=1, nrow=1)), +# pd_probs = c(0), +# n_thread = n_thread, +# write_forest = FALSE, +# run_forest = TRUE, +# verbosity = as.integer(verbose_progress)) +# +# checkout <- c() +# +# for(i in names(args)){ +# print(i) +# if(!is.list(args[[i]]) && !is.function(args[[i]])){ +# if(!all(args[[i]] == args_tmp[[i]])) +# checkout <- c(checkout, i) +# } +# } +# browser() +# orsf_out <- orsf_cpp(x = x_new, +# y = matrix(1, ncol=2), +# w = rep(1, nrow(x_new)), +# tree_type_R = get_tree_type(object), +# tree_seeds = get_tree_seeds(object), +# loaded_forest = object$forest, +# n_tree = get_n_tree(object), +# mtry = get_mtry(object), +# sample_with_replacement = get_sample_with_replacement(object), +# sample_fraction = get_sample_fraction(object), +# vi_type_R = 0, +# vi_max_pvalue = get_vi_max_pvalue(object), +# oobag_R_function = get_f_oobag_eval(object), +# leaf_min_events = get_leaf_min_events(object), +# leaf_min_obs = get_leaf_min_obs(object), +# split_rule_R = switch(get_split_rule(object), +# "logrank" = 1, +# "cstat" = 2), +# split_min_events = get_split_min_events(object), +# split_min_obs = get_split_min_obs(object), +# split_min_stat = get_split_min_stat(object), +# split_max_cuts = get_n_split(object), +# split_max_retry = get_n_retry(object), +# lincomb_R_function = control$lincomb_R_function, +# lincomb_type_R = switch(control$lincomb_type, +# 'glm' = 1, +# 'random' = 2, +# 'net' = 3, +# 'custom' = 4), +# lincomb_eps = control$lincomb_eps, +# lincomb_iter_max = control$lincomb_iter_max, +# lincomb_scale = control$lincomb_scale, +# lincomb_alpha = control$lincomb_alpha, +# lincomb_df_target = control$lincomb_df_target, +# lincomb_ties_method = switch(tolower(control$lincomb_ties_method), +# 'breslow' = 0, +# 'efron' = 1), +# pred_type_R = switch(pred_type, +# "risk" = 1, +# "surv" = 2, +# "chf" = 3, +# "mort" = 4, +# "leaf" = 8), +# pred_mode = TRUE, +# pred_aggregate = pred_aggregate, +# pred_horizon = pred_horizon_ordered, +# oobag = FALSE, +# oobag_eval_type_R = 0, +# oobag_eval_every = get_n_tree(object), +# pd_type_R = 0, +# pd_x_vals = list(matrix(0, ncol=1, nrow=1)), +# pd_x_cols = list(matrix(1L, ncol=1, nrow=1)), +# pd_probs = c(0), +# n_thread = n_thread, +# write_forest = FALSE, +# run_forest = TRUE, +# verbosity = as.integer(verbose_progress)) +