diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp index 1a7a6352d3e8..ca608de5fd36 100644 --- a/src/boosting/gbdt_model_text.cpp +++ b/src/boosting/gbdt_model_text.cpp @@ -27,11 +27,12 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const { str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << '\n'; str_buf << "\"label_index\":" << label_idx_ << "," << '\n'; str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << '\n'; - str_buf << "\"average_output\":" << (average_output_ ? "true" : "false") << ",\n"; if (objective_function_ != nullptr) { str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n"; } + str_buf << "\"average_output\":" << (average_output_ ? "true" : "false") << ",\n"; + str_buf << "\"feature_names\":[\"" << Common::Join(feature_names_, "\",\"") << "\"]," << '\n'; @@ -57,7 +58,26 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const { str_buf << models_[i]->ToJSON(); str_buf << "}"; } - str_buf << "]" << '\n'; + str_buf << "]," << '\n'; + + std::vector feature_importances = FeatureImportance(num_iteration, 0); + // store the importance first + std::vector> pairs; + for (size_t i = 0; i < feature_importances.size(); ++i) { + size_t feature_importances_int = static_cast(feature_importances[i]); + if (feature_importances_int > 0) { + pairs.emplace_back(feature_importances_int, feature_names_[i]); + } + } + str_buf << '\n' << "\"feature_importances\":" << "{"; + if (!pairs.empty()) { + str_buf << "\"" << pairs[0].second << "\":" << std::to_string(pairs[0].first); + for (size_t i = 1; i < pairs.size(); ++i) { + str_buf << ","; + str_buf << "\"" << pairs[i].second << "\":" << std::to_string(pairs[i].first); + } + } + str_buf << "}" << '\n'; str_buf << "}" << '\n'; @@ -325,7 +345,7 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons const std::pair& rhs) { return lhs.first > rhs.first; }); - ss << '\n' << "feature importances:" << '\n'; + ss << '\n' << "feature_importances:" << '\n'; for (size_t i = 0; i < pairs.size(); ++i) { ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n'; }