From b26e983e7a72d727b1a0b3b0621a580e3d13bf7c Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 7 Jul 2018 17:28:02 +0800 Subject: [PATCH] fix load parameter --- src/boosting/gbdt.h | 2 ++ src/boosting/gbdt_model_text.cpp | 29 +++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index c96d10f93ae5..d142766f20d6 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -456,8 +456,10 @@ class GBDT : public GBDTBase { std::unique_ptr loaded_objective_; bool average_output_; bool need_re_bagging_; + std::string loaded_parameter_; Json forced_splits_json_; + }; } // namespace LightGBM diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp index c57c657159c0..cbf9cef6e305 100644 --- a/src/boosting/gbdt_model_text.cpp +++ b/src/boosting/gbdt_model_text.cpp @@ -280,6 +280,7 @@ std::string GBDT::SaveModelToString(int num_iteration) const { ss << tree_strs[i]; tree_strs[i].clear(); } + ss << "end of trees" << "\n"; std::vector feature_importances = FeatureImportance(num_iteration, 0); // store the importance first @@ -301,8 +302,13 @@ std::string GBDT::SaveModelToString(int num_iteration) const { ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n'; } if (config_ != nullptr) { - ss << "parameters:" << '\n'; + ss << "\nparameters:" << '\n'; ss << config_->ToString() << "\n"; + ss << "end of parameters" << '\n'; + } else if (!loaded_parameter_.empty()) { + ss << "\nparameters:" << '\n'; + ss << loaded_parameter_ << "\n"; + ss << "end of parameters" << '\n'; } return ss.str(); } @@ -465,7 +471,26 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { num_iteration_for_pred_ = static_cast(models_.size()) / num_tree_per_iteration_; num_init_iteration_ = num_iteration_for_pred_; iter_ = 0; - + bool is_inparameter = false; + std::stringstream ss; + while (p < end) { + auto line_len = Common::GetLine(p); + std::string cur_line(p, line_len); + if (line_len > 0) { + if (cur_line == std::string("parameters:")) { + is_inparameter = true; + } else if (cur_line == std::string("end of parameters")) { + break; + } else if (is_inparameter) { + ss << cur_line << "\n"; + } + } + p += line_len; + p = Common::SkipNewLine(p); + } + if (!ss.str().empty()) { + loaded_parameter_ = ss.str(); + } return true; }