-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* maybe * try this * IT WORKS * add tests * fix tests * tests pass * tests-pass
- Loading branch information
Showing
9 changed files
with
323 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#' @title Greedy optimization for MSE | ||
#' @description Greedy optimization for minimizing the mean squared error. | ||
#' Works for classification and regression. | ||
#' @param X A numeric matrix of features. | ||
#' @param Y A numeric matrix of target values. | ||
#' @param max_iter An integer scalar of the maximum number of iterations. | ||
#' @return A list with components: | ||
#' \item{model_weights}{A numeric matrix of model_weights.} | ||
#' \item{RMSE}{A numeric scalar of the root mean squared error.} | ||
#' \item{max_iter}{An integer scalar of the maximum number of iterations.} | ||
#' @export | ||
greedyMSE <- function(X, Y, max_iter = 100L) { | ||
stopifnot( | ||
is.matrix(X), is.matrix(Y), | ||
is.numeric(X), is.numeric(Y), | ||
is.finite(X), is.finite(Y), | ||
nrow(X) == nrow(Y), ncol(X) >= 1L, ncol(Y) >= 1L, | ||
is.integer(max_iter), max_iter > 0L | ||
) | ||
|
||
model_weights <- matrix(0L, nrow = ncol(X), ncol = ncol(Y)) | ||
model_update <- diag(ncol(X)) | ||
|
||
for (iter in seq_len(max_iter)) { | ||
for (y_col in seq_len(ncol(Y))) { | ||
target <- Y[, y_col] | ||
w <- model_weights[, y_col] | ||
|
||
# Calculate MSE for incrementing each weight | ||
w_new <- w + model_update | ||
w_new <- w_new / colSums(w_new) | ||
predictions <- X %*% w_new | ||
MSE <- colMeans((predictions - target)^2.0) | ||
|
||
# Update the best weight | ||
best_id <- which.min(MSE) | ||
model_weights[best_id, y_col] <- model_weights[best_id, y_col] + 1L | ||
} | ||
} | ||
|
||
# Output | ||
model_weights <- model_weights / colSums(model_weights) | ||
rownames(model_weights) <- colnames(X) | ||
colnames(model_weights) <- colnames(Y) | ||
RMSE <- sqrt(mean((X %*% model_weights - Y)^2.0)) | ||
out <- list( | ||
model_weights = model_weights, | ||
RMSE = RMSE, | ||
max_iter = max_iter | ||
) | ||
class(out) <- "greedyMSE" | ||
out | ||
} | ||
|
||
#' @title Print method for greedyMSE | ||
#' @description Print method for greedyMSE objects. | ||
#' @param x A greedyMSE object. | ||
#' @param ... Additional arguments. Ignored. | ||
#' @export | ||
print.greedyMSE <- function(x, ...) { | ||
cat("Greedy MSE\n") | ||
cat("RMSE: ", x$RMSE, "\n") | ||
cat("Weights:\n") | ||
print(x$model_weights) | ||
} | ||
|
||
#' @title Predict method for greedyMSE | ||
#' @description Predict method for greedyMSE objects. | ||
#' @param object A greedyMSE object. | ||
#' @param newdata A numeric matrix of new data. | ||
#' @param ... Additional arguments. Ignored. | ||
#' @return A numeric matrix of predictions. | ||
#' @export | ||
predict.greedyMSE <- function(object, newdata, ...) { | ||
stopifnot( | ||
is.matrix(newdata), | ||
is.numeric(newdata), | ||
is.finite(newdata), | ||
ncol(newdata) == nrow(object$model_weights) | ||
) | ||
out <- newdata %*% object$model_weights | ||
if (ncol(out) > 1L) { | ||
out <- out / rowSums(out) | ||
} | ||
out | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ ggplot | |
ggplot2 | ||
github | ||
glm | ||
greedyMSE | ||
importances | ||
kable | ||
knitr | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# Helper function to create a simple dataset | ||
create_dataset <- function(n_samples, n_features = 5L, n_targets = 1L, coef = NULL, noise = 0L) { | ||
X <- matrix(stats::runif(n_samples * n_features), nrow = n_samples) | ||
|
||
if (is.null(coef)) { | ||
coef <- matrix(stats::runif(n_features * n_targets), nrow = n_features) | ||
} else if (is.vector(coef)) { | ||
coef <- matrix(coef, ncol = 1L) | ||
} | ||
|
||
# Normalize coefficients | ||
coef <- apply(coef, 2L, function(col) col / sum(abs(col))) | ||
|
||
Y <- X %*% coef + matrix(stats::rnorm(n_samples * n_targets, 0.0, noise), nrow = n_samples) | ||
|
||
if (n_targets > 1L) { | ||
# Ensure all values are positive | ||
Y <- Y - min(Y) + 1e-6 | ||
# Normalize rows to sum to 1 | ||
Y <- Y / rowSums(Y) | ||
} | ||
|
||
list(X = X, Y = Y, coef = coef) | ||
} | ||
|
||
# Create datasets for reuse | ||
set.seed(42L) | ||
N <- 100L | ||
regression_data <- create_dataset(N, noise = 0.1) | ||
multi_regression_data <- create_dataset(N, n_targets = 3L) | ||
|
||
Y_binary <- matrix(as.integer(regression_data$Y > mean(regression_data$Y)), nrow = N) | ||
Y_multi_binary <- matrix(as.integer(multi_regression_data$Y > mean(multi_regression_data$Y)), nrow = N) | ||
|
||
# Test for regression (one col) | ||
testthat::test_that("greedyMSE works for regression", { | ||
model <- greedyMSE(regression_data$X, regression_data$Y) | ||
testthat::expect_lt(model$RMSE, stats::sd(regression_data$Y)) # Model should be better than baseline | ||
# High correlation with true values | ||
testthat::expect_gt(stats::cor(predict(model, regression_data$X), regression_data$Y), 0.8) | ||
|
||
testthat::expect_output(print(model), "Greedy MSE") | ||
testthat::expect_output(print(model), "RMSE") | ||
testthat::expect_output(print(model), "Weights") | ||
}) | ||
|
||
# Test for binary classification (one col) | ||
testthat::test_that("greedyMSE works for binary classification", { | ||
model <- greedyMSE(regression_data$X, Y_binary) | ||
predictions <- predict(model, regression_data$X) | ||
accuracy <- mean((predictions > 0.5) == Y_binary) | ||
testthat::expect_gt(accuracy, 0.7) # Accuracy should be better than random guessing | ||
|
||
testthat::expect_output(print(model), "Greedy MSE") | ||
testthat::expect_output(print(model), "RMSE") | ||
testthat::expect_output(print(model), "Weights") | ||
}) | ||
|
||
# Test for multiple regression (many cols) | ||
testthat::test_that("greedyMSE works for multiple regression", { | ||
model <- greedyMSE(multi_regression_data$X, multi_regression_data$Y) | ||
|
||
testthat::expect_lt(model$RMSE, 0.3) | ||
predictions <- predict(model, multi_regression_data$X) | ||
correlations <- diag(stats::cor(predictions, multi_regression_data$Y)) | ||
testthat::expect_gt(min(correlations), 0.7) # High correlation for all targets | ||
}) | ||
|
||
# Test for multiclass classification (many cols) | ||
testthat::test_that("greedyMSE works for multiclass classification", { | ||
model <- greedyMSE(multi_regression_data$X, Y_multi_binary) | ||
predictions <- predict(model, multi_regression_data$X) | ||
accuracy <- mean(apply(predictions, 1L, which.max) == apply(multi_regression_data$Y, 1L, which.max)) | ||
testthat::expect_gt(accuracy, 0.6) # Accuracy should be better than random guessing | ||
}) | ||
|
||
# Edge cases | ||
testthat::test_that("greedyMSE handles edge cases", { | ||
# 1. Single feature | ||
data <- create_dataset(100L, 1L, 1L) | ||
testthat::expect_equal(greedyMSE(data$X, data$Y)$RMSE, 0.0, tolerance = 1e-6) | ||
|
||
# 2. Perfect multicollinearity | ||
X <- matrix(1L, nrow = 100L, ncol = 2L) | ||
Y <- X[, 1L] + stats::rnorm(100L, 0.0, 0.1) | ||
testthat::expect_equal(greedyMSE(data$X, data$Y)$RMSE, 0.0, tolerance = 1e-6) | ||
|
||
# 3. All zero features | ||
X <- matrix(0L, nrow = 100L, ncol = 5L) | ||
Y <- matrix(stats::rnorm(100L), ncol = 1L) | ||
model <- greedyMSE(X, Y) | ||
testthat::expect_equal(model$RMSE, stats::sd(Y), tolerance = 1e-2) | ||
|
||
# 4. Constant target | ||
X <- matrix(stats::runif(500L), nrow = 100L) | ||
Y <- matrix(rep(1L, 100L), ncol = 1L) | ||
model <- greedyMSE(X, Y) | ||
testthat::expect_equal(model$RMSE, 0.5, tolerance = 0.1) | ||
|
||
# 5. Very large values | ||
X <- matrix(stats::runif(500L, 1.0e6, 1.0e7), nrow = 100L) | ||
Y <- matrix(rowSums(X) + stats::rnorm(100L, 0L, 1.0e5), ncol = 1L) | ||
model <- greedyMSE(X, Y) | ||
pred <- predict(model, X) | ||
testthat::expect_gt(cor(pred, Y), 0.4) | ||
}) | ||
|
||
# Regression ensembling test | ||
testthat::test_that("greedyMSE can be used for regression ensembling with GLM, rpart, and RF", { | ||
X <- regression_data$X | ||
Y <- regression_data$Y | ||
|
||
# GLM | ||
glm_model <- stats::glm(Y ~ X, family = stats::gaussian()) | ||
glm_pred <- stats::predict(glm_model, newdata = data.frame(X)) | ||
|
||
# rpart | ||
rpart_model <- rpart::rpart(Y ~ X) | ||
rpart_pred <- stats::predict(rpart_model, newdata = data.frame(X)) | ||
|
||
# greedyMSE | ||
greedy_model <- greedyMSE(X, Y) | ||
greedy_pred <- predict(greedy_model, X) | ||
|
||
# Ensemble | ||
ensemble_X <- cbind(glm_pred, rpart_pred, greedy_pred) | ||
ensemble_model <- greedyMSE(ensemble_X, Y) | ||
|
||
# Check if ensemble performs better than the best individual model | ||
individual_rmse <- c( | ||
sqrt(mean((Y - glm_pred)^2L)), | ||
sqrt(mean((Y - rpart_pred)^2L)), | ||
greedy_model$RMSE | ||
) | ||
testthat::expect_lte(ensemble_model$RMSE, min(individual_rmse)) | ||
}) | ||
|
||
# Classification ensembling test | ||
testthat::test_that("greedyMSE can be used for classification ensembling with GLM, rpart, and RF", { | ||
X <- regression_data$X | ||
Y_binary <- as.integer(regression_data$Y > stats::median(regression_data$Y)) | ||
|
||
# GLM (logistic regression) | ||
glm_model <- stats::glm(Y_binary ~ X, family = stats::binomial()) | ||
glm_pred <- stats::predict(glm_model, newdata = data.frame(X), type = "response") | ||
|
||
# rpart | ||
rpart_model <- rpart::rpart(Y_binary ~ X, method = "class") | ||
rpart_pred <- stats::predict(rpart_model, newdata = data.frame(X), type = "prob")[, 2L] | ||
|
||
# Random Forest | ||
rf_model <- randomForest::randomForest(X, as.factor(Y_binary)) | ||
rf_pred <- stats::predict(rf_model, newdata = X, type = "prob")[, 2L] | ||
|
||
# greedyMSE | ||
greedy_model <- greedyMSE(X, matrix(Y_binary, ncol = 1L)) | ||
greedy_pred <- predict(greedy_model, X) | ||
|
||
# Ensemble | ||
ensemble_X <- cbind(glm_pred, rpart_pred, rf_pred, greedy_pred) | ||
ensemble_model <- greedyMSE(ensemble_X, matrix(Y_binary, ncol = 1L)) | ||
|
||
# Check if ensemble performs better than the best individual model | ||
individual_rmse <- c( | ||
sqrt(mean((Y_binary - glm_pred)^2L)), | ||
sqrt(mean((Y_binary - rpart_pred)^2L)), | ||
sqrt(mean((Y_binary - rf_pred)^2L)), | ||
greedy_model$RMSE | ||
) | ||
testthat::expect_lte(ensemble_model$RMSE, min(individual_rmse)) | ||
}) |