diff --git a/R-package/R/lgb.Predictor.R b/R-package/R/lgb.Predictor.R index c9106b488afa..fa2b14c94614 100644 --- a/R-package/R/lgb.Predictor.R +++ b/R-package/R/lgb.Predictor.R @@ -142,6 +142,10 @@ Predictor <- R6::R6Class( # Check if data is a matrix if (is.matrix(data)) { + # Check whether matrix is the correct type first ("double") + if (storage.mode(data) != "double") { + storage.mode(data) <- "double" + } preds <- lgb.call( "LGBM_BoosterPredictForMat_R" , ret = preds diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R new file mode 100644 index 000000000000..963046619869 --- /dev/null +++ b/R-package/tests/testthat/test_Predictor.R @@ -0,0 +1,18 @@ +context("Predictor") + +test_that("predictions do not fail for integer input", { + X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) + y <- iris[["Sepal.Length"]] + dtrain <- lgb.Dataset(X, label = y) + fit <- lgb.train( + data = dtrain + , objective = "regression" + , verbose = -1L + ) + X_double <- X[c(1L, 51L, 101L), , drop = FALSE] + X_integer <- X_double + storage.mode(X_double) <- "double" + pred_integer <- predict(fit, X_integer) + pred_double <- predict(fit, X_double) + expect_equal(pred_integer, pred_double) +})