Skip to content

Commit

Permalink
Add greedyOptimizer class (#304)
Browse files Browse the repository at this point in the history
* maybe

* try this

* IT WORKS

* add tests

* fix tests

* tests pass

* tests-pass
  • Loading branch information
zachmayer authored Aug 8, 2024
1 parent d3af788 commit 62725d3
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 1 deletion.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ S3method(plot,caretList)
S3method(plot,caretStack)
S3method(predict,caretList)
S3method(predict,caretStack)
S3method(predict,greedyMSE)
S3method(print,caretStack)
S3method(print,greedyMSE)
S3method(print,summary.caretList)
S3method(print,summary.caretStack)
S3method(summary,caretList)
Expand All @@ -26,6 +28,7 @@ export(caretList)
export(caretModelSpec)
export(caretStack)
export(extractMetric)
export(greedyMSE)
export(is.caretList)
export(is.caretStack)
export(permutationImportance)
Expand Down
86 changes: 86 additions & 0 deletions R/greedyOpt.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#' @title Greedy optimization for MSE
#' @description Greedy optimization for minimizing the mean squared error.
#' Works for classification and regression.
#' @param X A numeric matrix of features.
#' @param Y A numeric matrix of target values.
#' @param max_iter An integer scalar of the maximum number of iterations.
#' @return A list with components:
#' \item{model_weights}{A numeric matrix of model_weights.}
#' \item{RMSE}{A numeric scalar of the root mean squared error.}
#' \item{max_iter}{An integer scalar of the maximum number of iterations.}
#' @export
greedyMSE <- function(X, Y, max_iter = 100L) {
stopifnot(
is.matrix(X), is.matrix(Y),
is.numeric(X), is.numeric(Y),
is.finite(X), is.finite(Y),
nrow(X) == nrow(Y), ncol(X) >= 1L, ncol(Y) >= 1L,
is.integer(max_iter), max_iter > 0L
)

model_weights <- matrix(0L, nrow = ncol(X), ncol = ncol(Y))
model_update <- diag(ncol(X))

for (iter in seq_len(max_iter)) {
for (y_col in seq_len(ncol(Y))) {
target <- Y[, y_col]
w <- model_weights[, y_col]

# Calculate MSE for incrementing each weight
w_new <- w + model_update
w_new <- w_new / colSums(w_new)
predictions <- X %*% w_new
MSE <- colMeans((predictions - target)^2.0)

# Update the best weight
best_id <- which.min(MSE)
model_weights[best_id, y_col] <- model_weights[best_id, y_col] + 1L
}
}

# Output
model_weights <- model_weights / colSums(model_weights)
rownames(model_weights) <- colnames(X)
colnames(model_weights) <- colnames(Y)
RMSE <- sqrt(mean((X %*% model_weights - Y)^2.0))
out <- list(
model_weights = model_weights,
RMSE = RMSE,
max_iter = max_iter
)
class(out) <- "greedyMSE"
out
}

#' @title Print method for greedyMSE
#' @description Print method for greedyMSE objects.
#' @param x A greedyMSE object.
#' @param ... Additional arguments. Ignored.
#' @export
print.greedyMSE <- function(x, ...) {
cat("Greedy MSE\n")
cat("RMSE: ", x$RMSE, "\n")
cat("Weights:\n")
print(x$model_weights)
}

#' @title Predict method for greedyMSE
#' @description Predict method for greedyMSE objects.
#' @param object A greedyMSE object.
#' @param newdata A numeric matrix of new data.
#' @param ... Additional arguments. Ignored.
#' @return A numeric matrix of predictions.
#' @export
predict.greedyMSE <- function(object, newdata, ...) {
stopifnot(
is.matrix(newdata),
is.numeric(newdata),
is.finite(newdata),
ncol(newdata) == nrow(object$model_weights)
)
out <- newdata %*% object$model_weights
if (ncol(out) > 1L) {
out <- out / rowSums(out)
}
out
}
1 change: 0 additions & 1 deletion R/permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ permutationImportance <- function(
is.numeric(preds_orig),
is.finite(preds_orig)
)

# Error of shuffled variables
mae_vars <- shuffled_mae(model, newdata, preds_orig, pred_type, shuffle_idx)

Expand Down
Binary file modified coverage.rds
Binary file not shown.
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ggplot
ggplot2
github
glm
greedyMSE
importances
kable
knitr
Expand Down
25 changes: 25 additions & 0 deletions man/greedyMSE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions man/predict.greedyMSE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions man/print.greedyMSE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

171 changes: 171 additions & 0 deletions tests/testthat/test-greedyMSE.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Helper function to create a simple dataset
create_dataset <- function(n_samples, n_features = 5L, n_targets = 1L, coef = NULL, noise = 0L) {
X <- matrix(stats::runif(n_samples * n_features), nrow = n_samples)

if (is.null(coef)) {
coef <- matrix(stats::runif(n_features * n_targets), nrow = n_features)
} else if (is.vector(coef)) {
coef <- matrix(coef, ncol = 1L)
}

# Normalize coefficients
coef <- apply(coef, 2L, function(col) col / sum(abs(col)))

Y <- X %*% coef + matrix(stats::rnorm(n_samples * n_targets, 0.0, noise), nrow = n_samples)

if (n_targets > 1L) {
# Ensure all values are positive
Y <- Y - min(Y) + 1e-6
# Normalize rows to sum to 1
Y <- Y / rowSums(Y)
}

list(X = X, Y = Y, coef = coef)
}

# Create datasets for reuse
set.seed(42L)
N <- 100L
regression_data <- create_dataset(N, noise = 0.1)
multi_regression_data <- create_dataset(N, n_targets = 3L)

Y_binary <- matrix(as.integer(regression_data$Y > mean(regression_data$Y)), nrow = N)
Y_multi_binary <- matrix(as.integer(multi_regression_data$Y > mean(multi_regression_data$Y)), nrow = N)

# Test for regression (one col)
testthat::test_that("greedyMSE works for regression", {
model <- greedyMSE(regression_data$X, regression_data$Y)
testthat::expect_lt(model$RMSE, stats::sd(regression_data$Y)) # Model should be better than baseline
# High correlation with true values
testthat::expect_gt(stats::cor(predict(model, regression_data$X), regression_data$Y), 0.8)

testthat::expect_output(print(model), "Greedy MSE")
testthat::expect_output(print(model), "RMSE")
testthat::expect_output(print(model), "Weights")
})

# Test for binary classification (one col)
testthat::test_that("greedyMSE works for binary classification", {
model <- greedyMSE(regression_data$X, Y_binary)
predictions <- predict(model, regression_data$X)
accuracy <- mean((predictions > 0.5) == Y_binary)
testthat::expect_gt(accuracy, 0.7) # Accuracy should be better than random guessing

testthat::expect_output(print(model), "Greedy MSE")
testthat::expect_output(print(model), "RMSE")
testthat::expect_output(print(model), "Weights")
})

# Test for multiple regression (many cols)
testthat::test_that("greedyMSE works for multiple regression", {
model <- greedyMSE(multi_regression_data$X, multi_regression_data$Y)

testthat::expect_lt(model$RMSE, 0.3)
predictions <- predict(model, multi_regression_data$X)
correlations <- diag(stats::cor(predictions, multi_regression_data$Y))
testthat::expect_gt(min(correlations), 0.7) # High correlation for all targets
})

# Test for multiclass classification (many cols)
testthat::test_that("greedyMSE works for multiclass classification", {
model <- greedyMSE(multi_regression_data$X, Y_multi_binary)
predictions <- predict(model, multi_regression_data$X)
accuracy <- mean(apply(predictions, 1L, which.max) == apply(multi_regression_data$Y, 1L, which.max))
testthat::expect_gt(accuracy, 0.6) # Accuracy should be better than random guessing
})

# Edge cases
testthat::test_that("greedyMSE handles edge cases", {
# 1. Single feature
data <- create_dataset(100L, 1L, 1L)
testthat::expect_equal(greedyMSE(data$X, data$Y)$RMSE, 0.0, tolerance = 1e-6)

# 2. Perfect multicollinearity
X <- matrix(1L, nrow = 100L, ncol = 2L)
Y <- X[, 1L] + stats::rnorm(100L, 0.0, 0.1)
testthat::expect_equal(greedyMSE(data$X, data$Y)$RMSE, 0.0, tolerance = 1e-6)

# 3. All zero features
X <- matrix(0L, nrow = 100L, ncol = 5L)
Y <- matrix(stats::rnorm(100L), ncol = 1L)
model <- greedyMSE(X, Y)
testthat::expect_equal(model$RMSE, stats::sd(Y), tolerance = 1e-2)

# 4. Constant target
X <- matrix(stats::runif(500L), nrow = 100L)
Y <- matrix(rep(1L, 100L), ncol = 1L)
model <- greedyMSE(X, Y)
testthat::expect_equal(model$RMSE, 0.5, tolerance = 0.1)

# 5. Very large values
X <- matrix(stats::runif(500L, 1.0e6, 1.0e7), nrow = 100L)
Y <- matrix(rowSums(X) + stats::rnorm(100L, 0L, 1.0e5), ncol = 1L)
model <- greedyMSE(X, Y)
pred <- predict(model, X)
testthat::expect_gt(cor(pred, Y), 0.4)
})

# Regression ensembling test
testthat::test_that("greedyMSE can be used for regression ensembling with GLM, rpart, and RF", {
X <- regression_data$X
Y <- regression_data$Y

# GLM
glm_model <- stats::glm(Y ~ X, family = stats::gaussian())
glm_pred <- stats::predict(glm_model, newdata = data.frame(X))

# rpart
rpart_model <- rpart::rpart(Y ~ X)
rpart_pred <- stats::predict(rpart_model, newdata = data.frame(X))

# greedyMSE
greedy_model <- greedyMSE(X, Y)
greedy_pred <- predict(greedy_model, X)

# Ensemble
ensemble_X <- cbind(glm_pred, rpart_pred, greedy_pred)
ensemble_model <- greedyMSE(ensemble_X, Y)

# Check if ensemble performs better than the best individual model
individual_rmse <- c(
sqrt(mean((Y - glm_pred)^2L)),
sqrt(mean((Y - rpart_pred)^2L)),
greedy_model$RMSE
)
testthat::expect_lte(ensemble_model$RMSE, min(individual_rmse))
})

# Classification ensembling test
testthat::test_that("greedyMSE can be used for classification ensembling with GLM, rpart, and RF", {
X <- regression_data$X
Y_binary <- as.integer(regression_data$Y > stats::median(regression_data$Y))

# GLM (logistic regression)
glm_model <- stats::glm(Y_binary ~ X, family = stats::binomial())
glm_pred <- stats::predict(glm_model, newdata = data.frame(X), type = "response")

# rpart
rpart_model <- rpart::rpart(Y_binary ~ X, method = "class")
rpart_pred <- stats::predict(rpart_model, newdata = data.frame(X), type = "prob")[, 2L]

# Random Forest
rf_model <- randomForest::randomForest(X, as.factor(Y_binary))
rf_pred <- stats::predict(rf_model, newdata = X, type = "prob")[, 2L]

# greedyMSE
greedy_model <- greedyMSE(X, matrix(Y_binary, ncol = 1L))
greedy_pred <- predict(greedy_model, X)

# Ensemble
ensemble_X <- cbind(glm_pred, rpart_pred, rf_pred, greedy_pred)
ensemble_model <- greedyMSE(ensemble_X, matrix(Y_binary, ncol = 1L))

# Check if ensemble performs better than the best individual model
individual_rmse <- c(
sqrt(mean((Y_binary - glm_pred)^2L)),
sqrt(mean((Y_binary - rpart_pred)^2L)),
sqrt(mean((Y_binary - rf_pred)^2L)),
greedy_model$RMSE
)
testthat::expect_lte(ensemble_model$RMSE, min(individual_rmse))
})

0 comments on commit 62725d3

Please sign in to comment.