-
Notifications
You must be signed in to change notification settings - Fork 66
[ML] Return total SHAP per feature as a new result type #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
10e5a6c
df13ed0
874140c
6f40db9
38a180a
08bd4ec
13357af
3f7ec0a
0ba97a0
30c109b
f7689c3
d3758d4
8700f5f
6fe6399
f8126c8
1672019
3f8f6c2
1c3cfaf
1452f28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License; | ||
| * you may not use this file except in compliance with the Elastic License. | ||
| */ | ||
| #ifndef INCLUDED_ml_api_CInferenceModelMetadata_h | ||
| #define INCLUDED_ml_api_CInferenceModelMetadata_h | ||
|
|
||
| #include <maths/CBasicStatistics.h> | ||
| #include <maths/CLinearAlgebraEigen.h> | ||
|
|
||
| #include <api/CInferenceModelDefinition.h> | ||
| #include <api/ImportExport.h> | ||
|
|
||
| #include <string> | ||
|
|
||
| namespace ml { | ||
| namespace api { | ||
|
|
||
| class API_EXPORT CInferenceModelMetadata : public CSerializableToJsonDocument { | ||
| public: | ||
| using TVector = maths::CDenseVector<double>; | ||
|
|
||
| public: | ||
| CInferenceModelMetadata() : m_TotalShapValues(){}; | ||
| void addToJsonDocument(rapidjson::Value& parentObject, TRapidJsonWriter& writer) const override; | ||
| void columnNames(const std::vector<std::string>& columnNames); | ||
| const std::string& typeString() const; | ||
| void addToFeatureImportance(std::size_t i, const TVector& values); | ||
|
|
||
| private: | ||
| using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<TVector>::TAccumulator; | ||
| using TTotalShapValues = std::unordered_map<std::size_t, TMeanAccumulator>; | ||
|
|
||
| private: | ||
| TTotalShapValues m_TotalShapValues; | ||
| std::vector<std::string> m_ColumnNames; | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| #endif //INCLUDED_ml_api_CInferenceModelMetadata_h |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| #include <maths/CBoostedTreeLoss.h> | ||
| #include <maths/CDataFramePredictiveModel.h> | ||
| #include <maths/CDataFrameUtils.h> | ||
| #include <maths/CLinearAlgebraEigen.h> | ||
| #include <maths/COrderings.h> | ||
| #include <maths/CTools.h> | ||
| #include <maths/CTreeShapFeatureImportance.h> | ||
|
|
@@ -27,6 +28,7 @@ | |
| #include <memory> | ||
| #include <numeric> | ||
| #include <set> | ||
| #include <unordered_map> | ||
|
|
||
| namespace ml { | ||
| namespace api { | ||
|
|
@@ -162,6 +164,9 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( | |
| } | ||
|
|
||
| if (featureImportance != nullptr) { | ||
| using TVector = maths::CDenseVector<double>; | ||
| using TTotalShapValues = std::unordered_map<std::size_t, TVector>; | ||
| TTotalShapValues totalShapValues; | ||
| int numberClasses{static_cast<int>(classValues.size())}; | ||
| featureImportance->shap( | ||
| row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices, | ||
|
|
@@ -182,14 +187,42 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( | |
| writer.Key(classValues[j]); | ||
| writer.Double(shap[i](j)); | ||
| } | ||
| writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME); | ||
| writer.Key(IMPORTANCE_FIELD_NAME); | ||
| writer.Double(shap[i].lpNorm<1>()); | ||
| } | ||
| writer.EndObject(); | ||
| } | ||
| } | ||
| writer.EndArray(); | ||
|
|
||
| for (std::size_t i = 0; i < shap.size(); ++i) { | ||
| if (shap[i].lpNorm<1>() != 0) { | ||
| totalShapValues | ||
| .emplace(std::make_pair(i, TVector::Zero(shap[i].size()))) | ||
| .first->second += shap[i].cwiseAbs(); | ||
| } | ||
| } | ||
| }); | ||
| writer.Key(TOTAL_FEATURE_IMPORTANCE_FIELD_NAME); | ||
| writer.StartArray(); | ||
| for (const auto& item : totalShapValues) { | ||
| writer.StartObject(); | ||
| writer.Key(FEATURE_NAME_FIELD_NAME); | ||
| writer.String(featureImportance->columnNames()[item.first]); | ||
| if (item.second.size() == 1) { | ||
| writer.Key(IMPORTANCE_FIELD_NAME); | ||
| writer.Double(item.second(0)); | ||
tveasey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } else { | ||
| for (int j = 0; j < item.second.size() && j < numberClasses; ++j) { | ||
| writer.Key(classValues[j]); | ||
|
||
| writer.Double(item.second(j)); | ||
| } | ||
| writer.Key(IMPORTANCE_FIELD_NAME); | ||
| writer.Double(item.second.lpNorm<1>()); | ||
| } | ||
| writer.EndObject(); | ||
tveasey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| writer.EndArray(); | ||
| } | ||
| writer.EndObject(); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.