diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 37f1e8b52b..427d3880f8 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -75,6 +75,7 @@ regression. (See {ml-pull}1340[#1340].) * Improvement in handling large inference model definitions. (See {ml-pull}1349[#1349].) * Add a peak_model_bytes field to model_size_stats. (See {ml-pull}1389[#1389].) +* Calculate total feature importance as a new result type. (See {ml-pull}1387[#1387].) === Bug Fixes diff --git a/include/api/CDataFrameAnalysisRunner.h b/include/api/CDataFrameAnalysisRunner.h index abb11a3208..dbd89d09cc 100644 --- a/include/api/CDataFrameAnalysisRunner.h +++ b/include/api/CDataFrameAnalysisRunner.h @@ -12,10 +12,13 @@ #include #include +#include #include #include +#include + #include #include #include @@ -66,6 +69,7 @@ class API_EXPORT CDataFrameAnalysisRunner { using TProgressRecorder = std::function; using TStrVecVec = std::vector; using TInferenceModelDefinitionUPtr = std::unique_ptr; + using TOptionalInferenceModelMetadata = boost::optional; public: //! The intention is that concrete objects of this hierarchy are constructed @@ -141,6 +145,9 @@ class API_EXPORT CDataFrameAnalysisRunner { virtual TInferenceModelDefinitionUPtr inferenceModelDefinition(const TStrVec& fieldNames, const TStrVecVec& categoryNames) const; + //! \return A serialisable metadata of the trained model. + virtual TOptionalInferenceModelMetadata inferenceModelMetadata() const; + //! \return Reference to the analysis instrumentation. virtual const CDataFrameAnalysisInstrumentation& instrumentation() const = 0; //! \return Reference to the analysis instrumentation. diff --git a/include/api/CDataFrameAnalyzer.h b/include/api/CDataFrameAnalyzer.h index 6f98ffea86..675ed81fe8 100644 --- a/include/api/CDataFrameAnalyzer.h +++ b/include/api/CDataFrameAnalyzer.h @@ -87,6 +87,8 @@ class API_EXPORT CDataFrameAnalyzer { core::CRapidJsonConcurrentLineWriter& writer) const; void writeInferenceModel(const CDataFrameAnalysisRunner& analysis, core::CRapidJsonConcurrentLineWriter& writer) const; + void writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis, + core::CRapidJsonConcurrentLineWriter& writer) const; private: // This has values: -2 (unset), -1 (missing), >= 0 (control field index). diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index ec10300bb4..2ee662eca5 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -40,6 +41,8 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final static const std::string NUM_TOP_CLASSES; static const std::string PREDICTION_FIELD_TYPE; static const std::string CLASS_ASSIGNMENT_OBJECTIVE; + static const std::string CLASSES_FIELD_NAME; + static const std::string CLASS_NAME_FIELD_NAME; static const TStrVec CLASS_ASSIGNMENT_OBJECTIVE_VALUES; public: @@ -70,6 +73,9 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final inferenceModelDefinition(const TStrVec& fieldNames, const TStrVecVec& categoryNames) const override; + //! \return A serialisable metadata of the trained regression model. + TOptionalInferenceModelMetadata inferenceModelMetadata() const override; + private: static TLossFunctionUPtr loss(std::size_t numberClasses); @@ -82,6 +88,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; EPredictionFieldType m_PredictionFieldType; + mutable CInferenceModelMetadata m_InferenceModelMetadata; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/include/api/CDataFrameTrainBoostedTreeRegressionRunner.h b/include/api/CDataFrameTrainBoostedTreeRegressionRunner.h index c9eb8be6bc..3ed92f00f2 100644 --- a/include/api/CDataFrameTrainBoostedTreeRegressionRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeRegressionRunner.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -51,10 +52,15 @@ class API_EXPORT CDataFrameTrainBoostedTreeRegressionRunner final TInferenceModelDefinitionUPtr inferenceModelDefinition(const TStrVec& fieldNames, const TStrVecVec& categoryNameMap) const override; + //! \return A serialisable metadata of the trained regression model. + TOptionalInferenceModelMetadata inferenceModelMetadata() const override; private: void validate(const core::CDataFrame& frame, std::size_t dependentVariableColumn) const override; + +private: + mutable CInferenceModelMetadata m_InferenceModelMetadata; }; //! \brief Makes a core::CDataFrame boosted tree regression runner. diff --git a/include/api/CInferenceModelMetadata.h b/include/api/CInferenceModelMetadata.h new file mode 100644 index 0000000000..75ea2ae2c9 --- /dev/null +++ b/include/api/CInferenceModelMetadata.h @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +#ifndef INCLUDED_ml_api_CInferenceModelMetadata_h +#define INCLUDED_ml_api_CInferenceModelMetadata_h + +#include +#include + +#include +#include + +#include + +namespace ml { +namespace api { + +//! \brief Class controls the serialization of the model meta information +//! (such as totol feature importance) into JSON format. +class API_EXPORT CInferenceModelMetadata { +public: + static const std::string JSON_CLASS_NAME_TAG; + static const std::string JSON_CLASSES_TAG; + static const std::string JSON_FEATURE_NAME_TAG; + static const std::string JSON_IMPORTANCE_TAG; + static const std::string JSON_MAX_TAG; + static const std::string JSON_MEAN_MAGNITUDE_TAG; + static const std::string JSON_MIN_TAG; + static const std::string JSON_MODEL_METADATA_TAG; + static const std::string JSON_TOTAL_FEATURE_IMPORTANCE_TAG; + +public: + using TVector = maths::CDenseVector; + using TStrVec = std::vector; + using TRapidJsonWriter = core::CRapidJsonConcurrentLineWriter; + +public: + //! Writes metadata using \p writer. + void write(TRapidJsonWriter& writer) const; + void columnNames(const TStrVec& columnNames); + void classValues(const TStrVec& classValues); + const std::string& typeString() const; + //! Add importances \p values to the feature with index \p i to calculate total feature importance. + //! Total feature importance is the mean of the magnitudes of importances for individual data points. + void addToFeatureImportance(std::size_t i, const TVector& values); + +private: + using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar::TAccumulator; + using TMinMaxAccumulator = std::vector>; + using TSizeMeanVarAccumulatorUMap = std::unordered_map; + using TSizeMinMaxAccumulatorUMap = std::unordered_map; + +private: + void writeTotalFeatureImportance(TRapidJsonWriter& writer) const; + +private: + TSizeMeanVarAccumulatorUMap m_TotalShapValuesMeanVar; + TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax; + TStrVec m_ColumnNames; + TStrVec m_ClassValues; +}; +} +} + +#endif //INCLUDED_ml_api_CInferenceModelMetadata_h diff --git a/include/maths/CBasicStatistics.h b/include/maths/CBasicStatistics.h index b6c930db21..b934f627b3 100644 --- a/include/maths/CBasicStatistics.h +++ b/include/maths/CBasicStatistics.h @@ -245,7 +245,7 @@ class MATHS_EXPORT CBasicStatistics { if (ORDER > 1) { T r{x - s_Moments[0]}; - T r2{r * r}; + T r2{las::componentwise(r) * las::componentwise(r)}; T dMean{mean - s_Moments[0]}; T dMean2{las::componentwise(dMean) * las::componentwise(dMean)}; T variance{s_Moments[1]}; diff --git a/include/maths/CTreeShapFeatureImportance.h b/include/maths/CTreeShapFeatureImportance.h index e9846e8bf8..cd0a7c2d3e 100644 --- a/include/maths/CTreeShapFeatureImportance.h +++ b/include/maths/CTreeShapFeatureImportance.h @@ -73,6 +73,9 @@ class MATHS_EXPORT CTreeShapFeatureImportance { //! Get the maximum depth of any tree in \p forest. static std::size_t depth(const TTreeVec& forest); + //! Get the column names. + const TStrVec& columnNames() const; + private: //! Collects the elements of the path through decision tree that are updated together struct SPathElement { diff --git a/lib/api/CDataFrameAnalysisRunner.cc b/lib/api/CDataFrameAnalysisRunner.cc index dc3d15d0a7..c4492558a5 100644 --- a/lib/api/CDataFrameAnalysisRunner.cc +++ b/lib/api/CDataFrameAnalysisRunner.cc @@ -193,6 +193,11 @@ CDataFrameAnalysisRunner::inferenceModelDefinition(const TStrVec& /*fieldNames*/ return TInferenceModelDefinitionUPtr(); } +CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata +CDataFrameAnalysisRunner::inferenceModelMetadata() const { + return TOptionalInferenceModelMetadata(); +} + CDataFrameAnalysisRunnerFactory::TRunnerUPtr CDataFrameAnalysisRunnerFactory::make(const CDataFrameAnalysisSpecification& spec) const { auto result = this->makeImpl(spec); diff --git a/lib/api/CDataFrameAnalyzer.cc b/lib/api/CDataFrameAnalyzer.cc index edb698747d..90fe79858e 100644 --- a/lib/api/CDataFrameAnalyzer.cc +++ b/lib/api/CDataFrameAnalyzer.cc @@ -144,6 +144,8 @@ void CDataFrameAnalyzer::run() { analysisRunner->waitToFinish(); this->writeInferenceModel(*analysisRunner, outputWriter); this->writeResultsOf(*analysisRunner, outputWriter); + // TODO reactivate once Java parsing is ready + // this->writeInferenceModelMetadata(*analysisRunner, outputWriter); } } @@ -286,6 +288,21 @@ void CDataFrameAnalyzer::writeInferenceModel(const CDataFrameAnalysisRunner& ana writer.flush(); } +void CDataFrameAnalyzer::writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis, + core::CRapidJsonConcurrentLineWriter& writer) const { + // Write model meta information + auto modelMetadata = analysis.inferenceModelMetadata(); + if (modelMetadata) { + writer.StartObject(); + writer.Key(modelMetadata->typeString()); + writer.StartObject(); + modelMetadata->write(writer); + writer.EndObject(); + writer.EndObject(); + } + writer.flush(); +} + void CDataFrameAnalyzer::writeResultsOf(const CDataFrameAnalysisRunner& analysis, core::CRapidJsonConcurrentLineWriter& writer) const { diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index e011970505..0867fa8391 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -41,7 +42,6 @@ const std::string IS_TRAINING_FIELD_NAME{"is_training"}; const std::string PREDICTION_PROBABILITY_FIELD_NAME{"prediction_probability"}; const std::string PREDICTION_SCORE_FIELD_NAME{"prediction_score"}; const std::string TOP_CLASSES_FIELD_NAME{"top_classes"}; -const std::string CLASS_NAME_FIELD_NAME{"class_name"}; const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"}; const std::string CLASS_SCORE_FIELD_NAME{"class_score"}; @@ -162,7 +162,9 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( } if (featureImportance != nullptr) { - int numberClasses{static_cast(classValues.size())}; + std::size_t numberClasses{classValues.size()}; + m_InferenceModelMetadata.columnNames(featureImportance->columnNames()); + m_InferenceModelMetadata.classValues(classValues); featureImportance->shap( row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices, const TStrVec& featureNames, @@ -175,20 +177,47 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( writer.Key(FEATURE_NAME_FIELD_NAME); writer.String(featureNames[i]); if (shap[i].size() == 1) { - writer.Key(IMPORTANCE_FIELD_NAME); - writer.Double(shap[i](0)); + // output feature importance for individual classes in binary case + writer.Key(CLASSES_FIELD_NAME); + writer.StartArray(); + for (std::size_t j = 0; j < numberClasses; ++j) { + double importance{(j == predictedClassId) + ? shap[i](0) + : -shap[i](0)}; + writer.StartObject(); + writer.Key(CLASS_NAME_FIELD_NAME); + writer.String(classValues[j]); + writer.Key(IMPORTANCE_FIELD_NAME); + writer.Double(importance); + writer.EndObject(); + } + writer.EndArray(); } else { - for (int j = 0; j < shap[i].size() && j < numberClasses; ++j) { - writer.Key(classValues[j]); + // output feature importance for individual classes in multiclass case + writer.Key(CLASSES_FIELD_NAME); + writer.StartArray(); + for (std::size_t j = 0; + j < shap[i].size() && j < numberClasses; ++j) { + writer.StartObject(); + writer.Key(CLASS_NAME_FIELD_NAME); + writer.String(classValues[j]); + writer.Key(IMPORTANCE_FIELD_NAME); writer.Double(shap[i](j)); + writer.EndObject(); } - writer.Key(CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME); - writer.Double(shap[i].lpNorm<1>()); + writer.EndArray(); } writer.EndObject(); } } writer.EndArray(); + + for (std::size_t i = 0; i < shap.size(); ++i) { + if (shap[i].lpNorm<1>() != 0) { + const_cast(this) + ->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]); + } + } }); } writer.EndObject(); @@ -257,6 +286,11 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition( return std::make_unique(builder.build()); } +CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata +CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const { + return m_InferenceModelMetadata; +} + // clang-format off // The MAX_NUMBER_CLASSES must match the value used in the Java code. See the // MAX_DEPENDENT_VARIABLE_CARDINALITY in the x-pack classification code. @@ -291,5 +325,7 @@ CDataFrameTrainBoostedTreeClassifierRunnerFactory::makeImpl( } const std::string CDataFrameTrainBoostedTreeClassifierRunnerFactory::NAME{"classification"}; +const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASSES_FIELD_NAME{"classes"}; +const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASS_NAME_FIELD_NAME{"class_name"}; } } diff --git a/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc b/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc index eec5837000..44c3e703e6 100644 --- a/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc @@ -109,10 +109,11 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow( writer.Bool(maths::CDataFrameUtils::isMissing(row[columnHoldingDependentVariable]) == false); auto featureImportance = tree.shap(); if (featureImportance != nullptr) { + m_InferenceModelMetadata.columnNames(featureImportance->columnNames()); featureImportance->shap( - row, [&writer](const maths::CTreeShapFeatureImportance::TSizeVec& indices, - const TStrVec& featureNames, - const maths::CTreeShapFeatureImportance::TVectorVec& shap) { + row, [&writer, this](const maths::CTreeShapFeatureImportance::TSizeVec& indices, + const TStrVec& featureNames, + const maths::CTreeShapFeatureImportance::TVectorVec& shap) { writer.Key(FEATURE_IMPORTANCE_FIELD_NAME); writer.StartArray(); for (auto i : indices) { @@ -126,6 +127,13 @@ void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow( } } writer.EndArray(); + + for (int i = 0; i < shap.size(); ++i) { + if (shap[i].lpNorm<1>() != 0) { + const_cast(this) + ->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]); + } + } }); } writer.EndObject(); @@ -145,6 +153,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition( return std::make_unique(builder.build()); } +CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata +CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const { + return TOptionalInferenceModelMetadata(m_InferenceModelMetadata); +} + // clang-format off const std::string CDataFrameTrainBoostedTreeRegressionRunner::STRATIFIED_CROSS_VALIDATION{"stratified_cross_validation"}; const std::string CDataFrameTrainBoostedTreeRegressionRunner::LOSS_FUNCTION{"loss_function"}; @@ -160,7 +173,7 @@ const std::string& CDataFrameTrainBoostedTreeRegressionRunnerFactory::name() con CDataFrameTrainBoostedTreeRegressionRunnerFactory::TRunnerUPtr CDataFrameTrainBoostedTreeRegressionRunnerFactory::makeImpl(const CDataFrameAnalysisSpecification&) const { - HANDLE_FATAL(<< "Input error: classification has a non-optional parameter '" + HANDLE_FATAL(<< "Input error: regression has a non-optional parameter '" << CDataFrameTrainBoostedTreeRunner::DEPENDENT_VARIABLE_NAME << "'.") return nullptr; } diff --git a/lib/api/CInferenceModelMetadata.cc b/lib/api/CInferenceModelMetadata.cc new file mode 100644 index 0000000000..ab2206f49f --- /dev/null +++ b/lib/api/CInferenceModelMetadata.cc @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +#include + +namespace ml { +namespace api { + +void CInferenceModelMetadata::write(TRapidJsonWriter& writer) const { + this->writeTotalFeatureImportance(writer); +} + +void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writer) const { + writer.Key(JSON_TOTAL_FEATURE_IMPORTANCE_TAG); + writer.StartArray(); + for (const auto& item : m_TotalShapValuesMeanVar) { + writer.StartObject(); + writer.Key(JSON_FEATURE_NAME_TAG); + writer.String(m_ColumnNames[item.first]); + auto meanFeatureImportance = maths::CBasicStatistics::mean(item.second); + const auto& minMaxFeatureImportance = m_TotalShapValuesMinMax.at(item.first); + if (meanFeatureImportance.size() == 1 && m_ClassValues.empty()) { + // Regression + writer.Key(JSON_IMPORTANCE_TAG); + writer.StartObject(); + writer.Key(JSON_MEAN_MAGNITUDE_TAG); + writer.Double(meanFeatureImportance[0]); + writer.Key(JSON_MIN_TAG); + writer.Double(minMaxFeatureImportance[0].min()); + writer.Key(JSON_MAX_TAG); + writer.Double(minMaxFeatureImportance[0].max()); + writer.EndObject(); + } else if (meanFeatureImportance.size() == 1 && m_ClassValues.empty() == false) { + // Binary classification + // since we track the min/max only for one class, this will make the range more robust + double minimum{std::min(minMaxFeatureImportance[0].min(), + -minMaxFeatureImportance[0].max())}; + double maximum{-minimum}; + writer.Key(JSON_CLASSES_TAG); + writer.StartArray(); + for (std::size_t j = 0; j < m_ClassValues.size(); ++j) { + writer.StartObject(); + writer.Key(JSON_CLASS_NAME_TAG); + writer.String(m_ClassValues[j]); + writer.Key(JSON_IMPORTANCE_TAG); + writer.StartObject(); + writer.Key(JSON_MEAN_MAGNITUDE_TAG); + // mean magnitude is the same for both classes + writer.Double(meanFeatureImportance[0]); + writer.Key(JSON_MIN_TAG); + writer.Double(minimum); + writer.Key(JSON_MAX_TAG); + writer.Double(maximum); + writer.EndObject(); + writer.EndObject(); + } + writer.EndArray(); + } else { + // Multiclass classification + writer.Key(JSON_CLASSES_TAG); + writer.StartArray(); + for (std::size_t j = 0; + j < meanFeatureImportance.size() && j < m_ClassValues.size(); ++j) { + writer.StartObject(); + writer.Key(JSON_CLASS_NAME_TAG); + writer.String(m_ClassValues[j]); + writer.Key(JSON_IMPORTANCE_TAG); + writer.StartObject(); + writer.Key(JSON_MEAN_MAGNITUDE_TAG); + writer.Double(meanFeatureImportance[j]); + writer.Key(JSON_MIN_TAG); + writer.Double(minMaxFeatureImportance[j].min()); + writer.Key(JSON_MAX_TAG); + writer.Double(minMaxFeatureImportance[j].max()); + writer.EndObject(); + writer.EndObject(); + } + writer.EndArray(); + } + writer.EndObject(); + } + writer.EndArray(); +} + +const std::string& CInferenceModelMetadata::typeString() const { + return JSON_MODEL_METADATA_TAG; +} + +void CInferenceModelMetadata::columnNames(const TStrVec& columnNames) { + m_ColumnNames = columnNames; +} + +void CInferenceModelMetadata::classValues(const TStrVec& classValues) { + m_ClassValues = classValues; +} + +void CInferenceModelMetadata::addToFeatureImportance(std::size_t i, const TVector& values) { + m_TotalShapValuesMeanVar + .emplace(std::make_pair(i, TVector::Zero(values.size()))) + .first->second.add(values.cwiseAbs()); + auto& minMaxVector = + m_TotalShapValuesMinMax + .emplace(std::make_pair(i, TMinMaxAccumulator(values.size()))) + .first->second; + for (std::size_t j = 0; j < minMaxVector.size(); ++j) { + minMaxVector[j].add(values[j]); + } +} + +// clang-format off +const std::string CInferenceModelMetadata::JSON_CLASS_NAME_TAG{"class_name"}; +const std::string CInferenceModelMetadata::JSON_CLASSES_TAG{"classes"}; +const std::string CInferenceModelMetadata::JSON_FEATURE_NAME_TAG{"feature_name"}; +const std::string CInferenceModelMetadata::JSON_IMPORTANCE_TAG{"importance"}; +const std::string CInferenceModelMetadata::JSON_MAX_TAG{"max"}; +const std::string CInferenceModelMetadata::JSON_MEAN_MAGNITUDE_TAG{"mean_magnitude"}; +const std::string CInferenceModelMetadata::JSON_MIN_TAG{"min"}; +const std::string CInferenceModelMetadata::JSON_MODEL_METADATA_TAG{"model_metadata"}; +const std::string CInferenceModelMetadata::JSON_TOTAL_FEATURE_IMPORTANCE_TAG{"total_feature_importance"}; +// clang-format on +} +} diff --git a/lib/api/Makefile.first b/lib/api/Makefile.first index 15538533cd..21525f33e3 100644 --- a/lib/api/Makefile.first +++ b/lib/api/Makefile.first @@ -45,6 +45,7 @@ CForecastRunner.cc \ CGlobalCategoryId.cc \ CHierarchicalResultsWriter.cc \ CInferenceModelDefinition.cc \ +CInferenceModelMetadata.cc \ CInputParser.cc \ CIoManager.cc \ CJsonOutputWriter.cc \ diff --git a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc index 6594549af8..31e98379fe 100644 --- a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc +++ b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc @@ -14,7 +14,9 @@ #include #include +#include #include +#include #include #include @@ -229,7 +231,6 @@ struct SFixture { BOOST_TEST_REQUIRE( core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) < core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage)); - rapidjson::Document results; rapidjson::ParseResult ok(results.Parse(s_Output.str())); BOOST_TEST_REQUIRE(static_cast(ok) == true); @@ -418,8 +419,53 @@ double readShapValue(const RESULTS& results, std::string shapField, std::string .GetArray()) { if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME] .GetString() == shapField) { - if (shapResult.HasMember(className)) { - return shapResult[className].GetDouble(); + for (const auto& item : + shapResult[api::CDataFrameTrainBoostedTreeClassifierRunner::CLASSES_FIELD_NAME] + .GetArray()) { + if (item[api::CDataFrameTrainBoostedTreeClassifierRunner::CLASS_NAME_FIELD_NAME] + .GetString() == className) { + return item[api::CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME] + .GetDouble(); + } + } + } + } + } + return 0.0; +} + +template +double readTotalShapValue(const RESULTS& results, std::string shapField) { + using TModelMetadata = api::CInferenceModelMetadata; + if (results[TModelMetadata::JSON_MODEL_METADATA_TAG].HasMember( + TModelMetadata::JSON_TOTAL_FEATURE_IMPORTANCE_TAG)) { + for (const auto& shapResult : + results[TModelMetadata::JSON_MODEL_METADATA_TAG][TModelMetadata::JSON_TOTAL_FEATURE_IMPORTANCE_TAG] + .GetArray()) { + if (shapResult[TModelMetadata::JSON_FEATURE_NAME_TAG].GetString() == shapField) { + return shapResult[TModelMetadata::JSON_IMPORTANCE_TAG][TModelMetadata::JSON_MEAN_MAGNITUDE_TAG] + .GetDouble(); + } + } + } + return 0.0; +} + +template +double readTotalShapValue(const RESULTS& results, std::string shapField, std::string className) { + using TModelMetadata = api::CInferenceModelMetadata; + if (results[TModelMetadata::JSON_MODEL_METADATA_TAG].HasMember( + TModelMetadata::JSON_TOTAL_FEATURE_IMPORTANCE_TAG)) { + for (const auto& shapResult : + results[TModelMetadata::JSON_MODEL_METADATA_TAG][TModelMetadata::JSON_TOTAL_FEATURE_IMPORTANCE_TAG] + .GetArray()) { + if (shapResult[TModelMetadata::JSON_FEATURE_NAME_TAG].GetString() == shapField) { + for (const auto& item : + shapResult[TModelMetadata::JSON_CLASSES_TAG].GetArray()) { + if (item[TModelMetadata::JSON_CLASS_NAME_TAG].GetString() == className) { + return item[TModelMetadata::JSON_IMPORTANCE_TAG][TModelMetadata::JSON_MEAN_MAGNITUDE_TAG] + .GetDouble(); + } } } } @@ -439,7 +485,14 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) { auto results{runRegression(topShapValues, weights)}; TMeanVarAccumulator bias; + TMeanAccumulator c1TotalShapExpected; + TMeanAccumulator c2TotalShapExpected; + TMeanAccumulator c3TotalShapExpected; + TMeanAccumulator c4TotalShapExpected; double c1Sum{0.0}, c2Sum{0.0}, c3Sum{0.0}, c4Sum{0.0}; + double c1TotalShapActual{0.0}, c2TotalShapActual{0.0}, + c3TotalShapActual{0.0}, c4TotalShapActual{0.0}; + bool hasTotalFeatureImportance{false}; for (const auto& result : results.GetArray()) { if (result.HasMember("row_results")) { double c1{readShapValue(result, "c1")}; @@ -454,8 +507,21 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) { c2Sum += std::fabs(c2); c3Sum += std::fabs(c3); c4Sum += std::fabs(c4); + c1TotalShapExpected.add(std::fabs(c1)); + c2TotalShapExpected.add(std::fabs(c2)); + c3TotalShapExpected.add(std::fabs(c3)); + c4TotalShapExpected.add(std::fabs(c4)); // assert that no SHAP value for the dependent variable is returned BOOST_REQUIRE_EQUAL(readShapValue(result, "target"), 0.0); + } else if (result.HasMember("model_metadata")) { + if (result["model_metadata"].HasMember("total_feature_importance")) { + hasTotalFeatureImportance = true; + c1TotalShapActual = readTotalShapValue(result, "c1"); + c2TotalShapActual = readTotalShapValue(result, "c2"); + c3TotalShapActual = readTotalShapValue(result, "c3"); + c4TotalShapActual = readTotalShapValue(result, "c4"); + } + // TODO check that the total feature importance is calculated correctly } } @@ -471,6 +537,16 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) { BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 5.0); // c3 and c4 within 5% of each other // make sure the local approximation differs from the prediction always by the same bias (up to a numeric error) BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6); + // TODO reactivate once Java parsing is ready + // BOOST_TEST_REQUIRE(hasTotalFeatureImportance); + // BOOST_REQUIRE_CLOSE(c1TotalShapActual, + // maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c2TotalShapActual, + // maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c3TotalShapActual, + // maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c4TotalShapActual, + // maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); } BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) { @@ -508,18 +584,27 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) { std::size_t topShapValues{4}; TMeanVarAccumulator bias; auto results{runBinaryClassification(topShapValues, {0.5, -0.7, 0.2, -0.2})}; - + TMeanAccumulator c1TotalShapExpected; + TMeanAccumulator c2TotalShapExpected; + TMeanAccumulator c3TotalShapExpected; + TMeanAccumulator c4TotalShapExpected; double c1Sum{0.0}, c2Sum{0.0}, c3Sum{0.0}, c4Sum{0.0}; + double c1FooTotalShapActual{0.0}, c2FooTotalShapActual{0.0}, + c3FooTotalShapActual{0.0}, c4FooTotalShapActual{0.0}; + double c1BarTotalShapActual{0.0}, c2BarTotalShapActual{0.0}, + c3BarTotalShapActual{0.0}, c4BarTotalShapActual{0.0}; + bool hasTotalFeatureImportance{false}; for (const auto& result : results.GetArray()) { if (result.HasMember("row_results")) { - double c1{readShapValue(result, "c1")}; - double c2{readShapValue(result, "c2")}; - double c3{readShapValue(result, "c3")}; - double c4{readShapValue(result, "c4")}; - double predictionProbability{ - result["row_results"]["results"]["ml"]["prediction_probability"].GetDouble()}; std::string targetPrediction{ result["row_results"]["results"]["ml"]["target_prediction"].GetString()}; + double c1{readShapValue(result, "c1", targetPrediction)}; + double c2{readShapValue(result, "c2", targetPrediction)}; + double c3{readShapValue(result, "c3", targetPrediction)}; + double c4{readShapValue(result, "c4", targetPrediction)}; + double predictionProbability{ + result["row_results"]["results"]["ml"]["prediction_probability"].GetDouble()}; + double logOdds{0.0}; if (targetPrediction == "bar") { logOdds = std::log(predictionProbability / @@ -536,6 +621,23 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) { c2Sum += std::fabs(c2); c3Sum += std::fabs(c3); c4Sum += std::fabs(c4); + c1TotalShapExpected.add(std::fabs(c1)); + c2TotalShapExpected.add(std::fabs(c2)); + c3TotalShapExpected.add(std::fabs(c3)); + c4TotalShapExpected.add(std::fabs(c4)); + } else if (result.HasMember("model_metadata")) { + if (result["model_metadata"].HasMember("total_feature_importance")) { + hasTotalFeatureImportance = true; + } + // TODO reactivate once Java parsing is ready + c1FooTotalShapActual = readTotalShapValue(result, "c1", "foo"); + c2FooTotalShapActual = readTotalShapValue(result, "c2", "foo"); + c3FooTotalShapActual = readTotalShapValue(result, "c3", "foo"); + c4FooTotalShapActual = readTotalShapValue(result, "c4", "foo"); + c1BarTotalShapActual = readTotalShapValue(result, "c1", "bar"); + c2BarTotalShapActual = readTotalShapValue(result, "c2", "bar"); + c3BarTotalShapActual = readTotalShapValue(result, "c3", "bar"); + c4BarTotalShapActual = readTotalShapValue(result, "c4", "bar"); } } @@ -548,44 +650,131 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) { BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 40.0); // c3 and c4 within 40% of each other // make sure the local approximation differs from the prediction always by the same bias (up to a numeric error) BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6); + // TODO reactivate once Java parsing is ready + // BOOST_TEST_REQUIRE(hasTotalFeatureImportance); + // BOOST_REQUIRE_CLOSE(c1FooTotalShapActual, + // maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c2FooTotalShapActual, + // maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c3FooTotalShapActual, + // maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c4FooTotalShapActual, + // maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c1BarTotalShapActual, + // maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c2BarTotalShapActual, + // maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c3BarTotalShapActual, + // maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c4BarTotalShapActual, + // maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); } BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SFixture) { std::size_t topShapValues{4}; auto results{runMultiClassClassification(topShapValues, {0.5, -0.7, 0.2, -0.2})}; - + TMeanAccumulator c1FooTotalShapExpected; + TMeanAccumulator c2FooTotalShapExpected; + TMeanAccumulator c3FooTotalShapExpected; + TMeanAccumulator c4FooTotalShapExpected; + TMeanAccumulator c1BarTotalShapExpected; + TMeanAccumulator c2BarTotalShapExpected; + TMeanAccumulator c3BarTotalShapExpected; + TMeanAccumulator c4BarTotalShapExpected; + TMeanAccumulator c1BazTotalShapExpected; + TMeanAccumulator c2BazTotalShapExpected; + TMeanAccumulator c3BazTotalShapExpected; + TMeanAccumulator c4BazTotalShapExpected; + double c1FooTotalShapActual{0.0}, c2FooTotalShapActual{0.0}, + c3FooTotalShapActual{0.0}, c4FooTotalShapActual{0.0}; + double c1BarTotalShapActual{0.0}, c2BarTotalShapActual{0.0}, + c3BarTotalShapActual{0.0}, c4BarTotalShapActual{0.0}; + double c1BazTotalShapActual{0.0}, c2BazTotalShapActual{0.0}, + c3BazTotalShapActual{0.0}, c4BazTotalShapActual{0.0}; + bool hasTotalFeatureImportance{false}; for (const auto& result : results.GetArray()) { if (result.HasMember("row_results")) { - double c1{readShapValue(result, "c1")}; - double c2{readShapValue(result, "c2")}; - double c3{readShapValue(result, "c3")}; - double c4{readShapValue(result, "c4")}; - // We should have at least one feature that is important - BOOST_TEST_REQUIRE((c1 > 0.0 || c2 > 0.0 || c3 > 0.0 || c4 > 0.0)); - // class shap values should sum(abs()) to the overall feature importance double c1f{readShapValue(result, "c1", "foo")}; double c1bar{readShapValue(result, "c1", "bar")}; double c1baz{readShapValue(result, "c1", "baz")}; - BOOST_REQUIRE_CLOSE(c1, std::abs(c1f) + std::abs(c1bar) + std::abs(c1baz), 1e-6); + double c1{std::abs(c1f) + std::abs(c1bar) + std::abs(c1baz)}; + c1FooTotalShapExpected.add(std::fabs(c1f)); + c1BarTotalShapExpected.add(std::fabs(c1bar)); + c1BazTotalShapExpected.add(std::fabs(c1baz)); double c2f{readShapValue(result, "c2", "foo")}; double c2bar{readShapValue(result, "c2", "bar")}; double c2baz{readShapValue(result, "c2", "baz")}; - BOOST_REQUIRE_CLOSE(c2, std::abs(c2f) + std::abs(c2bar) + std::abs(c2baz), 1e-6); + double c2{std::abs(c2f) + std::abs(c2bar) + std::abs(c2baz)}; + c2FooTotalShapExpected.add(std::fabs(c2f)); + c2BarTotalShapExpected.add(std::fabs(c2bar)); + c2BazTotalShapExpected.add(std::fabs(c2baz)); double c3f{readShapValue(result, "c3", "foo")}; double c3bar{readShapValue(result, "c3", "bar")}; double c3baz{readShapValue(result, "c3", "baz")}; - BOOST_REQUIRE_CLOSE(c3, std::abs(c3f) + std::abs(c3bar) + std::abs(c3baz), 1e-6); + double c3{std::abs(c3f) + std::abs(c3bar) + std::abs(c3baz)}; + c3FooTotalShapExpected.add(std::fabs(c3f)); + c3BarTotalShapExpected.add(std::fabs(c3bar)); + c3BazTotalShapExpected.add(std::fabs(c3baz)); double c4f{readShapValue(result, "c4", "foo")}; double c4bar{readShapValue(result, "c4", "bar")}; double c4baz{readShapValue(result, "c4", "baz")}; - BOOST_REQUIRE_CLOSE(c4, std::abs(c4f) + std::abs(c4bar) + std::abs(c4baz), 1e-6); + double c4{std::abs(c4f) + std::abs(c4bar) + std::abs(c4baz)}; + c4FooTotalShapExpected.add(std::fabs(c4f)); + c4BarTotalShapExpected.add(std::fabs(c4bar)); + c4BazTotalShapExpected.add(std::fabs(c4baz)); + + // We should have at least one feature that is important + BOOST_TEST_REQUIRE((c1 > 0.0 || c2 > 0.0 || c3 > 0.0 || c4 > 0.0)); + } else if (result.HasMember("model_metadata")) { + if (result["model_metadata"].HasMember("total_feature_importance")) { + hasTotalFeatureImportance = true; + } + // TODO reactivate once Java parsing is ready + c1FooTotalShapActual = readTotalShapValue(result, "c1", "foo"); + c2FooTotalShapActual = readTotalShapValue(result, "c2", "foo"); + c3FooTotalShapActual = readTotalShapValue(result, "c3", "foo"); + c4FooTotalShapActual = readTotalShapValue(result, "c4", "foo"); + c1BarTotalShapActual = readTotalShapValue(result, "c1", "bar"); + c2BarTotalShapActual = readTotalShapValue(result, "c2", "bar"); + c3BarTotalShapActual = readTotalShapValue(result, "c3", "bar"); + c4BarTotalShapActual = readTotalShapValue(result, "c4", "bar"); + c1BazTotalShapActual = readTotalShapValue(result, "c1", "baz"); + c2BazTotalShapActual = readTotalShapValue(result, "c2", "baz"); + c3BazTotalShapActual = readTotalShapValue(result, "c3", "baz"); + c4BazTotalShapActual = readTotalShapValue(result, "c4", "baz"); } } + // TODO reactivate once Java parsing is ready + // BOOST_TEST_REQUIRE(hasTotalFeatureImportance); + // BOOST_REQUIRE_CLOSE(c1FooTotalShapActual, + // maths::CBasicStatistics::mean(c1FooTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c2FooTotalShapActual, + // maths::CBasicStatistics::mean(c2FooTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c3FooTotalShapActual, + // maths::CBasicStatistics::mean(c3FooTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c4FooTotalShapActual, + // maths::CBasicStatistics::mean(c4FooTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c1BarTotalShapActual, + // maths::CBasicStatistics::mean(c1BarTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c2BarTotalShapActual, + // maths::CBasicStatistics::mean(c2BarTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c3BarTotalShapActual, + // maths::CBasicStatistics::mean(c3BarTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c4BarTotalShapActual, + // maths::CBasicStatistics::mean(c4BarTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c1BazTotalShapActual, + // maths::CBasicStatistics::mean(c1BazTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c2BazTotalShapActual, + // maths::CBasicStatistics::mean(c2BazTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c3BazTotalShapActual, + // maths::CBasicStatistics::mean(c3BazTotalShapExpected), 1.0); + // BOOST_REQUIRE_CLOSE(c4BazTotalShapActual, + // maths::CBasicStatistics::mean(c4BazTotalShapExpected), 1.0); } BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoShap, SFixture) { diff --git a/lib/maths/CTreeShapFeatureImportance.cc b/lib/maths/CTreeShapFeatureImportance.cc index 12eeef6039..a8b2b2a5a3 100644 --- a/lib/maths/CTreeShapFeatureImportance.cc +++ b/lib/maths/CTreeShapFeatureImportance.cc @@ -362,5 +362,9 @@ void CTreeShapFeatureImportance::unwindPath(CSplitPath& path, int pathIndex, int } --nextIndex; } + +const CTreeShapFeatureImportance::TStrVec& CTreeShapFeatureImportance::columnNames() const { + return m_ColumnNames; +} } }