Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
* Improve runtime and memory usage training deep trees for classification and
regression. (See {ml-pull}1340[#1340].)
* Improvement in handling large inference model definitions. (See {ml-pull}1349[#1349].)
* Calculate total feature importance as a new result type. (See {ml-pull}1387[#1387].)

=== Bug Fixes

Expand Down
7 changes: 7 additions & 0 deletions include/api/CDataFrameAnalysisRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
#ifndef INCLUDED_ml_api_CDataFrameAnalysisRunner_h
#define INCLUDED_ml_api_CDataFrameAnalysisRunner_h

#include "api/CInferenceModelMetadata.h"
#include <core/CProgramCounters.h>
#include <core/CStatePersistInserter.h>

#include <api/CDataFrameAnalysisInstrumentation.h>
#include <api/CInferenceModelDefinition.h>
#include <api/CInferenceModelMetadata.h>
#include <api/ImportExport.h>

#include <rapidjson/fwd.h>
Expand Down Expand Up @@ -66,6 +68,8 @@ class API_EXPORT CDataFrameAnalysisRunner {
using TProgressRecorder = std::function<void(double)>;
using TStrVecVec = std::vector<TStrVec>;
using TInferenceModelDefinitionUPtr = std::unique_ptr<CInferenceModelDefinition>;
using TOptionalInferenceModelMetadata = boost::optional<const CInferenceModelMetadata&>;
using TInferenceModelMetadataUPtr = std::unique_ptr<CInferenceModelMetadata>;

public:
//! The intention is that concrete objects of this hierarchy are constructed
Expand Down Expand Up @@ -141,6 +145,9 @@ class API_EXPORT CDataFrameAnalysisRunner {
virtual TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames, const TStrVecVec& categoryNames) const;

//! \return A serialisable metadata of the trained model.
virtual TOptionalInferenceModelMetadata inferenceModelMetadata() const;

//! \return Reference to the analysis instrumentation.
virtual const CDataFrameAnalysisInstrumentation& instrumentation() const = 0;
//! \return Reference to the analysis instrumentation.
Expand Down
2 changes: 2 additions & 0 deletions include/api/CDataFrameAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class API_EXPORT CDataFrameAnalyzer {
core::CRapidJsonConcurrentLineWriter& writer) const;
void writeInferenceModel(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const;
void writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const;

private:
// This has values: -2 (unset), -1 (missing), >= 0 (control field index).
Expand Down
8 changes: 8 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeRegressionRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
#include <maths/CBoostedTreeLoss.h>

#include <api/CDataFrameTrainBoostedTreeRunner.h>
#include <api/CInferenceModelMetadata.h>
#include <api/ImportExport.h>

#include <rapidjson/fwd.h>

#include <boost/optional.hpp>

namespace ml {
namespace api {

Expand Down Expand Up @@ -51,10 +54,15 @@ class API_EXPORT CDataFrameTrainBoostedTreeRegressionRunner final
TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames,
const TStrVecVec& categoryNameMap) const override;
//! \return A serialisable metadata of the trained regression model.
TOptionalInferenceModelMetadata inferenceModelMetadata() const override;

private:
void validate(const core::CDataFrame& frame,
std::size_t dependentVariableColumn) const override;

private:
CInferenceModelMetadata m_InferenceModelMetadata;
};

//! \brief Makes a core::CDataFrame boosted tree regression runner.
Expand Down
1 change: 1 addition & 0 deletions include/api/CDataFrameTrainBoostedTreeRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
static const std::string FEATURE_NAME_FIELD_NAME;
static const std::string IMPORTANCE_FIELD_NAME;
static const std::string FEATURE_IMPORTANCE_FIELD_NAME;
static const std::string TOTAL_FEATURE_IMPORTANCE_FIELD_NAME;

public:
~CDataFrameTrainBoostedTreeRunner() override;
Expand Down
2 changes: 2 additions & 0 deletions include/api/CInferenceModelDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

#include <core/CRapidJsonConcurrentLineWriter.h>

#include <maths/CBasicStatistics.h>
#include <maths/CDataFrameCategoryEncoder.h>
#include <maths/CLinearAlgebraEigen.h>

#include <api/ImportExport.h>

Expand Down
42 changes: 42 additions & 0 deletions include/api/CInferenceModelMetadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
#ifndef INCLUDED_ml_api_CInferenceModelMetadata_h
#define INCLUDED_ml_api_CInferenceModelMetadata_h

#include <maths/CBasicStatistics.h>
#include <maths/CLinearAlgebraEigen.h>

#include <api/CInferenceModelDefinition.h>
#include <api/ImportExport.h>

#include <string>

namespace ml {
namespace api {

class API_EXPORT CInferenceModelMetadata : public CSerializableToJsonDocument {
public:
using TVector = maths::CDenseVector<double>;

public:
CInferenceModelMetadata() : m_TotalShapValues(){};
void addToJsonDocument(rapidjson::Value& parentObject, TRapidJsonWriter& writer) const override;
void columnNames(const std::vector<std::string>& columnNames);
const std::string& typeString() const;
void addToFeatureImportance(std::size_t i, const TVector& values);

private:
using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<TVector>::TAccumulator;
using TTotalShapValues = std::unordered_map<std::size_t, TMeanAccumulator>;

private:
TTotalShapValues m_TotalShapValues;
std::vector<std::string> m_ColumnNames;
};
}
}

#endif //INCLUDED_ml_api_CInferenceModelMetadata_h
10 changes: 10 additions & 0 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
#ifndef INCLUDED_ml_maths_CTreeShapFeatureImportance_h
#define INCLUDED_ml_maths_CTreeShapFeatureImportance_h

#include <maths/CBasicStatistics.h>
#include <maths/CBoostedTree.h>
#include <maths/CLinearAlgebraEigen.h>
#include <maths/ImportExport.h>

#include <unordered_map>
#include <vector>

namespace ml {
Expand Down Expand Up @@ -73,6 +75,10 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
//! Get the maximum depth of any tree in \p forest.
static std::size_t depth(const TTreeVec& forest);

const TStrVec& columnNames() const;

std::size_t numberTopShapValues() const;

private:
//! Collects the elements of the path through decision tree that are updated together
struct SPathElement {
Expand Down Expand Up @@ -158,6 +164,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
TDoubleVecItr m_ScaleIterator;
};

using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<TVector>::TAccumulator;
using TTotalShapValues = std::unordered_map<std::size_t, TMeanAccumulator>;

private:
static void computeInternalNodeValues(TTree& tree, std::size_t nodeIndex);
static std::size_t depth(const TTree& tree, std::size_t nodeIndex);
Expand Down Expand Up @@ -194,6 +203,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
TVectorVecVec m_PerThreadShapValues;
TVectorVec m_ReducedShapValues;
TSizeVec m_TopShapValues;
TTotalShapValues m_TotalShapValues;
};
}
}
Expand Down
5 changes: 5 additions & 0 deletions lib/api/CDataFrameAnalysisRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ CDataFrameAnalysisRunner::inferenceModelDefinition(const TStrVec& /*fieldNames*/
return TInferenceModelDefinitionUPtr();
}

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameAnalysisRunner::inferenceModelMetadata() const {
return TOptionalInferenceModelMetadata();
}

CDataFrameAnalysisRunnerFactory::TRunnerUPtr
CDataFrameAnalysisRunnerFactory::make(const CDataFrameAnalysisSpecification& spec) const {
auto result = this->makeImpl(spec);
Expand Down
20 changes: 20 additions & 0 deletions lib/api/CDataFrameAnalyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ void CDataFrameAnalyzer::run() {
analysisRunner->waitToFinish();
this->writeInferenceModel(*analysisRunner, outputWriter);
this->writeResultsOf(*analysisRunner, outputWriter);
this->writeInferenceModelMetadata(*analysisRunner, outputWriter);
}
}

Expand Down Expand Up @@ -270,6 +271,8 @@ void CDataFrameAnalyzer::addRowToDataFrame(const TStrVec& fieldValues) {

void CDataFrameAnalyzer::writeInferenceModel(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const {
// Write model meta information

// Write the resulting model for inference.
auto modelDefinition = analysis.inferenceModelDefinition(
m_DataFrame->columnNames(), m_DataFrame->categoricalColumnValues());
Expand All @@ -286,6 +289,23 @@ void CDataFrameAnalyzer::writeInferenceModel(const CDataFrameAnalysisRunner& ana
writer.flush();
}

void CDataFrameAnalyzer::writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const {
// Write model meta information

// Write the resulting model for inference.
auto modelMetadata = analysis.inferenceModelMetadata();
if (modelMetadata) {
rapidjson::Value metadataObject{writer.makeObject()};
modelMetadata->addToJsonDocument(metadataObject, writer);
writer.StartObject();
writer.Key(modelMetadata->typeString());
writer.write(metadataObject);
writer.EndObject();
}
writer.flush();
}

void CDataFrameAnalyzer::writeResultsOf(const CDataFrameAnalysisRunner& analysis,
core::CRapidJsonConcurrentLineWriter& writer) const {

Expand Down
35 changes: 34 additions & 1 deletion lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <maths/CBoostedTreeLoss.h>
#include <maths/CDataFramePredictiveModel.h>
#include <maths/CDataFrameUtils.h>
#include <maths/CLinearAlgebraEigen.h>
#include <maths/COrderings.h>
#include <maths/CTools.h>
#include <maths/CTreeShapFeatureImportance.h>
Expand All @@ -27,6 +28,7 @@
#include <memory>
#include <numeric>
#include <set>
#include <unordered_map>

namespace ml {
namespace api {
Expand Down Expand Up @@ -162,6 +164,9 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
}

if (featureImportance != nullptr) {
using TVector = maths::CDenseVector<double>;
using TTotalShapValues = std::unordered_map<std::size_t, TVector>;
TTotalShapValues totalShapValues;
int numberClasses{static_cast<int>(classValues.size())};
featureImportance->shap(
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
Expand All @@ -182,14 +187,42 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
writer.Key(classValues[j]);
writer.Double(shap[i](j));
}
writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME);
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(shap[i].lpNorm<1>());
}
writer.EndObject();
}
}
writer.EndArray();

for (std::size_t i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
totalShapValues
.emplace(std::make_pair(i, TVector::Zero(shap[i].size())))
.first->second += shap[i].cwiseAbs();
}
}
});
writer.Key(TOTAL_FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (const auto& item : totalShapValues) {
writer.StartObject();
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureImportance->columnNames()[item.first]);
if (item.second.size() == 1) {
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(item.second(0));
} else {
for (int j = 0; j < item.second.size() && j < numberClasses; ++j) {
writer.Key(classValues[j]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work for storage in ES. This index stores the information for all trained models and indexing the class names for the feature importances will not scale.

I propose this format:

{
	"feature_name": "c4",
	"importance": 0.4810469375580312,
	"class_importance": [
		{
			"class_name": "foo",
			"importance": 0.24052346877901588
		},
		{
			"class_name": "baz",
			"importance": 0.19615020783390645
		},
		{
			"class_name": "bar",
			"importance": 0.04437326094510882
		}
	]
}

class_importance will be a nested data type that allows aggregations and searches for specific models and classnames.

writer.Double(item.second(j));
}
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(item.second.lpNorm<1>());
}
writer.EndObject();
}
writer.EndArray();
}
writer.EndObject();
}
Expand Down
39 changes: 35 additions & 4 deletions lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
* you may not use this file except in compliance with the Elastic License.
*/

#include "api/CDataFrameTrainBoostedTreeRunner.h"
#include "api/CInferenceModelMetadata.h"
#include <api/CDataFrameTrainBoostedTreeRegressionRunner.h>

#include <core/CLogger.h>
#include <core/CRapidJsonConcurrentLineWriter.h>

#include <maths/CBasicStatistics.h>
#include <maths/CBoostedTree.h>
#include <maths/CBoostedTreeFactory.h>
#include <maths/CBoostedTreeLoss.h>
#include <maths/CDataFrameUtils.h>
#include <maths/CLinearAlgebraEigen.h>
#include <maths/CTreeShapFeatureImportance.h>

#include <api/CBoostedTreeInferenceModelBuilder.h>
Expand All @@ -24,6 +28,7 @@
#include <memory>
#include <set>
#include <string>
#include <unordered_map>

namespace ml {
namespace api {
Expand Down Expand Up @@ -109,10 +114,11 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
writer.Bool(maths::CDataFrameUtils::isMissing(row[columnHoldingDependentVariable]) == false);
auto featureImportance = tree.shap();
if (featureImportance != nullptr) {
const_cast<CDataFrameTrainBoostedTreeRegressionRunner*>(this)->m_InferenceModelMetadata.columnNames(featureImportance->columnNames());
featureImportance->shap(
row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
row, [&writer, this](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (auto i : indices) {
Expand All @@ -126,7 +132,27 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow(
}
}
writer.EndArray();

for (int i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeRegressionRunner*>(this)->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
}
}
});

// LOG_DEBUG(<< "Total shap size: " << totalShapValues.size());
// writer.Key(TOTAL_FEATURE_IMPORTANCE_FIELD_NAME);
// writer.StartArray();
// for (const auto& item : totalShapValues) {
// writer.StartObject();
// writer.Key(FEATURE_NAME_FIELD_NAME);
// writer.String(featureImportance->columnNames()[item.first]);
// writer.Key(IMPORTANCE_FIELD_NAME);
// writer.Double(maths::CBasicStatistics::mean(item.second)[0]);
// writer.EndObject();
// LOG_DEBUG(<< "Count: " << maths::CBasicStatistics::count(item.second));
// }
// writer.EndArray();
}
writer.EndObject();
}
Expand All @@ -146,6 +172,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(
return std::make_unique<CInferenceModelDefinition>(builder.build());
}

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const {
return TOptionalInferenceModelMetadata(m_InferenceModelMetadata);
}

// clang-format off
const std::string CDataFrameTrainBoostedTreeRegressionRunner::STRATIFIED_CROSS_VALIDATION{"stratified_cross_validation"};
const std::string CDataFrameTrainBoostedTreeRegressionRunner::LOSS_FUNCTION{"loss_function"};
Expand All @@ -161,7 +192,7 @@ const std::string& CDataFrameTrainBoostedTreeRegressionRunnerFactory::name() con

CDataFrameTrainBoostedTreeRegressionRunnerFactory::TRunnerUPtr
CDataFrameTrainBoostedTreeRegressionRunnerFactory::makeImpl(const CDataFrameAnalysisSpecification&) const {
HANDLE_FATAL(<< "Input error: classification has a non-optional parameter '"
HANDLE_FATAL(<< "Input error: regression has a non-optional parameter '"
<< CDataFrameTrainBoostedTreeRunner::DEPENDENT_VARIABLE_NAME << "'.")
return nullptr;
}
Expand Down
Loading