Skip to content

Commit

Permalink
fix: Automated fixes for issue #346 (#361)
Browse files Browse the repository at this point in the history
Co-authored-by: Your Name (aider) <[email protected]>
Co-authored-by: zachmayer <[email protected]>
fix: Automated fixes for issue #346
  • Loading branch information
3 people authored Nov 30, 2024
1 parent f0218fb commit a9ba676
Show file tree
Hide file tree
Showing 5 changed files with 890 additions and 7 deletions.
8 changes: 6 additions & 2 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ caretList <- function(
tuneList = NULL,
metric = NULL,
continue_on_fail = FALSE,
trim = TRUE) {
trim = TRUE,
sort_preds = TRUE) {
# Checks
if (is.null(tuneList) && is.null(methodList)) {
stop("Please either define a methodList or tuneList", call. = FALSE)
Expand Down Expand Up @@ -79,7 +80,10 @@ caretList <- function(
global_args[["metric"]] <- metric

# Loop through the tuneLists and fit caret models with those specs
modelList <- lapply(tuneList, caretTrain, global_args = global_args, continue_on_fail = continue_on_fail, trim = trim)
modelList <- lapply(tuneList, caretTrain,
global_args = global_args,
continue_on_fail = continue_on_fail, trim = trim, sort_preds = sort_preds
)
names(modelList) <- names(tuneList)
nulls <- vapply(modelList, is.null, logical(1L))
modelList <- modelList[!nulls]
Expand Down
13 changes: 8 additions & 5 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {
#' If `TRUE`, the function will remove some elements that are not needed from the output model.
#' @return The output of the `train` function.
#' @keywords internal
caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim = TRUE) {
caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim = TRUE, sort_preds = TRUE) {
# Combine args
# I think my handling here is correct (update globals with locals, which allows locals be partial)
# but it would be nice to have some tests
Expand All @@ -100,7 +100,7 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =

# Only save stacked predictions for the best model
if ("pred" %in% names(model)) {
model[["pred"]] <- extractBestPreds(model)
model[["pred"]] <- extractBestPreds(model, sort_preds = sort_preds)
}

if (trim) {
Expand Down Expand Up @@ -147,9 +147,10 @@ aggregate_mean_or_first <- function(x) {
#' @title Extract the best predictions from a train object
#' @description Extract the best predictions from a train object.
#' @param x a train object
#' @param sort_preds logical, should predictions be sorted by rowIndex. Default TRUE.
#' @return a data.table::data.table with predictions
#' @keywords internal
extractBestPreds <- function(x) {
extractBestPreds <- function(x, sort_preds = TRUE) {
stopifnot(methods::is(x, "train"))
if (is.null(x$pred)) {
stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE)
Expand All @@ -173,8 +174,10 @@ extractBestPreds <- function(x) {
data.table::setkeyv(pred, keys)
pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys]

# Order results consistently
data.table::setorderv(pred, keys)
# Order results consistently if requested
if (sort_preds) {
data.table::setorderv(pred, keys)
}

# Return
pred
Expand Down
Loading

0 comments on commit a9ba676

Please sign in to comment.