Skip to content

Commit

Permalink
safer coxph mult and cleaner test
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 11, 2023
1 parent 056b1a8 commit 2a37136
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 468 deletions.
8 changes: 2 additions & 6 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -1221,19 +1221,15 @@ orsf_train_ <- function(object,

if(get_oobag_pred(object)){

# put the oob predictions into the same order as the training data.
# TODO: this can be faster; see predict unsorting
unsorted <- vector(mode = 'integer', length = length(sorted))
for(i in seq_along(unsorted)) unsorted[ sorted[i] ] <- i

# clear labels for oobag evaluation type

object$eval_oobag$stat_type <-
switch(EXPR = as.character(object$eval_oobag$stat_type),
"0" = "None",
"1" = "Harrell's C-statistic",
"2" = "User-specified function")

# put the oob predictions into the same order as the training data.
unsorted <- collapse::radixorder(sorted)
object$pred_oobag <- object$pred_oobag[unsorted, , drop = FALSE]

# mortality predictions should always be 1 column
Expand Down
80 changes: 0 additions & 80 deletions R/ref_code.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,86 +93,6 @@ ref_code <- function (x_data, fi, names_x_data){

}

# an older version of the function above that didn't use collapse
# (its about 2 times slower)
# ref_code <- function (x_data, fi, names_x_data){
#
# # Will use these original names to help re-order the output
#
# for(i in seq_along(fi$cols)){
#
# if(fi$cols[i] %in% names(x_data)){
#
# if(fi$ordr[i]){
#
# x_data[[ fi$cols[i] ]] <- as.integer( x_data[[ fi$cols[i] ]] )
#
# } else {
#
# # make a matrix for each factor
# mat <- matrix(0,
# nrow = nrow(x_data),
# ncol = length(fi$lvls[[i]])
# )
#
# colnames(mat) <- fi$keys[[i]]
#
# # missing values of the factor become missing rows
# mat[is.na(x_data[[fi$cols[i]]]), ] <- NA_integer_
#
# # we will one-hot encode the matrix and then bind it to data,
# # replacing the original factor column. Go through the matrix
# # column by column, where each column corresponds to a level
# # of the current factor (indexed by i). Flip the values
# # of the j'th column to 1 whenever the current factor's value
# # is the j'th level.
#
# for (j in seq(ncol(mat))) {
#
# # find which rows to turn into 1's. These should be the
# # indices in the currect factor where it's value is equal
# # to the j'th level.
# hot_rows <- which( x_data[[fi$cols[i]]] == fi$lvls[[i]][j] )
#
# # after finding the rows, flip the values from 0 to 1
# if(!is_empty(hot_rows)){
# mat[hot_rows , j] <- 1
# }
#
# }
#
# # data[[fi$cols[i]]] <- NULL
#
# x_data <- cbind(x_data, mat)
#
# }
#
# }
#
#
#
# }
#
# OH_names <- names_x_data
#
# for (i in seq_along(fi$cols)){
#
# if(fi$cols[i] %in% names_x_data){
# if(!fi$ordr[i]){
# OH_names <- insert_vals(
# vec = OH_names,
# where = which(fi$cols[i] == OH_names),
# what = fi$keys[[i]][-1]
# )
# }
# }
#
# }
#
# select_cols(x_data, OH_names)
#
# }

#' insert some value(s) into a vector
#'
#'
Expand Down
122 changes: 0 additions & 122 deletions R/srr-stats-standards.R

This file was deleted.

20 changes: 10 additions & 10 deletions src/Coxph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@

break_loop = false;

XB = x_node * beta_new;
Risk = exp(XB) % w_node;
// XB = x_node * beta_new;
// Risk = exp(XB) % w_node;


for( ; ; ){
Expand All @@ -475,18 +475,18 @@

n_risk++;

xb = XB.at(person);
risk = Risk.at(person);
// xb = XB.at(person);
// risk = Risk.at(person);

// xb = 0;
//
// for(i = 0; i < n_vars; i++){
// xb += beta.at(i) * x_node.at(person, i);
// }
xb = 0;

for(i = 0; i < n_vars; i++){
xb += beta_new.at(i) * x_node.at(person, i);
}

w_node_person = w_node.at(person);

// risk = exp(xb) * w_node_person;
risk = exp(xb) * w_node_person;

if (y_node.at(person, 1) == 0) {

Expand Down
2 changes: 1 addition & 1 deletion src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@
leaf_data.at(0, 0) = y_node.at(person, 0);

// if no events in this node:
// (TODO: should this case even occur? consider removing)
// should this case even occur? consider removing
if(person == y_node.n_rows){

vec temp_surv(1, fill::ones);
Expand Down
56 changes: 56 additions & 0 deletions tests/testthat/test-coxph.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

run_cph_test <- function(x, y, w, method){

control <- coxph.control(iter.max = 20, eps = 1e-8)

start <- Sys.time()

tt = survival::coxph.fit(x = x,
y = y,
strata = NULL,
offset = NULL,
init = rep(0, ncol(x)),
control = control,
weights = w,
method = if(method == 0) 'breslow' else 'efron',
rownames = NULL,
resid = FALSE,
nocenter = c(0))

stop <- Sys.time()

tt_time <- stop-start

xx <- x[, , drop = FALSE]

start <- Sys.time()

bcj = coxph_fit_exported(xx,
y,
w,
method = method,
cph_eps = control$eps,
cph_iter_max = control$iter.max)

stop <- Sys.time()

bcj_time <- stop-start

expect_equal(as.numeric(tt$coefficients), bcj$beta, tolerance = control$eps)

expect_equal(diag(tt$var), bcj$var, tolerance = control$eps)

# list(bcj_time = bcj_time, tt_time = tt_time)

}

for(i in seq_along(mat_list_surv)){

x <- mat_list_surv[[i]]$x
y <- mat_list_surv[[i]]$y
w <- mat_list_surv[[i]]$w

run_cph_test(x, Surv(y), w, method = 0)
run_cph_test(x, Surv(y), w, method = 1)

}
Loading

0 comments on commit 2a37136

Please sign in to comment.