Skip to content

Commit

Permalink
[R-package] Added unit tests on creating Booster from file or string (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Apr 8, 2020
1 parent 5c201e4 commit 91ce04b
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 20 deletions.
33 changes: 14 additions & 19 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -795,32 +795,27 @@ predict.lgb.Booster <- function(object,
#' @export
lgb.load <- function(filename = NULL, model_str = NULL) {

if (is.null(filename) && is.null(model_str)) {
stop("lgb.load: either filename or model_str must be given")
}

# Load from filename
if (!is.null(filename) && !is.character(filename)) {
stop("lgb.load: filename should be character")
}
filename_provided <- !is.null(filename)
model_str_provided <- !is.null(model_str)

# Return new booster
if (!is.null(filename) && !file.exists(filename)) {
stop("lgb.load: file does not exist for supplied filename")
}
if (!is.null(filename)) {
if (filename_provided) {
if (!is.character(filename)) {
stop("lgb.load: filename should be character")
}
if (!file.exists(filename)) {
stop(sprintf("lgb.load: file '%s' passed to filename does not exist", filename))
}
return(invisible(Booster$new(modelfile = filename)))
}

# Load from model_str
if (!is.null(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be character")
}
# Return new booster
if (!is.null(model_str)) {
if (model_str_provided) {
if (!is.character(model_str)) {
stop("lgb.load: model_str should be character")
}
return(invisible(Booster$new(model_str = model_str)))
}

stop("lgb.load: either filename or model_str must be given")
}

#' @name lgb.save
Expand Down
1 change: 0 additions & 1 deletion R-package/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ export CC=/usr/local/bin/gcc-8
Rscript build_r.R

# Get coverage
rm -rf lightgbm_r/build
Rscript -e " \
coverage <- covr::package_coverage('./lightgbm_r', quiet=FALSE);
print(coverage);
Expand Down
139 changes: 139 additions & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,142 @@ test_that("lgb.get.eval.result() should throw an informative error for incorrect
)
}, regexp = "Only the following eval_names exist for dataset.*\\: \\[l2\\]", fixed = FALSE)
})

context("lgb.load()")

test_that("lgb.load() gives the expected error messages given different incorrect inputs", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = 2L
, objective = "binary"
)

# you have to give model_str or filename
expect_error({
lgb.load()
}, regexp = "either filename or model_str must be given")
expect_error({
lgb.load(filename = NULL, model_str = NULL)
}, regexp = "either filename or model_str must be given")

# if given, filename should be a string that points to an existing file
out_file <- "lightgbm.model"
expect_error({
lgb.load(filename = list(out_file))
}, regexp = "filename should be character")
file_to_check <- paste0("a.model")
while (file.exists(file_to_check)) {
file_to_check <- paste0("a", file_to_check)
}
expect_error({
lgb.load(filename = file_to_check)
}, regexp = "passed to filename does not exist")

# if given, model_str should be a string
expect_error({
lgb.load(model_str = c(4.0, 5.0, 6.0))
}, regexp = "model_str should be character")

})

test_that("Loading a Booster from a file works", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = 2L
, objective = "binary"
)
expect_true(lgb.is.Booster(bst))

pred <- predict(bst, test$data)
lgb.save(bst, "lightgbm.model")

# finalize the booster and destroy it so you know we aren't cheating
bst$finalize()
expect_null(bst$.__enclos_env__$private$handle)
rm(bst)

bst2 <- lgb.load(
filename = "lightgbm.model"
)
pred2 <- predict(bst2, test$data)
expect_identical(pred, pred2)
})

test_that("Loading a Booster from a string works", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = 2L
, objective = "binary"
)
expect_true(lgb.is.Booster(bst))

pred <- predict(bst, test$data)
model_string <- bst$save_model_to_string()

# finalize the booster and destroy it so you know we aren't cheating
bst$finalize()
expect_null(bst$.__enclos_env__$private$handle)
rm(bst)

bst2 <- lgb.load(
model_str = model_string
)
pred2 <- predict(bst2, test$data)
expect_identical(pred, pred2)
})

test_that("If a string and a file are both passed to lgb.load() the file is used model_str is totally ignored", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = 2L
, objective = "binary"
)
expect_true(lgb.is.Booster(bst))

pred <- predict(bst, test$data)
lgb.save(bst, "lightgbm.model")

# finalize the booster and destroy it so you know we aren't cheating
bst$finalize()
expect_null(bst$.__enclos_env__$private$handle)
rm(bst)

bst2 <- lgb.load(
filename = "lightgbm.model"
, model_str = 4.0
)
pred2 <- predict(bst2, test$data)
expect_identical(pred, pred2)
})

0 comments on commit 91ce04b

Please sign in to comment.