Skip to content

Commit

Permalink
happy path for sparse matrix passed to fit()
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Sep 6, 2024
1 parent 0e37996 commit 236a39b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 10 deletions.
8 changes: 1 addition & 7 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,7 @@ fit.model_spec <-
}

if (is_sparse_matrix(data)) {
outcome_names <- all.names(rlang::f_lhs(formula))
outcome_ind <- match(outcome_names, colnames(data))

y <- data[, outcome_ind]
x <- data[, -outcome_ind, drop = TRUE]

return(fit_xy(object, x, y, case_weights, control, ...))
data <- sparsevctrs::coerce_to_sparse_tibble(data)
}

dots <- quos(...)
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
Code
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
Condition
Error in `fit_xy()`:
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.
Warning:
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse tibble can be passed to `fit_xy()

Expand Down
1 change: 0 additions & 1 deletion tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ test_that("sparse matrix can be passed to `fit()", {
set_engine("lm")

expect_snapshot(
error = TRUE,
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
)
})
Expand Down

0 comments on commit 236a39b

Please sign in to comment.