Skip to content

Commit

Permalink
made test_booster_load_params_when_passed_model_str pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuyi Xue committed Nov 14, 2021
1 parent 874e635 commit ef774a3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ class LIGHTGBM_EXPORT Boosting {
static Boosting* CreateBoosting(const std::string& type, const char* filename);

virtual bool IsLinear() const { return false; }

std::string loaded_parameter_;
};

class GBDTBase : public Boosting {
Expand Down
5 changes: 5 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str,
int* out_num_iterations,
BoosterHandle* out);

LIGHTGBM_C_EXPORT int LGBM_BoosterGetConfig(BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str);

/*!
* \brief Free space for booster.
* \param handle Handle of booster to be freed
Expand Down
21 changes: 20 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Wrapper for C API of LightGBM."""
import abc
import ctypes
import io
import json
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -2664,6 +2665,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
self.train_set_version = train_set.version
self.params = params
elif model_file is not None:
# Prediction task
out_num_iterations = ctypes.c_int(0)
Expand All @@ -2678,12 +2680,13 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
self.params = params
elif model_str is not None:
self.params = {}
self.model_from_string(model_str, verbose="_silent_false")
else:
raise TypeError('Need at least one training dataset or model file or model string '
'to create Booster instance')
self.params = params

def __del__(self):
try:
Expand Down Expand Up @@ -3384,6 +3387,22 @@ def model_from_string(self, model_str, verbose='warn'):
c_str(model_str),
ctypes.byref(out_num_iterations),
ctypes.byref(self.handle)))

buffer_len = 2 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterGetConfig(
self.handle,
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
_params_str = string_buffer.value.decode('utf-8')
# print(f'{_params_str=:}')
for line in io.StringIO(_params_str):
if line.startswith('[boosting: '):
self.params['boosting'] = line.strip().replace(f"[boosting: ", "").replace("]", "")

out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle,
Expand Down
1 change: 0 additions & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ class GBDT : public GBDTBase {
bool average_output_;
bool need_re_bagging_;
bool balanced_bagging_;
std::string loaded_parameter_;
std::vector<int8_t> monotone_constraints_;
const int bagging_rand_block_ = 1024;
std::vector<Random> bagging_rands_;
Expand Down
18 changes: 18 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,24 @@ int LGBM_BoosterLoadModelFromString(
API_END();
}

int LGBM_BoosterGetConfig(
BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str
) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
// GBDT* ref_gbdt = reinterpret_cast<GBDT*>(ref_booster);
std::string params = ref_booster->GetBoosting()->loaded_parameter_;
// std::string params = "abc";
*out_len = static_cast<int64_t>(params.size()) + 1;
if (*out_len <= buffer_len) {
std::memcpy(out_str, params.c_str(), *out_len);
}
API_END();
}

#ifdef _MSC_VER
#pragma warning(disable : 4702)
#endif
Expand Down
66 changes: 66 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# coding: utf-8
import filecmp
import numbers
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from scipy import sparse
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
Expand Down Expand Up @@ -557,3 +559,67 @@ def test_init_score_for_multiclass_classification(init_score_type):
ds = lgb.Dataset(data, init_score=init_score).construct()
np.testing.assert_equal(ds.get_field('init_score'), init_score)
np.testing.assert_equal(ds.init_score, init_score)


@pytest.fixture(name="fake_model")
def _fake_model() -> lgb.Booster:
# TODO: maybe removed deps on data_dir later
data_dir = Path(__file__).parent.parent.parent / "examples/binary_classification"

df_train = pd.read_csv(data_dir / "binary.train", header=None, sep="\t")
weights = pd.read_csv(data_dir / "binary.train.weight", header=None)[0]

# df_train = pd.DataFrame({
# 'feature_1': [0, 1, 2],
# 'feature_2': [0., 0.1, 0.2],
# 'feature_3': pd.Categorical(['a', 'b', 'b']),
# 'Target':
# })

X_train = df_train.drop(0, axis=1)
y_train = df_train[0]

params = {
"boosting_type": "gbdt",
"objective": "binary",
"metric": "binary_logloss",
"num_leaves": 31,
"learning_rate": 0.05,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbose": 0,
}
lgb_train = lgb.Dataset(X_train, y_train, weight=weights, free_raw_data=False)
feature_name = [f"feature_{col}" for col in X_train.columns]

gbm = lgb.train(
params,
lgb_train,
num_boost_round=3,
valid_sets=lgb_train, # eval training data
feature_name=feature_name,
categorical_feature=[21],
)

return gbm

def test_booster_load_params_when_passed_model_file(fake_model: lgb.Booster) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
model_file = Path(temp_dir) / "model.txt"
fake_model.save_model(model_file)
# gbm.save_model('/Users/zhuyi/Projects/third-party/lightgbm/model.txt')

loaded = lgb.Booster(model_file=model_file)

# TODO: needs parse more params
assert 'boosting' in loaded.params

def test_booster_load_params_when_passed_model_str(fake_model: lgb.Booster) -> None:
model_str = fake_model.model_to_string()

loaded = lgb.Booster(model_str=model_str)

# TODO: needs parse more params
assert 'boosting' in loaded.params
assert loaded.params['boosting'] == 'gbdt'

0 comments on commit ef774a3

Please sign in to comment.