From 2a371362580efb30104d7ea2b1400272ca59bb6a Mon Sep 17 00:00:00 2001 From: bjaeger Date: Tue, 10 Oct 2023 22:47:43 -0400 Subject: [PATCH] safer coxph mult and cleaner test --- R/orsf.R | 8 +- R/ref_code.R | 80 ------------ R/srr-stats-standards.R | 122 ------------------ src/Coxph.cpp | 20 +-- src/TreeSurvival.cpp | 2 +- tests/testthat/test-coxph.R | 56 ++++++++ tests/testthat/test-newtraph_cph.R | 170 ------------------------- tests/testthat/test-orsf.R | 16 --- tests/testthat/test-orsf_summary.R | 15 +++ tests/testthat/test-ostree_pred_leaf.R | 63 --------- tests/testthat/test-ref_code.R | 16 +++ 11 files changed, 100 insertions(+), 468 deletions(-) delete mode 100644 R/srr-stats-standards.R create mode 100644 tests/testthat/test-coxph.R delete mode 100644 tests/testthat/test-newtraph_cph.R delete mode 100644 tests/testthat/test-ostree_pred_leaf.R create mode 100644 tests/testthat/test-ref_code.R diff --git a/R/orsf.R b/R/orsf.R index c1d55f35..24a8df11 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -1221,19 +1221,15 @@ orsf_train_ <- function(object, if(get_oobag_pred(object)){ - # put the oob predictions into the same order as the training data. - # TODO: this can be faster; see predict unsorting - unsorted <- vector(mode = 'integer', length = length(sorted)) - for(i in seq_along(unsorted)) unsorted[ sorted[i] ] <- i - # clear labels for oobag evaluation type - object$eval_oobag$stat_type <- switch(EXPR = as.character(object$eval_oobag$stat_type), "0" = "None", "1" = "Harrell's C-statistic", "2" = "User-specified function") + # put the oob predictions into the same order as the training data. + unsorted <- collapse::radixorder(sorted) object$pred_oobag <- object$pred_oobag[unsorted, , drop = FALSE] # mortality predictions should always be 1 column diff --git a/R/ref_code.R b/R/ref_code.R index 4b5e16b5..1fd47b19 100644 --- a/R/ref_code.R +++ b/R/ref_code.R @@ -93,86 +93,6 @@ ref_code <- function (x_data, fi, names_x_data){ } -# an older version of the function above that didn't use collapse -# (its about 2 times slower) -# ref_code <- function (x_data, fi, names_x_data){ -# -# # Will use these original names to help re-order the output -# -# for(i in seq_along(fi$cols)){ -# -# if(fi$cols[i] %in% names(x_data)){ -# -# if(fi$ordr[i]){ -# -# x_data[[ fi$cols[i] ]] <- as.integer( x_data[[ fi$cols[i] ]] ) -# -# } else { -# -# # make a matrix for each factor -# mat <- matrix(0, -# nrow = nrow(x_data), -# ncol = length(fi$lvls[[i]]) -# ) -# -# colnames(mat) <- fi$keys[[i]] -# -# # missing values of the factor become missing rows -# mat[is.na(x_data[[fi$cols[i]]]), ] <- NA_integer_ -# -# # we will one-hot encode the matrix and then bind it to data, -# # replacing the original factor column. Go through the matrix -# # column by column, where each column corresponds to a level -# # of the current factor (indexed by i). Flip the values -# # of the j'th column to 1 whenever the current factor's value -# # is the j'th level. -# -# for (j in seq(ncol(mat))) { -# -# # find which rows to turn into 1's. These should be the -# # indices in the currect factor where it's value is equal -# # to the j'th level. -# hot_rows <- which( x_data[[fi$cols[i]]] == fi$lvls[[i]][j] ) -# -# # after finding the rows, flip the values from 0 to 1 -# if(!is_empty(hot_rows)){ -# mat[hot_rows , j] <- 1 -# } -# -# } -# -# # data[[fi$cols[i]]] <- NULL -# -# x_data <- cbind(x_data, mat) -# -# } -# -# } -# -# -# -# } -# -# OH_names <- names_x_data -# -# for (i in seq_along(fi$cols)){ -# -# if(fi$cols[i] %in% names_x_data){ -# if(!fi$ordr[i]){ -# OH_names <- insert_vals( -# vec = OH_names, -# where = which(fi$cols[i] == OH_names), -# what = fi$keys[[i]][-1] -# ) -# } -# } -# -# } -# -# select_cols(x_data, OH_names) -# -# } - #' insert some value(s) into a vector #' #' diff --git a/R/srr-stats-standards.R b/R/srr-stats-standards.R deleted file mode 100644 index a1c088d4..00000000 --- a/R/srr-stats-standards.R +++ /dev/null @@ -1,122 +0,0 @@ -#' srr_stats -#' -#' All of the following standards initially have `@srrstatsTODO` tags. -#' These may be moved at any time to any other locations in your code. -#' Once addressed, please modify the tag from `@srrstatsTODO` to `@srrstats`, -#' or `@srrstatsNA`, ensuring that references to every one of the following -#' standards remain somewhere within your code. -#' (These comments may be deleted at any time.) -#' -#' @srrstatsVerbose TRUE -#' - - - - - - - - - - - - - - - - - - - - - -#' @noRd -NULL - -#' NA_standards -#' -#' Any non-applicable standards can have their tags changed from `@srrstatsTODO` -#' to `@srrstatsNA`, and placed together in this block, along with explanations -#' for why each of these standards have been deemed not applicable. -#' (These comments may also be deleted at any time.) -#' -#' @srrstatsNA {G2.4c} *orsf() and other functions in aorsf generally throw errors when data types do not meet expectations. I believe this is the design I want to pursue because I think trusting users to learn how to supply inputs correctly will lead to more accurate and reproducible science than relying on my software to continuously fix mistakes in the inputs.* -#' @srrstatsNA {G2.4d} see above -#' @srrstatsNA {G2.4e} see above -#' -#' @srrstatsNA {G2.14} *I have made orsf and its associated functions throw an error when there is a missing value in the relevant data. Here is why I made this decision: (1) imputation of missing data is an involved process that many other packages have been designed to engage with. I want aorsf to be good at one thing, which is oblique random survival forests. If I try to make routines to handle missing data, I am kind of re-inventing the wheel when I could be working on things that are more relevant to the oblique random survival forest. (2) ignoring missing data would be okay from a programmatic point of view, but I have chosen not to implement this because it would not be helpful if the user was unaware of their missing data. I want orsf to perform a hard stop when it detects missing data because in many cases, junior analysts are not familiar enough with their data to know it has missing values, and perpetuating the unawareness of missing data by handling it on the back-end of analysis functions just creates downstream issues when the analysis is written up.* -#' @srrstatsNA {G2.14a} see above -#' @srrstatsNA {G2.14b} see above -#' @srrstatsNA {G2.14c} see above -#' -#' @srrstatsNA {G3.1} *there is no user-facing covariance calculation.* -#' -#' @srrstatsNA {G3.1a} *Specific covariance methods are not applied, although there are explicitly documented control methods for fitting oblique survival trees.* -#' -#' @srrstatsNA {G4.0} *outputs are not written to local files.* -#' -#' @srrstatsNA {G5.1} *No data sets are created within the package. Random values are created from a set seed, and they can be reproduced this way if needed* -#' -#' @srrstatsNA {G5.4c} *There are no tests where published paper outputs are applicable and where code from original implementations is not available* -#' -#' @srrstatsNA {G5.6b} *Parameter recovery tests are not run with multiple random seeds as neither data simulation nor the algorithms we test contain a random component.* -#' -#' @srrstatsNA {G5.10} *Extended tests are not used in this package* -#' @srrstatsNA {G5.11} *Extended tests are not used in this package* -#' @srrstatsNA {G5.11a} *Extended tests are not used in this package* -#' @srrstatsNA {G5.12} *Extended tests are not used in this package* -#' -#' @srrstatsNA {ML1.0a} *Training and testing data are not eschewed.* -#' @srrstatsNA {ML1.1a, ML1.1b} *The labels 'train' and 'test' are not used by functions in aorsf, so are not explicitly confirmed via pre-processing steps.* -#' -#' @srrstatsNA {ML1.2} *orsf() only requires training data, and predict.aorsf() only requires testing data, so I do not not apply methods to distinguish training and testing data in these functions.* -#' -#' @srrstatsNA {ML1.4} *aorsf does not provide explicit training and test data sets, only a small dataset for illustrative purposes. If testing data are needed for illustration, they are created by sub-setting.* -#' -#' @srrstatsNA {ML1.5} *aorsf is a very lightweight package that focuses on fitting, applying, and interpreting oblique random survival forests. There are numerous other R packages that summarize the contents of data sets, and their extensively developed functions are much better at summarizing training and testing data than something I would write* -#' -#' @srrstatsNA {ML1.7, ML1.7a, ML1.7b, ML1.8} *aorsf does not admit or impute missing values* -#' -#' @srrstatsNA {ML2.1} *aorsf does not use broadcasting to reconcile dimensionally incommensurate input data.* -#' @srrstatsNA {ML2.2, ML2.2a, ML2.2b} *aorsf does not require numeric transformation of input data and therefore does not have a dedicated input data specification stage. When internal transformations are performed, they are always reversed for compatibility with the original input. The strategy used by aorsf is the same as Terry Therneau's routine in the survival::coxph function. I do not intend to allow specification of target values for this transformation. I want to keep my Newton-Raphson scoring procedure as close to identical to Terry Therneau's as possible.* -#' -#' @srrstatsNA {ML3.0b} *Since aorsf doesn't have a dedicated input data specification state, the output of orsf(no_fit=TRUE) is passed directly to orsf_train().* -#' -#' @srrstatsNA {ML3.1} *aorsf allows users to print and fit untrained models. Since aorsf doesn't include a data pre-processing stage, I am not sure if there are any other helpful functions to include for an untrained random forest. For pre-trained models, users can always run saveRDS and readRDS on aorsf objects, which are just lists with attributes. Thus, additional functions for these purposes have not been added* -#' -#' @srrstatsNA {ML3.2} *aorsf does not have a dedicated input data specification step.* -#' -#' @srrstatsNA {ML3.4, ML3.4a, ML3.4b} *aorsf does not use training rates.* -#' -#' @srrstatsNA {ML3.7} *This software uses C++, facilitated through Rcpp, which does not currently allow user-controlled use of either CPUs or GPUs.* -#' -#' @srrstatsNA {ML4.1c} *The random forest trees do not depend on each other, so there is no information used to advance from one tree to the next.* -#' -#' @srrstatsNA {ML4.3} *aorsf does not use batch processing.* -#' -#' @srrstatsNA {ML4.4} *aorsf does not use batch processing.* -#' -#' @srrstatsNA {ML4.6} *aorsf does not use batch jobs.* -#' -#' @srrstatsNA {ML4.7, ML4.8, ML4.8a} *aorsf does not currently include dedicated functions for re-sampling (e.g., glmnet::cv.glmnet or xgboost::xgb.cv). Random forests generally do not need a lot of tuning and there are other R packages that are dedicated to providing robust resampling routines (e.g., rsample).* - -#' @srrstatsNA {ML5.1} *The properties and behaviours of ORSF models were explicitly compared with objects produced by other ML software in Jaeger et al, 2019 (DOI: 10.1214/19-AOAS1261). These comparisons focused on comparing model performance. I am not including comparisons such as this in the aorsf package because I want aorsf to include or suggest including as few other R packages as possible. However, I am managing a separate repository where extensive comparisons are made between aorsf, party, randomForestSRC, xgboost, and ranger. This repo is located at: https://github.com/bcjaeger/aorsf-bench and the comparisons made between aorsf and other software can be viewed here: https://bcjaeger.github.io/aorsf-bench/* -#' -#' @srrstatsNA {ML5.2c} *General functions for saving or serializing objects, such as [`saveRDS`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/readRDS.html) are appropriate for storing local copies of trained aorsf models.* -#' -#' @srrstatsNA {ML7.0} *aorsf does not have text inputs with labels of "test", "train", or "validation" data. However, aorsf does implement a function, check_arg_is_valid(), which assesses validity of text inputs based on a case-sensitive set of valid options.* -#' -#' @srrstatsNA {ML7.2} *aorsf does not impute missing data.* -#' -#' @srrstatsNA {ML7.3, ML7.3a, ML7.3b} *I am not including comparisons such as this in the aorsf package because I want aorsf to include or suggest including as few other R packages as possible. I don't want to overload the imported or suggested packages for aorsf because it becomes exponentially harder to get a package onto CRAN the more it depends on other packages. Jaeger et al, 2019 (DOI: 10.1214/19-AOAS1261) made comparisons like these formally using several ML software packages, and I plan on writing a similar paper for aorsf that will make meaningful comparisons similar to the ones I made in Jaeger et al, 2019. I am developing this paper in the following repo: https://github.com/bcjaeger/aorsf-bench. You can look directly at comparisons of computational efficiency and prediction accuracy here: https://bcjaeger.github.io/aorsf-bench/* -#' -#' @srrstatsNA {ML7.4} *aorsf does not use training rates* -#' -#' @srrstatsNA {ML7.5} *aorsf does not use training rates* -#' -#' @srrstatsNA {ML7.6} *aorsf does not use training epochs.* -#' -#' @srrstatsNA {ML7.11a} *aorsf does not implement multiple metrics and therefore cannot demonstrate relative advantages and disadvantages of different metrics. However, when verifying the accuracy of aorsf's scripts to compute certain metrics (e.g., the likelihood ratio test, and cox PH regression), aorsf tests to make sure the metrics are correctly computed over a wide range of inputs* -#' -#' @noRd -NULL diff --git a/src/Coxph.cpp b/src/Coxph.cpp index ed7c5ac6..3058f3af 100644 --- a/src/Coxph.cpp +++ b/src/Coxph.cpp @@ -459,8 +459,8 @@ break_loop = false; - XB = x_node * beta_new; - Risk = exp(XB) % w_node; + // XB = x_node * beta_new; + // Risk = exp(XB) % w_node; for( ; ; ){ @@ -475,18 +475,18 @@ n_risk++; - xb = XB.at(person); - risk = Risk.at(person); + // xb = XB.at(person); + // risk = Risk.at(person); - // xb = 0; - // - // for(i = 0; i < n_vars; i++){ - // xb += beta.at(i) * x_node.at(person, i); - // } + xb = 0; + + for(i = 0; i < n_vars; i++){ + xb += beta_new.at(i) * x_node.at(person, i); + } w_node_person = w_node.at(person); - // risk = exp(xb) * w_node_person; + risk = exp(xb) * w_node_person; if (y_node.at(person, 1) == 0) { diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index bcf13c11..4c0ae8fc 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -348,7 +348,7 @@ leaf_data.at(0, 0) = y_node.at(person, 0); // if no events in this node: - // (TODO: should this case even occur? consider removing) + // should this case even occur? consider removing if(person == y_node.n_rows){ vec temp_surv(1, fill::ones); diff --git a/tests/testthat/test-coxph.R b/tests/testthat/test-coxph.R new file mode 100644 index 00000000..8e8f4c04 --- /dev/null +++ b/tests/testthat/test-coxph.R @@ -0,0 +1,56 @@ + +run_cph_test <- function(x, y, w, method){ + + control <- coxph.control(iter.max = 20, eps = 1e-8) + + start <- Sys.time() + + tt = survival::coxph.fit(x = x, + y = y, + strata = NULL, + offset = NULL, + init = rep(0, ncol(x)), + control = control, + weights = w, + method = if(method == 0) 'breslow' else 'efron', + rownames = NULL, + resid = FALSE, + nocenter = c(0)) + + stop <- Sys.time() + + tt_time <- stop-start + + xx <- x[, , drop = FALSE] + + start <- Sys.time() + + bcj = coxph_fit_exported(xx, + y, + w, + method = method, + cph_eps = control$eps, + cph_iter_max = control$iter.max) + + stop <- Sys.time() + + bcj_time <- stop-start + + expect_equal(as.numeric(tt$coefficients), bcj$beta, tolerance = control$eps) + + expect_equal(diag(tt$var), bcj$var, tolerance = control$eps) + + # list(bcj_time = bcj_time, tt_time = tt_time) + +} + +for(i in seq_along(mat_list_surv)){ + + x <- mat_list_surv[[i]]$x + y <- mat_list_surv[[i]]$y + w <- mat_list_surv[[i]]$w + + run_cph_test(x, Surv(y), w, method = 0) + run_cph_test(x, Surv(y), w, method = 1) + +} diff --git a/tests/testthat/test-newtraph_cph.R b/tests/testthat/test-newtraph_cph.R deleted file mode 100644 index c788ee09..00000000 --- a/tests/testthat/test-newtraph_cph.R +++ /dev/null @@ -1,170 +0,0 @@ - -#' @srrstats {G5.4} **Correctness tests** *test that statistical algorithms produce expected results to some fixed test data sets. I use the pbc data and compare the aorsf newton raphson algorithm to the same algorithm used in survival::coxph().* - -#' @srrstats {G5.4b} *Correctness tests include tests against previous implementations, explicitly calling those implementations in testing.* - -#' @srrstats {G5.0} *tests use the pbc data and the flchain data, two standard datasets in the survival package that are widely studied. The pbc data are also featured in another R package for random forests, i.e., randomForestSRC* - - -#' @srrstats {G5.6} **Parameter recovery tests** *The coxph newton-raphson algorithm returns coefficient values. The aorsf version matches those within a specified tolerance* - -#' @srrstats {ML7.7} *explicitly test optimization algorithms for accuracy. I do not test multiple optimization algorithms because this is the only one I have programmed in aorsf. The optimization algorithm used in coxnet has been thoroughly tested by the glmnet developers.* - -#' @srrstats {ML7.10} *The successful extraction of information on paths taken by optimizers is tested.* - -iter_max = 20 -control <- survival::coxph.control(iter.max = iter_max, eps = 1e-8) - -run_cph_test <- function(x, y, method){ - - wts <- sample(seq(1:5), size = nrow(x), replace = TRUE) - - tt = survival::coxph.fit(x = x, - y = y, - strata = NULL, - offset = NULL, - init = rep(0, ncol(x)), - control = control, - weights = wts, - method = if(method == 0) 'breslow' else 'efron', - rownames = NULL, - resid = FALSE, - nocenter = c(0)) - - tt_fit <- survival::coxph(y~x, - weights = wts, - control = control, - ties = if(method == 0) 'breslow' else 'efron') - - tt_inf <- summary(tt_fit)$coefficients[,'Pr(>|z|)'] - - xx <- x[, , drop = FALSE] - - bcj = coxph_fit_exported(xx, - y, - wts, - method = method, - cph_eps = 1e-8, - cph_iter_max = iter_max) - - expect_equal(as.numeric(tt$coefficients), bcj$beta, tolerance = 1e-9) - - expect_equal(diag(tt$var), bcj$var, tolerance = 1e-9) - -} - - -# pbc data ---------------------------------------------------------------- - -.pbc <- pbc_orsf[order(pbc_orsf$time), ] - -.pbc$trt <- as.numeric(.pbc$trt) -.pbc$ascites <- as.numeric(.pbc$ascites) -.pbc$hepato <- as.numeric(.pbc$hepato) - -x <- as.matrix(.pbc[, c('trt','age','ascites','hepato','bili')]) -y <- survival::Surv(.pbc$time, .pbc$status) - -#' @srrstats {G5.6a} *succeed within a defined tolerance rather than recovering exact values.* - -test_that( - desc = 'similar answers for pbc data', - code = { - run_cph_test(x, y, method = 0) - run_cph_test(x, y, method = 1) - } -) - -# flchain data ------------------------------------------------------------ - -data("flchain", package = 'survival') - -df <- na.omit(flchain) - -df$chapter <- NULL - -time <- 'futime' -status <- 'death' - -df_nomiss <- na.omit(df) - -df_sorted <- df_nomiss[order(df_nomiss[[time]]),] - -df_x <- df_sorted -df_x[[time]] <- NULL -df_x[[status]] <- NULL - -flchain_x <- model.matrix(~.-1, data = df_x) - -flchain_y <- survival::Surv(time = df_sorted[[time]], - event = df_sorted[[status]]) - -x <- flchain_x[, c('age', 'sexF','sample.yr', 'kappa', 'lambda')] -y <- flchain_y - -#' @srrstats {G5.6a} *succeed within a defined tolerance rather than recovering exact values.* - -test_that( - desc = 'similar answers for flchain data', - code = { - run_cph_test(x, y, method = 0) - run_cph_test(x, y, method = 1) - } -) - -# # speed comparison -------------------------------------------------------- -# -# data("flchain", package = 'survival') -# -# df <- na.omit(flchain) -# -# df$chapter <- NULL -# -# time <- 'futime' -# status <- 'death' -# -# df_nomiss <- na.omit(df) -# -# df_sorted <- df_nomiss[order(df_nomiss[[time]]),] -# -# df_x <- df_sorted -# df_x[[time]] <- NULL -# df_x[[status]] <- NULL -# -# flchain_x <- model.matrix(~.-1, data = df_x) -# -# flchain_y <- survival::Surv(time = df_sorted[[time]], -# event = df_sorted[[status]]) -# -# x <- flchain_x[, c('age', 'sexF','sample.yr', 'kappa', 'lambda')] -# y <- flchain_y -# -# wts <- sample(seq(1:2), size = nrow(x), replace = TRUE) -# -# method = 0 -# -# control <- survival::coxph.control(iter.max = 1, eps = 1e-8) -# -# microbenchmark::microbenchmark( -# -# tt = survival::coxph.fit(x = x, -# y = y, -# strata = NULL, -# offset = NULL, -# init = rep(0, ncol(x)), -# control = control, -# weights = wts, -# method = if(method == 0) 'breslow' else 'efron', -# rownames = NULL, -# resid = FALSE, -# nocenter = c(0)), -# -# bcj = coxph_fit_exported(x[, , drop = FALSE], -# y, -# wts, -# method = method, -# cph_eps = 1e-8, -# cph_iter_max = control$iter.max) -# -# ) - diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index 62a9bb64..67e57ba8 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -762,22 +762,6 @@ test_that( -# high pred horizon -# TODO: move this to test file for summarize -# test_that( -# desc = 'higher pred horizon is not allowed for summary', -# code = { -# -# fit_bad_oob_horizon <- orsf(time + status ~ ., data = pbc_orsf, -# oobag_pred_horizon = 7000) -# -# expect_error(orsf_summarize_uni(fit_bad_oob_horizon), -# regexp = 'prediction horizon') -# -# } -# ) - - # Similar to obliqueRSF? # suppressPackageStartupMessages({ # library(obliqueRSF) diff --git a/tests/testthat/test-orsf_summary.R b/tests/testthat/test-orsf_summary.R index ef5ef9d0..3346b22d 100644 --- a/tests/testthat/test-orsf_summary.R +++ b/tests/testthat/test-orsf_summary.R @@ -108,4 +108,19 @@ test_that("bad inputs caught", { }) +# high pred horizon +test_that( + desc = 'higher pred horizon is not allowed for summary', + code = { + + fit_bad_oob_horizon <- orsf(pbc, + time + status ~ ., + n_tree = 1, + oobag_pred_horizon = 7000) + + expect_error(orsf_summarize_uni(fit_bad_oob_horizon), + regexp = 'prediction horizon') + + } +) diff --git a/tests/testthat/test-ostree_pred_leaf.R b/tests/testthat/test-ostree_pred_leaf.R deleted file mode 100644 index 5b18fe40..00000000 --- a/tests/testthat/test-ostree_pred_leaf.R +++ /dev/null @@ -1,63 +0,0 @@ -#' -#' #' @srrstats {G5.4a} *Testing of leaf assignment implementation is done by comparing results from an initial R implementation to the C++ implementation.* -#' -#' #' @srrstats {G5.5} *Correctness tests are run with a fixed random seed* -#' set.seed(1) -#' -#' formula <- Surv(time, status) ~ . - id -#' -#' aorsf <- orsf( -#' formula = formula, -#' data = pbc_orsf, -#' n_tree = 10, -#' leaf_min_obs = 20 -#' ) -#' -#' formula_terms <- stats::terms(formula, data=pbc_orsf) -#' names_x_data <- attr(formula_terms, 'term.labels') -#' -#' fi <- fctr_info(pbc_orsf, names_x_data) -#' x <- as.matrix(ref_code(pbc_orsf, fi, names_x_data)) -#' -#' for( tr in seq(10) ){ -#' -#' tree <- aorsf$forest[[tr]] -#' leaf_assigned <- rep(0, nrow(x)) -#' -#' for (j in seq(0, ncol(tree$betas)-1 ) ) { -#' -#' jj <- j+1 -#' -#' obs_in_node <- which(leaf_assigned==j) -#' -#' if(tree$children_left[jj] != 0){ -#' -#' lc <- x[obs_in_node, (tree$col_indices[, jj] + 1)] %*% tree$betas[, jj, drop = F] -#' -#' going_left <- lc <= tree$cut_points[jj] -#' going_right <- !going_left -#' -#' leaf_assigned[obs_in_node[going_left]] <- tree$children_left[jj] -#' leaf_assigned[obs_in_node[going_right]] <- tree$children_left[jj] + 1 -#' -#' } -#' -#' } -#' -#' test_that( -#' desc = 'check pred_leaf with R script', -#' code = { -#' -#' leaves <- as.numeric(ostree_pred_leaf_testthat(tree = tree, x_pred_ = x)) -#' #' @srrstats {G5.3} *Test that objects returned contain no missing (`NA`) or undefined (`NaN`, `Inf`) values.* -#' expect_false(any(is.na(leaves))) -#' expect_false(any(is.nan(leaves))) -#' expect_false(any(is.infinite(leaves))) -#' expect_equal(leaves, as.numeric(leaf_assigned)) -#' -#' #' @srrstats {G5.6a} *In this case the results are integers, so no tolerance is used.* -#' -#' } -#' ) -#' -#' } diff --git a/tests/testthat/test-ref_code.R b/tests/testthat/test-ref_code.R new file mode 100644 index 00000000..bbdcb864 --- /dev/null +++ b/tests/testthat/test-ref_code.R @@ -0,0 +1,16 @@ + +fi <- fctr_info(data = pbc_orsf, .names = names(pbc_orsf)) + +pbc_refcoded <- ref_code(pbc_orsf, + fi, + names_x_data = c("age", "sex", "stage")) + +test_that( + desc = "reference coding names and types", + code = { + expect_named(pbc_refcoded, expected = c("age", "sex_f", "stage")) + expect_type(pbc_refcoded$stage, 'integer') + expect_type(pbc_refcoded$sex_f, 'integer') + expect_type(pbc_refcoded$age, 'double') + } +)