Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package][R-package] load parameters from model file (fixes #2613) #5424

Merged
merged 28 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7399fe1
initial work to retrieve parameters from loaded booster
jmoralez Aug 12, 2022
02ca63a
get parameter types and use to parse
jmoralez Aug 15, 2022
c81f768
add test
jmoralez Aug 16, 2022
2ea2c29
merge master
jmoralez Aug 16, 2022
b33d6a0
True for boolean field if it's equal to '1'
jmoralez Aug 16, 2022
c7a6a22
remove bound on cache
jmoralez Aug 16, 2022
f43934e
remove duplicated code
jmoralez Aug 17, 2022
edf11fc
merge remote
jmoralez Aug 17, 2022
7761124
manually parse json string
jmoralez Aug 17, 2022
26ba91f
dont create temporary map. lint
jmoralez Aug 17, 2022
ec113c0
add doc
jmoralez Aug 17, 2022
39c7a8c
minor fixes
jmoralez Aug 27, 2022
0e6591b
revert _get_string_from_c_api. rename parameter to param
jmoralez Aug 28, 2022
d4e781b
add R-package functions
jmoralez Aug 28, 2022
c574a4a
merge master
jmoralez Aug 28, 2022
483a3f4
rename functions to BoosterGetLoadedParam. override array parameters.…
jmoralez Aug 29, 2022
4ab5dd4
add missing types to tests
jmoralez Aug 29, 2022
bd4eec0
fix R params
jmoralez Aug 30, 2022
9a00fde
assert equal dicts
jmoralez Aug 30, 2022
de6ef8a
use boost_from_average as boolean param
jmoralez Aug 30, 2022
2cec692
set boost_from_average to false
jmoralez Aug 30, 2022
f066dba
simplify R's parse_param
jmoralez Aug 30, 2022
db36cb9
parse types on cpp side
jmoralez Aug 31, 2022
9467814
warn about ignoring parameters passed to constructor
jmoralez Sep 21, 2022
339bb1c
Merge branch 'master' into retrieve-params
jmoralez Sep 21, 2022
4cbf477
trigger ci
jmoralez Sep 24, 2022
6667771
Merge branch 'master' into retrieve-params
jmoralez Oct 10, 2022
17ad0c1
trigger ci
jmoralez Oct 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@
return(params_to_aliases)
}

# [description] List of parameter types. Wrapped in a function to take advantage of
# lazy evaluation (so it doesn't matter what order R sources files during installation).
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
# vector of corresponding of their type name in C++.
.PARAMETER_TYPES <- function() {
json_str <- .Call(
LGBM_DumpParamTypes_R
)
param_types <- jsonlite::fromJSON(json_str)
# store in cache so the next call to `.PARAMETER_TYPES()` doesn't need to recompute this
assign(
x = "PARAMETER_TYPES"
, value = param_types
, envir = .lgb_session_cache_env
)
return(param_types)
}

# [description]
# Per https://github.com/microsoft/LightGBM/blob/master/docs/Parameters.rst#metric,
# a few different strings can be used to indicate "no metrics".
Expand Down
48 changes: 48 additions & 0 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Booster <- R6::R6Class(
LGBM_BoosterCreateFromModelfile_R
, modelfile
)
params <- private$get_loaded_param(handle)

} else if (!is.null(model_str)) {

Expand Down Expand Up @@ -727,6 +728,53 @@ Booster <- R6::R6Class(

},

get_loaded_param = function(handle) {
params_str <- .Call(
LGBM_BoosterGetLoadedParam_R
, handle
)
params <- jsonlite::fromJSON(params_str)
param_types <- .PARAMETER_TYPES()

type_name_to_fn <- c(
"string" = as.character
, "int" = as.integer
, "double" = as.numeric
, "bool" = function(x) x == "1"
)

parse_param <- function(value, type_name) {
if (grepl("vector", type_name)) {
eltype_name <- sub("vector<(.*)>", "\\1", type_name)
if (grepl("vector", eltype_name)) {
arr_pat <- "\\[(.*?)\\]"
matches <- regmatches(value, gregexpr(arr_pat, value))[[1L]]
# the previous returns the matches with the square brackets
matches <- sapply(matches, function(x) gsub(arr_pat, "\\1", x))
values <- unname(sapply(matches, parse_param, eltype_name))
} else {
parse_fn <- type_name_to_fn[[eltype_name]]
values <- parse_fn(strsplit(value, ",")[[1L]])
}
return(values)
}
parse_fn <- type_name_to_fn[[type_name]]
parse_fn(value)
}

res <- list()
for (param_name in names(params)) {
value <- parse_param(params[[param_name]], param_types[[param_name]])
if (param_name == "interaction_constraints") {
value <- lapply(value, function(x) x + 1L)
}
res[[param_name]] <- value
}

return(res)

},

inner_eval = function(data_name, data_idx, feval = NULL) {

# Check for unknown dataset (over the maximum provided range)
Expand Down
43 changes: 43 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,47 @@ SEXP LGBM_DumpParamAliases_R() {
R_API_END();
}

SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP params_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data()));
// if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data()));
}
params_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(2);
return params_str;
R_API_END();
}

SEXP LGBM_DumpParamTypes_R() {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
SEXP types_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_DumpParamTypes(buf_len, &out_len, inner_char_buf.data()));
// if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_DumpParamTypes(out_len, &out_len, inner_char_buf.data()));
}
types_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(types_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(2);
return types_str;
R_API_END();
}

// .Call() calls
static const R_CallMethodDef CallEntries[] = {
{"LGBM_HandleIsNull_R" , (DL_FUNC) &LGBM_HandleIsNull_R , 1},
Expand Down Expand Up @@ -1211,6 +1252,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1},
{"LGBM_BoosterGetLoadedParam_R" , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R , 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
Expand Down Expand Up @@ -1238,6 +1280,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0},
{"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0},
{"LGBM_DumpParamTypes_R" , (DL_FUNC) &LGBM_DumpParamTypes_R , 0},
{NULL, NULL, 0}
};

Expand Down
15 changes: 15 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R(
SEXP model_str
);

/*!
* \brief Get parameters as JSON string.
* \param handle Booster handle
* \return R character vector (length=1) with parameters in JSON format
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLoadedParam_R(
SEXP handle
);

/*!
* \brief Merge model in two Boosters to first handle
* \param handle handle primary Booster handle, will merge other handle to this
Expand Down Expand Up @@ -838,4 +847,10 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
*/
LIGHTGBM_C_EXPORT SEXP LGBM_DumpParamAliases_R();

/*!
* \brief Dump parameter types to JSON
* \return R character vector (length=1) with types JSON
*/
LIGHTGBM_C_EXPORT SEXP LGBM_DumpParamTypes_R();

#endif // LIGHTGBM_R_H_
22 changes: 16 additions & 6 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,22 @@ test_that("Loading a Booster from a text file works", {
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
params <- list(
num_leaves = 4L
, boosting = "rf"
, bagging_fraction = 0.8
, bagging_freq = 1L
, force_col_wise = TRUE
, categorical_feature = c(1L, 2L)
, interaction_constraints = list(c(1L, 2L), 1L)
, learning_rate = 1.0
, objective = "binary"
, verbosity = VERBOSITY
)
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, params = list(
num_leaves = 4L
, learning_rate = 1.0
, objective = "binary"
, verbose = VERBOSITY
)
, params = params
, nrounds = 2L
)
expect_true(lgb.is.Booster(bst))
Expand All @@ -199,6 +206,9 @@ test_that("Loading a Booster from a text file works", {
)
pred2 <- predict(bst2, test$data)
expect_identical(pred, pred2)

# check that the parameters are loaded correctly
expect_equal(bst2$params[names(params)], params)
})

test_that("boosters with linear models at leaves can be written to text file and re-loaded successfully", {
Expand Down
34 changes: 34 additions & 0 deletions helpers/parameter_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
along with parameters description in LightGBM/docs/Parameters.rst file
from the information in LightGBM/include/LightGBM/config.h file.
"""
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -373,6 +374,39 @@ def gen_parameter_code(
}

"""
str_to_write += """const std::string Config::ParameterTypes() {
std::stringstream str_buf;
str_buf << "{";"""
int_t_pat = re.compile(r'int\d+_t')
first = True
# the following are stored as comma separated strings but are arrays in the wrappers
overrides = {
'categorical_feature': 'vector<int>',
'ignore_column': 'vector<int>',
'interaction_constraints': 'vector<vector<int>>',
}
for x in infos:
for y in x:
name = y["name"][0]
if name == 'task':
continue
if name in overrides:
param_type = overrides[name]
else:
param_type = int_t_pat.sub('int', y["inner_type"][0]).replace('std::', '')
prefix = f'\n str_buf << "'
if first:
first = False
else:
prefix += ','
str_to_write += f'{prefix}\\"{name}\\": \\"{param_type}\\"";'
str_to_write += """
str_buf << "}";
return str_buf.str();
}

"""

str_to_write += "} // namespace LightGBM\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ class LIGHTGBM_EXPORT Boosting {
*/
static Boosting* CreateBoosting(const std::string& type, const char* filename);

virtual std::string GetLoadedParam() const = 0;

virtual bool IsLinear() const { return false; }

virtual std::string ParserConfigStr() const = 0;
Expand Down
25 changes: 25 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ LIGHTGBM_C_EXPORT int LGBM_DumpParamAliases(int64_t buffer_len,
int64_t* out_len,
char* out_str);

/*!
* \brief Dump all parameter names with their types to JSON.
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str JSON format string of parameters, should pre-allocate memory
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DumpParamTypes(int64_t buffer_len,
int64_t* out_len,
char* out_str);

/*!
* \brief Register a callback function for log redirecting.
* \param callback The callback function to register
Expand Down Expand Up @@ -595,6 +606,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str,
int* out_num_iterations,
BoosterHandle* out);

/*!
* \brief Get parameters as JSON string.
* \param handle Handle of booster.
* \param buffer_len Allocated space for string.
* \param[out] out_len Actual size of string.
* \param[out] out_str JSON string containing parameters.
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLoadedParam(BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str);


/*!
* \brief Free space for booster.
* \param handle Handle of booster to be freed
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ struct Config {
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;
static const std::string ParameterTypes();
static const std::string DumpAliases();

private:
Expand Down
Loading