-
Notifications
You must be signed in to change notification settings - Fork 66
[ML] Return total SHAP per feature as a new result type #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
10e5a6c
df13ed0
874140c
6f40db9
38a180a
08bd4ec
13357af
3f7ec0a
0ba97a0
30c109b
f7689c3
d3758d4
8700f5f
6fe6399
f8126c8
1672019
3f8f6c2
1c3cfaf
1452f28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| #include <maths/CBoostedTreeLoss.h> | ||
| #include <maths/CDataFramePredictiveModel.h> | ||
| #include <maths/CDataFrameUtils.h> | ||
| #include <maths/CLinearAlgebraEigen.h> | ||
| #include <maths/COrderings.h> | ||
| #include <maths/CTools.h> | ||
| #include <maths/CTreeShapFeatureImportance.h> | ||
|
|
@@ -27,6 +28,7 @@ | |
| #include <memory> | ||
| #include <numeric> | ||
| #include <set> | ||
| #include <unordered_map> | ||
|
|
||
| namespace ml { | ||
| namespace api { | ||
|
|
@@ -162,6 +164,9 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( | |
| } | ||
|
|
||
| if (featureImportance != nullptr) { | ||
| using TVector = maths::CDenseVector<double>; | ||
| using TTotalShapValues = std::unordered_map<std::size_t, TVector>; | ||
| TTotalShapValues totalShapValues; | ||
| int numberClasses{static_cast<int>(classValues.size())}; | ||
| featureImportance->shap( | ||
| row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices, | ||
|
|
@@ -182,14 +187,44 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( | |
| writer.Key(classValues[j]); | ||
| writer.Double(shap[i](j)); | ||
| } | ||
| writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME); | ||
| writer.Key(IMPORTANCE_FIELD_NAME); | ||
| writer.Double(shap[i].lpNorm<1>()); | ||
| } | ||
| writer.EndObject(); | ||
| } | ||
| } | ||
| writer.EndArray(); | ||
|
|
||
| for (std::size_t i = 0; i < shap.size(); ++i) { | ||
| if (shap[i].lpNorm<1>() != 0) { | ||
| if (totalShapValues.find(i) != totalShapValues.end()) { | ||
| totalShapValues[i] += shap[i].cwiseAbs(); | ||
| } else { | ||
| totalShapValues[i] = shap[i].cwiseAbs(); | ||
| } | ||
tveasey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| }); | ||
| writer.Key(TOTAL_FEATURE_IMPORTANCE_FIELD_NAME); | ||
| writer.StartArray(); | ||
| for (const auto& item : totalShapValues) { | ||
| writer.StartObject(); | ||
| writer.Key(FEATURE_NAME_FIELD_NAME); | ||
| writer.String(featureImportance->columnNames()[item.first]); | ||
| if (item.second.size() == 1) { | ||
| writer.Key(IMPORTANCE_FIELD_NAME); | ||
| writer.Double(item.second(0)); | ||
tveasey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } else { | ||
| for (int j = 0; j < item.second.size() && j < numberClasses; ++j) { | ||
| writer.Key(classValues[j]); | ||
|
||
| writer.Double(item.second(j)); | ||
| } | ||
| writer.Key(IMPORTANCE_FIELD_NAME); | ||
| writer.Double(item.second.lpNorm<1>()); | ||
| } | ||
| writer.EndObject(); | ||
tveasey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| writer.EndArray(); | ||
| } | ||
| writer.EndObject(); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.