Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add greedyOptimizer class #304

Merged
merged 7 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ S3method(plot,caretList)
S3method(plot,caretStack)
S3method(predict,caretList)
S3method(predict,caretStack)
S3method(predict,greedyMSE)
S3method(print,caretStack)
S3method(print,greedyMSE)
S3method(print,summary.caretList)
S3method(print,summary.caretStack)
S3method(summary,caretList)
Expand All @@ -26,6 +28,7 @@ export(caretList)
export(caretModelSpec)
export(caretStack)
export(extractMetric)
export(greedyMSE)
export(is.caretList)
export(is.caretStack)
export(permutationImportance)
Expand Down
86 changes: 86 additions & 0 deletions R/greedyOpt.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#' @title Greedy optimization for MSE
#' @description Greedy optimization for minimizing the mean squared error.
#' Works for classification and regression.
#' @param X A numeric matrix of features.
#' @param Y A numeric matrix of target values.
#' @param max_iter An integer scalar of the maximum number of iterations.
#' @return A list with components:
#' \item{model_weights}{A numeric matrix of model_weights.}
#' \item{RMSE}{A numeric scalar of the root mean squared error.}
#' \item{max_iter}{An integer scalar of the maximum number of iterations.}
#' @export
greedyMSE <- function(X, Y, max_iter = 100L) {
stopifnot(
is.matrix(X), is.matrix(Y),
is.numeric(X), is.numeric(Y),
is.finite(X), is.finite(Y),
nrow(X) == nrow(Y), ncol(X) >= 1L, ncol(Y) >= 1L,
is.integer(max_iter), max_iter > 0L
)

model_weights <- matrix(0L, nrow = ncol(X), ncol = ncol(Y))
model_update <- diag(ncol(X))

for (iter in seq_len(max_iter)) {
for (y_col in seq_len(ncol(Y))) {
target <- Y[, y_col]
w <- model_weights[, y_col]

# Calculate MSE for incrementing each weight
w_new <- w + model_update
w_new <- w_new / colSums(w_new)
predictions <- X %*% w_new
MSE <- colMeans((predictions - target)^2.0)

# Update the best weight
best_id <- which.min(MSE)
model_weights[best_id, y_col] <- model_weights[best_id, y_col] + 1L
}
}

# Output
model_weights <- model_weights / colSums(model_weights)
rownames(model_weights) <- colnames(X)
colnames(model_weights) <- colnames(Y)
RMSE <- sqrt(mean((X %*% model_weights - Y)^2.0))
out <- list(
model_weights = model_weights,
RMSE = RMSE,
max_iter = max_iter
)
class(out) <- "greedyMSE"
out
}

#' @title Print method for greedyMSE
#' @description Print method for greedyMSE objects.
#' @param x A greedyMSE object.
#' @param ... Additional arguments. Ignored.
#' @export
print.greedyMSE <- function(x, ...) {
cat("Greedy MSE\n")
cat("RMSE: ", x$RMSE, "\n")
cat("Weights:\n")
print(x$model_weights)
}

#' @title Predict method for greedyMSE
#' @description Predict method for greedyMSE objects.
#' @param object A greedyMSE object.
#' @param newdata A numeric matrix of new data.
#' @param ... Additional arguments. Ignored.
#' @return A numeric matrix of predictions.
#' @export
predict.greedyMSE <- function(object, newdata, ...) {
stopifnot(
is.matrix(newdata),
is.numeric(newdata),
is.finite(newdata),
ncol(newdata) == nrow(object$model_weights)
)
out <- newdata %*% object$model_weights
if (ncol(out) > 1L) {
out <- out / rowSums(out)
}
out
}
1 change: 0 additions & 1 deletion R/permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ permutationImportance <- function(
is.numeric(preds_orig),
is.finite(preds_orig)
)

# Error of shuffled variables
mae_vars <- shuffled_mae(model, newdata, preds_orig, pred_type, shuffle_idx)

Expand Down
Binary file modified coverage.rds
Binary file not shown.
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ggplot
ggplot2
github
glm
greedyMSE
importances
kable
knitr
Expand Down
25 changes: 25 additions & 0 deletions man/greedyMSE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions man/predict.greedyMSE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions man/print.greedyMSE.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

171 changes: 171 additions & 0 deletions tests/testthat/test-greedyMSE.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Helper function to create a simple dataset
create_dataset <- function(n_samples, n_features = 5L, n_targets = 1L, coef = NULL, noise = 0L) {
X <- matrix(stats::runif(n_samples * n_features), nrow = n_samples)

if (is.null(coef)) {
coef <- matrix(stats::runif(n_features * n_targets), nrow = n_features)
} else if (is.vector(coef)) {
coef <- matrix(coef, ncol = 1L)
}

# Normalize coefficients
coef <- apply(coef, 2L, function(col) col / sum(abs(col)))

Y <- X %*% coef + matrix(stats::rnorm(n_samples * n_targets, 0.0, noise), nrow = n_samples)

if (n_targets > 1L) {
# Ensure all values are positive
Y <- Y - min(Y) + 1e-6
# Normalize rows to sum to 1
Y <- Y / rowSums(Y)
}

list(X = X, Y = Y, coef = coef)
}

# Create datasets for reuse
set.seed(42L)
N <- 100L
regression_data <- create_dataset(N, noise = 0.1)
multi_regression_data <- create_dataset(N, n_targets = 3L)

Y_binary <- matrix(as.integer(regression_data$Y > mean(regression_data$Y)), nrow = N)
Y_multi_binary <- matrix(as.integer(multi_regression_data$Y > mean(multi_regression_data$Y)), nrow = N)

# Test for regression (one col)
testthat::test_that("greedyMSE works for regression", {
model <- greedyMSE(regression_data$X, regression_data$Y)
testthat::expect_lt(model$RMSE, stats::sd(regression_data$Y)) # Model should be better than baseline
# High correlation with true values
testthat::expect_gt(stats::cor(predict(model, regression_data$X), regression_data$Y), 0.8)

testthat::expect_output(print(model), "Greedy MSE")
testthat::expect_output(print(model), "RMSE")
testthat::expect_output(print(model), "Weights")
})

# Test for binary classification (one col)
testthat::test_that("greedyMSE works for binary classification", {
model <- greedyMSE(regression_data$X, Y_binary)
predictions <- predict(model, regression_data$X)
accuracy <- mean((predictions > 0.5) == Y_binary)
testthat::expect_gt(accuracy, 0.7) # Accuracy should be better than random guessing

testthat::expect_output(print(model), "Greedy MSE")
testthat::expect_output(print(model), "RMSE")
testthat::expect_output(print(model), "Weights")
})

# Test for multiple regression (many cols)
testthat::test_that("greedyMSE works for multiple regression", {
model <- greedyMSE(multi_regression_data$X, multi_regression_data$Y)

testthat::expect_lt(model$RMSE, 0.3)
predictions <- predict(model, multi_regression_data$X)
correlations <- diag(stats::cor(predictions, multi_regression_data$Y))
testthat::expect_gt(min(correlations), 0.7) # High correlation for all targets
})

# Test for multiclass classification (many cols)
testthat::test_that("greedyMSE works for multiclass classification", {
model <- greedyMSE(multi_regression_data$X, Y_multi_binary)
predictions <- predict(model, multi_regression_data$X)
accuracy <- mean(apply(predictions, 1L, which.max) == apply(multi_regression_data$Y, 1L, which.max))
testthat::expect_gt(accuracy, 0.6) # Accuracy should be better than random guessing
})

# Edge cases
testthat::test_that("greedyMSE handles edge cases", {
# 1. Single feature
data <- create_dataset(100L, 1L, 1L)
testthat::expect_equal(greedyMSE(data$X, data$Y)$RMSE, 0.0, tolerance = 1e-6)

# 2. Perfect multicollinearity
X <- matrix(1L, nrow = 100L, ncol = 2L)
Y <- X[, 1L] + stats::rnorm(100L, 0.0, 0.1)
testthat::expect_equal(greedyMSE(data$X, data$Y)$RMSE, 0.0, tolerance = 1e-6)

# 3. All zero features
X <- matrix(0L, nrow = 100L, ncol = 5L)
Y <- matrix(stats::rnorm(100L), ncol = 1L)
model <- greedyMSE(X, Y)
testthat::expect_equal(model$RMSE, stats::sd(Y), tolerance = 1e-2)

# 4. Constant target
X <- matrix(stats::runif(500L), nrow = 100L)
Y <- matrix(rep(1L, 100L), ncol = 1L)
model <- greedyMSE(X, Y)
testthat::expect_equal(model$RMSE, 0.5, tolerance = 0.1)

# 5. Very large values
X <- matrix(stats::runif(500L, 1.0e6, 1.0e7), nrow = 100L)
Y <- matrix(rowSums(X) + stats::rnorm(100L, 0L, 1.0e5), ncol = 1L)
model <- greedyMSE(X, Y)
pred <- predict(model, X)
testthat::expect_gt(cor(pred, Y), 0.4)
})

# Regression ensembling test
testthat::test_that("greedyMSE can be used for regression ensembling with GLM, rpart, and RF", {
X <- regression_data$X
Y <- regression_data$Y

# GLM
glm_model <- stats::glm(Y ~ X, family = stats::gaussian())
glm_pred <- stats::predict(glm_model, newdata = data.frame(X))

# rpart
rpart_model <- rpart::rpart(Y ~ X)
rpart_pred <- stats::predict(rpart_model, newdata = data.frame(X))

# greedyMSE
greedy_model <- greedyMSE(X, Y)
greedy_pred <- predict(greedy_model, X)

# Ensemble
ensemble_X <- cbind(glm_pred, rpart_pred, greedy_pred)
ensemble_model <- greedyMSE(ensemble_X, Y)

# Check if ensemble performs better than the best individual model
individual_rmse <- c(
sqrt(mean((Y - glm_pred)^2L)),
sqrt(mean((Y - rpart_pred)^2L)),
greedy_model$RMSE
)
testthat::expect_lte(ensemble_model$RMSE, min(individual_rmse))
})

# Classification ensembling test
testthat::test_that("greedyMSE can be used for classification ensembling with GLM, rpart, and RF", {
X <- regression_data$X
Y_binary <- as.integer(regression_data$Y > stats::median(regression_data$Y))

# GLM (logistic regression)
glm_model <- stats::glm(Y_binary ~ X, family = stats::binomial())
glm_pred <- stats::predict(glm_model, newdata = data.frame(X), type = "response")

# rpart
rpart_model <- rpart::rpart(Y_binary ~ X, method = "class")
rpart_pred <- stats::predict(rpart_model, newdata = data.frame(X), type = "prob")[, 2L]

# Random Forest
rf_model <- randomForest::randomForest(X, as.factor(Y_binary))
rf_pred <- stats::predict(rf_model, newdata = X, type = "prob")[, 2L]

# greedyMSE
greedy_model <- greedyMSE(X, matrix(Y_binary, ncol = 1L))
greedy_pred <- predict(greedy_model, X)

# Ensemble
ensemble_X <- cbind(glm_pred, rpart_pred, rf_pred, greedy_pred)
ensemble_model <- greedyMSE(ensemble_X, matrix(Y_binary, ncol = 1L))

# Check if ensemble performs better than the best individual model
individual_rmse <- c(
sqrt(mean((Y_binary - glm_pred)^2L)),
sqrt(mean((Y_binary - rpart_pred)^2L)),
sqrt(mean((Y_binary - rf_pred)^2L)),
greedy_model$RMSE
)
testthat::expect_lte(ensemble_model$RMSE, min(individual_rmse))
})