Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Keep row names in output from predict #4977

Merged
merged 15 commits into from
Apr 5, 2022
9 changes: 9 additions & 0 deletions R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ Predictor <- R6::R6Class(
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
}

# Keep row names if possible
if (NROW(row.names(data)) && NROW(data) == NROW(preds)) {
if (is.null(dim(preds))) {
names(preds) <- row.names(data)
} else {
row.names(preds) <- row.names(data)
}
}

return(preds)
}

Expand Down
111 changes: 111 additions & 0 deletions R-package/tests/testthat/test_Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ VERBOSITY <- as.integer(
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
)

library(Matrix)

test_that("Predictor$finalize() should not fail", {
X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L)
y <- iris[["Sepal.Length"]]
Expand Down Expand Up @@ -112,6 +114,115 @@ test_that("start_iteration works correctly", {
expect_equal(pred_leaf1, pred_leaf2)
})

.expect_has_row_names <- function(pred, X) {
if (is.vector(pred)) {
rnames <- names(pred)
} else {
rnames <- row.names(pred)
}
expect_false(is.null(rnames))
expect_true(is.vector(rnames))
expect_true(length(rnames) > 0L)
expect_equal(row.names(X), rnames)
}

.expect_doesnt_have_row_names <- function(pred) {
if (is.vector(pred)) {
expect_null(names(pred))
} else {
expect_null(row.names(pred))
}
}

.check_all_row_name_expectations <- function(bst, X) {

# dense matrix with row names
pred <- predict(bst, X)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, rawscore = TRUE)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, predleaf = TRUE)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, predcontrib = TRUE)
.expect_has_row_names(pred, X)

# dense matrix without row names
Xcopy <- X
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy)
.expect_doesnt_have_row_names(pred)

# sparse matrix with row names
Xcsc <- as(X, "CsparseMatrix")
pred <- predict(bst, Xcsc)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, rawscore = TRUE)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, predleaf = TRUE)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, predcontrib = TRUE)
.expect_has_row_names(pred, Xcsc)

# sparse matrix without row names
Xcopy <- Xcsc
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy)
.expect_doesnt_have_row_names(pred)
}

test_that("predict() keeps row names from data (regression)", {
data("mtcars")
X <- as.matrix(mtcars[, -1L])
y <- as.numeric(mtcars[, 1L])
dtrain <- lgb.Dataset(
X
, label = y
, params = list(
max_bins = 5L
, min_data_in_bin = 1L
)
)
bst <- lgb.train(
data = dtrain
, obj = "regression"
, nrounds = 5L
, verbose = VERBOSITY
, params = list(min_data_in_leaf = 1L)
)
.check_all_row_name_expectations(bst, X)
})

test_that("predict() keeps row names from data (binary classification)", {
data(agaricus.train, package = "lightgbm")
X <- as.matrix(agaricus.train$data)
y <- agaricus.train$label
row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "")
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "binary"
, nrounds = 5L
, verbose = VERBOSITY
)
.check_all_row_name_expectations(bst, X)
})

test_that("predict() keeps row names from data (multi-class classification)", {
data(iris)
y <- as.numeric(iris$Species) - 1.0
X <- as.matrix(iris[, names(iris) != "Species"])
row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "")
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "multiclass"
, params = list(num_class = 3L)
, nrounds = 5L
, verbose = VERBOSITY
)
.check_all_row_name_expectations(bst, X)
})

test_that("predictions for regression and binary classification are returned as vectors", {
data(mtcars)
X <- as.matrix(mtcars[, -1L])
Expand Down