From 517ecfbc235a16241fe464df6bd63f31e93fc343 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 30 Sep 2024 11:57:41 -0700 Subject: [PATCH 1/4] don't turn sparse matrix into dense matrix for glmnet prediction --- R/glmnet-engines.R | 11 +++++++++++ R/linear_reg_data.R | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index 296468e49..65f0a7ac8 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -138,6 +138,15 @@ predict_raw._glmnetfit <- predict_raw_glmnet unname(x[, 1]) } +organize_glmnet_pre_pred <- function(x, object) { + if (is_sparse_matrix(x)) { + return(x) + } + + as.matrix(x[, rownames(object$fit$beta), drop = FALSE]) +} + + organize_glmnet_class <- function(x, object) { prob_to_class_2(x[, 1], object) } @@ -166,6 +175,8 @@ organize_multnet_prob <- function(x, object) { x } + + # ------------------------------------------------------------------------- multi_predict_glmnet <- function(object, diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index bdf6a3753..82e13fc49 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -249,7 +249,7 @@ set_pred( args = list( object = expr(object$fit), - newx = expr(as.matrix(new_data[, rownames(object$fit$beta), drop = FALSE])), + newx = expr(.organize_glmnet_pre_pred(new_data, object)), type = "response", s = expr(object$spec$args$penalty) ) From 69046446761bd7fce4285c33b285ac4ecb49862b Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 30 Sep 2024 12:02:56 -0700 Subject: [PATCH 2/4] remember to do subsetting --- R/glmnet-engines.R | 3 ++- R/linear_reg_data.R | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index 65f0a7ac8..e80591585 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -139,11 +139,12 @@ predict_raw._glmnetfit <- predict_raw_glmnet } organize_glmnet_pre_pred <- function(x, object) { + x <- x[, rownames(object$fit$beta), drop = FALSE] if (is_sparse_matrix(x)) { return(x) } - as.matrix(x[, rownames(object$fit$beta), drop = FALSE]) + as.matrix(x) } diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 82e13fc49..c8163fa2e 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -249,7 +249,7 @@ set_pred( args = list( object = expr(object$fit), - newx = expr(.organize_glmnet_pre_pred(new_data, object)), + newx = expr(organize_glmnet_pre_pred(new_data, object)), type = "response", s = expr(object$spec$args$penalty) ) From 6a22f858bf03c2fc98badf52e9a925cd1240faa2 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 30 Sep 2024 12:04:15 -0700 Subject: [PATCH 3/4] don't add random newlines --- R/glmnet-engines.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index e80591585..067825c7b 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -176,8 +176,6 @@ organize_multnet_prob <- function(x, object) { x } - - # ------------------------------------------------------------------------- multi_predict_glmnet <- function(object, From ba56e2a68fe30c4002a8711259d777a80e4c6c42 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 2 Oct 2024 09:54:32 -0700 Subject: [PATCH 4/4] test glmnet predict doesn't remove sparseness --- tests/testthat/_snaps/sparsevctrs.md | 8 ++++++++ tests/testthat/test-sparsevctrs.R | 29 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 3619bf8dd..798fe0a72 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -127,3 +127,11 @@ Error in `maybe_sparse_matrix()`: ! no sparse vectors detected +# we don't run as.matrix() on sparse matrix for glmnet pred #1210 + + Code + predict(lm_fit, hotel_data) + Condition + Error in `predict.elnet()`: + ! data is sparse + diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 7759006e8..8dcbd796b 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -313,3 +313,32 @@ test_that("maybe_sparse_matrix() is used correctly", { fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[, 1]) ) }) + +test_that("we don't run as.matrix() on sparse matrix for glmnet pred #1210", { + skip_if_not_installed("glmnet") + + local_mocked_bindings( + predict.elnet = function(object, newx, ...) { + if (is_sparse_matrix(newx)) { + stop("data is sparse") + } else { + stop("data isn't sparse (should not happen)") + } + }, + .package = "glmnet" + ) + + hotel_data <- sparse_hotel_rates() + + spec <- linear_reg(penalty = 0) %>% + set_mode("regression") %>% + set_engine("glmnet") + + lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + + expect_snapshot( + error = TRUE, + predict(lm_fit, hotel_data) + ) +}) +