Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package][R-package] load parameters from model file (fixes #2613) #5424

Merged
merged 28 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7399fe1
initial work to retrieve parameters from loaded booster
jmoralez Aug 12, 2022
02ca63a
get parameter types and use to parse
jmoralez Aug 15, 2022
c81f768
add test
jmoralez Aug 16, 2022
2ea2c29
merge master
jmoralez Aug 16, 2022
b33d6a0
True for boolean field if it's equal to '1'
jmoralez Aug 16, 2022
c7a6a22
remove bound on cache
jmoralez Aug 16, 2022
f43934e
remove duplicated code
jmoralez Aug 17, 2022
edf11fc
merge remote
jmoralez Aug 17, 2022
7761124
manually parse json string
jmoralez Aug 17, 2022
26ba91f
dont create temporary map. lint
jmoralez Aug 17, 2022
ec113c0
add doc
jmoralez Aug 17, 2022
39c7a8c
minor fixes
jmoralez Aug 27, 2022
0e6591b
revert _get_string_from_c_api. rename parameter to param
jmoralez Aug 28, 2022
d4e781b
add R-package functions
jmoralez Aug 28, 2022
c574a4a
merge master
jmoralez Aug 28, 2022
483a3f4
rename functions to BoosterGetLoadedParam. override array parameters.…
jmoralez Aug 29, 2022
4ab5dd4
add missing types to tests
jmoralez Aug 29, 2022
bd4eec0
fix R params
jmoralez Aug 30, 2022
9a00fde
assert equal dicts
jmoralez Aug 30, 2022
de6ef8a
use boost_from_average as boolean param
jmoralez Aug 30, 2022
2cec692
set boost_from_average to false
jmoralez Aug 30, 2022
f066dba
simplify R's parse_param
jmoralez Aug 30, 2022
db36cb9
parse types on cpp side
jmoralez Aug 31, 2022
9467814
warn about ignoring parameters passed to constructor
jmoralez Sep 21, 2022
339bb1c
Merge branch 'master' into retrieve-params
jmoralez Sep 21, 2022
4cbf477
trigger ci
jmoralez Sep 24, 2022
6667771
Merge branch 'master' into retrieve-params
jmoralez Oct 10, 2022
17ad0c1
trigger ci
jmoralez Oct 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions helpers/parameter_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
along with parameters description in LightGBM/docs/Parameters.rst file
from the information in LightGBM/include/LightGBM/config.h file.
"""
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -373,6 +374,30 @@ def gen_parameter_code(
}

"""
str_to_write += """const std::string Config::ParameterTypes() {
std::stringstream str_buf;
str_buf << "{";"""
int_t_pat = re.compile(r'int\d+_t')
first = True
for x in infos:
for y in x:
if "[doc-only]" in y:
continue
param_type = int_t_pat.sub('int', y["inner_type"][0]).replace('std::', '')
name = y["name"][0]
prefix = f'\n str_buf << "'
if first:
first = False
else:
prefix += ','
str_to_write += f'{prefix}\\"{name}\\": \\"{param_type}\\"";'
str_to_write += """
str_buf << "}";
return str_buf.str();
}

"""

str_to_write += "} // namespace LightGBM\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ class LIGHTGBM_EXPORT Boosting {
*/
static Boosting* CreateBoosting(const std::string& type, const char* filename);

virtual std::string GetParameters() const = 0;

virtual bool IsLinear() const { return false; }

virtual std::string ParserConfigStr() const = 0;
Expand Down
25 changes: 25 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ LIGHTGBM_C_EXPORT int LGBM_DumpParamAliases(int64_t buffer_len,
int64_t* out_len,
char* out_str);

/*!
* \brief Dump all parameter names with their types to JSON.
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str JSON format string of parameters, should pre-allocate memory
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DumpParameterTypes(int64_t buffer_len,
int64_t* out_len,
char* out_str);

/*!
* \brief Register a callback function for log redirecting.
* \param callback The callback function to register
Expand Down Expand Up @@ -595,6 +606,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str,
int* out_num_iterations,
BoosterHandle* out);

/*!
* \brief Get parameters as JSON string.
* \param handle Handle of booster.
* \param buffer_len Allocated space for string.
* \param[out] out_len Actual size of string.
* \param[out] out_str JSON string containing parameters.
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetParameters(BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str);


/*!
* \brief Free space for booster.
* \param handle Handle of booster to be freed
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ struct Config {
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;
static const std::string ParameterTypes();
static const std::string DumpAliases();

private:
Expand Down
72 changes: 53 additions & 19 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
from functools import wraps
from functools import lru_cache, wraps
from os import SEEK_END, environ
from os.path import getsize
from pathlib import Path
Expand Down Expand Up @@ -156,6 +156,28 @@ def _safe_call(ret: int) -> None:
raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8'))


def _get_string_from_c_api(func: Callable, booster_handle: Optional[ctypes.c_void_p] = None) -> str:
def c_api_call(buffer_len: int, out_len: ctypes.c_int64):
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
args = (ctypes.c_int64(buffer_len), ctypes.byref(out_len), ptr_string_buffer)
if booster_handle is None:
f = func(*args)
else:
f = func(booster_handle, *args)
_safe_call(f)
return ptr_string_buffer.value.decode('utf-8')

buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
res = c_api_call(buffer_len, tmp_out_len)
actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len:
res = c_api_call(actual_len, tmp_out_len)
return res


def _is_numeric(obj: Any) -> bool:
"""Check whether object is a number or not, include numpy number, etc."""
try:
Expand Down Expand Up @@ -357,25 +379,9 @@ class _ConfigAliases:

@staticmethod
def _get_all_param_aliases() -> Dict[str, List[str]]:
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_DumpParamAliases(
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_DumpParamAliases(
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
aliases_str = _get_string_from_c_api(_LIB.LGBM_DumpParamAliases)
aliases = json.loads(
string_buffer.value.decode('utf-8'),
aliases_str,
object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
)
return aliases
Expand Down Expand Up @@ -455,6 +461,14 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
return params


@lru_cache(maxsize=None)
def _get_parameter_types() -> Dict[str, str]:
types_str = _get_string_from_c_api(_LIB.LGBM_DumpParameterTypes)
res = json.loads(types_str)
res['categorical_feature'] = 'vector<int>'
return res


MAX_INT32 = (1 << 31) - 1

"""Macro definition of data type in C API of LightGBM"""
Expand Down Expand Up @@ -2738,6 +2752,8 @@ def __init__(
else:
raise TypeError('Need at least one training dataset or model file or model string '
'to create Booster instance')
if model_file is not None or model_str is not None:
params = self._get_parameters()
self.params = params

def __del__(self) -> None:
Expand Down Expand Up @@ -2781,6 +2797,24 @@ def __setstate__(self, state):
state['handle'] = handle
self.__dict__.update(state)

def _get_parameters(self) -> Dict[str, Any]:
params_str = _get_string_from_c_api(_LIB.LGBM_BoosterGetParameters, self.handle)
params = json.loads(params_str)
ptypes = _get_parameter_types()
types_dict = {'string': str, 'int': int, 'double': float, 'bool': lambda x: x == '1'}

def parse_param(value: str, type_name: str) -> Union[Any, List[Any]]:
if 'vector' in type_name:
if not value:
return []
eltype_name = type_name[type_name.find('<') + 1:type_name.find('>')]
eltype = types_dict[eltype_name]
return [eltype(v) for v in value.split(',')]
eltype = types_dict[type_name]
return eltype(value)

return {param: parse_param(value, ptypes.get(param, 'string')) for param, value in params.items()}

def free_dataset(self) -> "Booster":
"""Free Booster's Datasets.

Expand Down
27 changes: 27 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,33 @@ class GBDT : public GBDTBase {
*/
int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; }

/*!
* \brief Get parameters as a JSON string
*/
std::string GetParameters() const override {
if (loaded_parameter_.empty()) {
return std::string("{}");
}
std::stringstream str_buf;
str_buf << "{";
const auto lines = Common::Split(loaded_parameter_.c_str(), "\n");
bool first = true;
for (const auto& line : lines) {
const auto pair = Common::Split(line.c_str(), "[:]");
if (pair[1] != " ") {
if (first) {
first = false;
str_buf << "\"";
} else {
str_buf << ",\"";
}
str_buf << pair[0] << "\": \"" << Common::Trim(pair[1]) << "\"";
}
}
str_buf << "}";
return str_buf.str();
}

/*!
* \brief Can use early stopping for prediction or not
* \return True if cannot use early stopping for prediction
Expand Down
27 changes: 27 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,18 @@ int LGBM_DumpParamAliases(int64_t buffer_len,
API_END();
}

int LGBM_DumpParameterTypes(int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
std::string ptypes = Config::ParameterTypes();
*out_len = static_cast<int64_t>(ptypes.size()) + 1;
if (*out_len <= buffer_len) {
std::memcpy(out_str, ptypes.c_str(), *out_len);
}
API_END();
}

int LGBM_RegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN();
Log::ResetCallBack(callback);
Expand Down Expand Up @@ -1748,6 +1760,21 @@ int LGBM_BoosterLoadModelFromString(
API_END();
}

int LGBM_BoosterGetParameters(
BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string params = ref_booster->GetBoosting()->GetParameters();
*out_len = static_cast<int64_t>(params.size()) + 1;
if (*out_len <= buffer_len) {
std::memcpy(out_str, params.c_str(), *out_len);
}
API_END();
}

#ifdef _MSC_VER
#pragma warning(disable : 4702)
#endif
Expand Down
Loading