Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
* Improve runtime and memory usage training deep trees for classification and
regression. (See {ml-pull}1340[#1340].)
* Improvement in handling large inference model definitions. (See {ml-pull}1349[#1349].)
* Calculate total feature importance as a new result type. (See {ml-pull}1387[#1387].)

=== Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions include/api/CDataFrameTrainBoostedTreeRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
static const std::string FEATURE_NAME_FIELD_NAME;
static const std::string IMPORTANCE_FIELD_NAME;
static const std::string FEATURE_IMPORTANCE_FIELD_NAME;
static const std::string TOTAL_FEATURE_IMPORTANCE_FIELD_NAME;

public:
~CDataFrameTrainBoostedTreeRunner() override;
Expand Down
2 changes: 2 additions & 0 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! Get the maximum depth of any tree in \p forest.
static std::size_t depth(const TTreeVec& forest);

const TStrVec& columnNames() const;

private:
//! Collects the elements of the path through decision tree that are updated together
struct SPathElement {
Expand Down
37 changes: 36 additions & 1 deletion lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <maths/CBoostedTreeLoss.h>
#include <maths/CDataFramePredictiveModel.h>
#include <maths/CDataFrameUtils.h>
#include <maths/CLinearAlgebraEigen.h>
#include <maths/COrderings.h>
#include <maths/CTools.h>
#include <maths/CTreeShapFeatureImportance.h>
Expand All @@ -27,6 +28,7 @@
#include <memory>
#include <numeric>
#include <set>
#include <unordered_map>

namespace ml {
namespace api {
Expand Down Expand Up @@ -162,6 +164,9 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
}

if (featureImportance != nullptr) {
using TVector = maths::CDenseVector<double>;
using TTotalShapValues = std::unordered_map<std::size_t, TVector>;
TTotalShapValues totalShapValues;
int numberClasses{static_cast<int>(classValues.size())};
featureImportance->shap(
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
Expand All @@ -182,14 +187,44 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
writer.Key(classValues[j]);
writer.Double(shap[i](j));
}
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(shap[i].lpNorm<1>());
}
writer.EndObject();
}
}
writer.EndArray();

for (std::size_t i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
if (totalShapValues.find(i) != totalShapValues.end()) {
totalShapValues[i] += shap[i].cwiseAbs();
} else {
totalShapValues[i] = shap[i].cwiseAbs();
}
}
}
});
writer.Key(TOTAL_FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (const auto& item : totalShapValues) {
writer.StartObject();
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureImportance->columnNames()[item.first]);
if (item.second.size() == 1) {
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(item.second(0));
} else {
for (int j = 0; j < item.second.size() && j < numberClasses; ++j) {
writer.Key(classValues[j]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work for storage in ES. This index stores the information for all trained models and indexing the class names for the feature importances will not scale.

I propose this format:

{
	"feature_name": "c4",
	"importance": 0.4810469375580312,
	"class_importance": [
		{
			"class_name": "foo",
			"importance": 0.24052346877901588
		},
		{
			"class_name": "baz",
			"importance": 0.19615020783390645
		},
		{
			"class_name": "bar",
			"importance": 0.04437326094510882
		}
	]
}

class_importance will be a nested data type that allows aggregations and searches for specific models and classnames.

writer.Double(item.second(j));
}
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(item.second.lpNorm<1>());
}
writer.EndObject();
}
writer.EndArray();
}
writer.EndObject();
}
Expand Down
33 changes: 30 additions & 3 deletions lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <maths/CBoostedTreeFactory.h>
#include <maths/CBoostedTreeLoss.h>
#include <maths/CDataFrameUtils.h>
#include <maths/CLinearAlgebraEigen.h>
#include <maths/CTreeShapFeatureImportance.h>

#include <api/CBoostedTreeInferenceModelBuilder.h>
Expand All @@ -24,6 +25,7 @@
#include <memory>
#include <set>
#include <string>
#include <unordered_map>

namespace ml {
namespace api {
Expand Down Expand Up @@ -109,10 +111,14 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
writer.Bool(maths::CDataFrameUtils::isMissing(row[columnHoldingDependentVariable]) == false);
auto featureImportance = tree.shap();
if (featureImportance != nullptr) {
using TVector = maths::CDenseVector<double>;
using TTotalShapValues = std::unordered_map<std::size_t, TVector>;
TTotalShapValues totalShapValues;
featureImportance->shap(
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
row, [&writer, &totalShapValues](
const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (auto i : indices) {
Expand All @@ -126,7 +132,28 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
}
}
writer.EndArray();

for (int i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
if (totalShapValues.find(i) != totalShapValues.end()) {
totalShapValues[i] += shap[i].cwiseAbs();
} else {
totalShapValues[i] = shap[i].cwiseAbs();
}
}
}
});
writer.Key(TOTAL_FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (const auto& item : totalShapValues) {
writer.StartObject();
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureImportance->columnNames()[item.first]);
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(item.second[0]);
writer.EndObject();
}
writer.EndArray();
}
writer.EndObject();
}
Expand Down
1 change: 1 addition & 0 deletions lib/api/CDataFrameTrainBoostedTreeRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ const std::string CDataFrameTrainBoostedTreeRunner::IS_TRAINING_FIELD_NAME{"is_t
const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME{"feature_name"};
const std::string CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME{"importance"};
const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME{"feature_importance"};
const std::string CDataFrameTrainBoostedTreeRunner::TOTAL_FEATURE_IMPORTANCE_FIELD_NAME{"total_feature_importance"};
// clang-format on
}
}
18 changes: 17 additions & 1 deletion lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) {

TMeanVarAccumulator bias;
double c1Sum{0.0}, c2Sum{0.0}, c3Sum{0.0}, c4Sum{0.0};
bool hasTotalFeatureImportance{false};
for (const auto& result : results.GetArray()) {
if (result.HasMember("row_results")) {
double c1{readShapValue(result, "c1")};
Expand All @@ -456,6 +457,9 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) {
c4Sum += std::fabs(c4);
// assert that no SHAP value for the dependent variable is returned
BOOST_REQUIRE_EQUAL(readShapValue(result, "target"), 0.0);
if (result["row_results"]["results"]["ml"].HasMember("total_feature_importance")) {
hasTotalFeatureImportance = true;
}
}
}

Expand All @@ -471,6 +475,7 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) {
BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 5.0); // c3 and c4 within 5% of each other
// make sure the local approximation differs from the prediction always by the same bias (up to a numeric error)
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6);
BOOST_TEST_REQUIRE(hasTotalFeatureImportance);
}

BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) {
Expand Down Expand Up @@ -510,6 +515,7 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) {
auto results{runBinaryClassification(topShapValues, {0.5, -0.7, 0.2, -0.2})};

double c1Sum{0.0}, c2Sum{0.0}, c3Sum{0.0}, c4Sum{0.0};
bool hasTotalFeatureImportance{false};
for (const auto& result : results.GetArray()) {
if (result.HasMember("row_results")) {
double c1{readShapValue(result, "c1")};
Expand All @@ -536,6 +542,10 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) {
c2Sum += std::fabs(c2);
c3Sum += std::fabs(c3);
c4Sum += std::fabs(c4);

if (result["row_results"]["results"]["ml"].HasMember("total_feature_importance")) {
hasTotalFeatureImportance = true;
}
}
}

Expand All @@ -548,13 +558,14 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) {
BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 40.0); // c3 and c4 within 40% of each other
// make sure the local approximation differs from the prediction always by the same bias (up to a numeric error)
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6);
BOOST_TEST_REQUIRE(hasTotalFeatureImportance);
}

BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SFixture) {

std::size_t topShapValues{4};
auto results{runMultiClassClassification(topShapValues, {0.5, -0.7, 0.2, -0.2})};

bool hasTotalFeatureImportance{false};
for (const auto& result : results.GetArray()) {
if (result.HasMember("row_results")) {
double c1{readShapValue(result, "c1")};
Expand Down Expand Up @@ -584,8 +595,13 @@ BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SF
double c4bar{readShapValue(result, "c4", "bar")};
double c4baz{readShapValue(result, "c4", "baz")};
BOOST_REQUIRE_CLOSE(c4, std::abs(c4f) + std::abs(c4bar) + std::abs(c4baz), 1e-6);

if (result["row_results"]["results"]["ml"].HasMember("total_feature_importance")) {
hasTotalFeatureImportance = true;
}
}
}
BOOST_TEST_REQUIRE(hasTotalFeatureImportance);
}

BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoShap, SFixture) {
Expand Down
4 changes: 4 additions & 0 deletions lib/maths/CTreeShapFeatureImportance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,5 +362,9 @@ void CTreeShapFeatureImportance::unwindPath(CSplitPath& path, int pathIndex, int
}
--nextIndex;
}

const CTreeShapFeatureImportance::TStrVec& CTreeShapFeatureImportance::columnNames() const {
return m_ColumnNames;
}
}
}