From 811a3b5e4370e3298f91848e0758726837e13fbf Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Thu, 19 Sep 2019 16:57:11 -0700 Subject: [PATCH 01/15] Add entrypoint for PFI --- .../PermutationFeatureImportance.cs | 342 ++++++++++++++++++ .../Common/EntryPoints/core_ep-list.tsv | 1 + .../Common/EntryPoints/core_manifest.json | 85 +++++ 3 files changed, 428 insertions(+) create mode 100644 src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs new file mode 100644 index 0000000000..a05ace83ea --- /dev/null +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -0,0 +1,342 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; + +[assembly: LoadableClass(typeof(void), typeof(PermutationFeatureImportanceEntryPoints), null, typeof(SignatureEntryPointModule), "PermutationFeatureImportance")] + +namespace Microsoft.ML.Transforms +{ + internal static class PermutationFeatureImportanceEntryPoints + { + [TlcModule.EntryPoint(Name = "Transforms.PermutationFeatureImportance", Desc = "Permutation Feature Importance (PFI)", UserName = "PFI", ShortName = "PFI")] + public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IHostEnvironment env, PermutationFeatureImportanceArguments input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register("Pfi"); + host.CheckValue(input, nameof(input)); + EntryPointUtils.CheckInputArgs(host, input); + + var mlContext = new MLContext(); + + var model = mlContext.Model.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema); + var chain = model as TransformerChain; + var predictor = chain.LastTransformer as ISingleFeaturePredictionTransformer; + + var transformedData = model.Transform(input.Data); + + IDataView result = PermutationFeatureImportanceUtils.GetMetrics(mlContext, predictor, transformedData, input); + + return new PermutationFeatureImportanceOutput { Metrics = result }; + } + } + + internal sealed class PermutationFeatureImportanceOutput + { + [TlcModule.Output(Desc = "The PFI metrics")] + public IDataView Metrics; + } + + internal sealed class PermutationFeatureImportanceArguments : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "The path to the model file", ShortName = "path", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public IFileHandle ModelPath; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Label column name", ShortName = "label", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public string LabelColumnName = "Label"; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Group ID column", ShortName = "groupId", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public string RowGroupColumnName = "GroupId"; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Use feature weights to pre-filter features", ShortName = "usefw", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public bool UseFeatureWeightFilter = false; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of examples to evaluate on", ShortName = "numexamples", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public int? NumberOfExamplesToUse = null; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of permutations to perform", ShortName = "permutations", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public int PermutationCount = 1; + } + + internal static class PermutationFeatureImportanceUtils + { + private static string[] GetSlotNames(IDataView data) + { + VBuffer> slots = default; + data.Schema["Features"].GetSlotNames(ref slots); + + var column = data.GetColumn>( + data.Schema["Features"]); + + List slotNames = new List(); + + foreach (var item in column.First>().Items(all: true)) + { + slotNames.Add(slots.GetValues()[item.Key].ToString()); + }; + + return slotNames.ToArray(); + } + + internal static IDataView GetMetrics( + MLContext mlContext, + ISingleFeaturePredictionTransformer predictor, + IDataView data, + PermutationFeatureImportanceArguments input) + { + IDataView result; + if (predictor is BinaryPredictionTransformer>) + result = GetBinaryMetrics(mlContext, predictor, data, input); + else if (predictor is MulticlassPredictionTransformer>>) + result = GetMulticlassMetrics(mlContext, predictor, data, input); + else if (predictor is RegressionPredictionTransformer>) + result = GetRegressionMetrics(mlContext, predictor, data, input); + else if (predictor is RankingPredictionTransformer>) + result = GetRankingMetrics(mlContext, predictor, data, input); + else + throw Contracts.Except( + "Unsupported predictor type. Predictor must be binary classifier," + + "multiclass classifier, regressor, or ranker."); + + return result; + } + + private static IDataView GetBinaryMetrics( + MLContext mlContext, + ISingleFeaturePredictionTransformer predictor, + IDataView data, + PermutationFeatureImportanceArguments input) + { + var slotNames = GetSlotNames(data); + + var permutationMetrics = mlContext.BinaryClassification + .PermutationFeatureImportance(predictor, + data, + labelColumnName: input.LabelColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + IEnumerable metrics = Enumerable.Empty(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + var pMetric = permutationMetrics[i]; + metrics = metrics.Append(new BinaryMetrics + { + FeatureName = slotNames[i], + AreaUnderRocCurve = pMetric.AreaUnderRocCurve.Mean, + Accuracy = pMetric.Accuracy.Mean, + PositivePrecision = pMetric.PositivePrecision.Mean, + PositiveRecall = pMetric.PositiveRecall.Mean, + NegativePrecision = pMetric.NegativePrecision.Mean, + NegativeRecall = pMetric.NegativeRecall.Mean, + F1Score = pMetric.F1Score.Mean, + AreaUnderPrecisionRecallCurve = pMetric.AreaUnderPrecisionRecallCurve.Mean + }); + } + + var result = mlContext.Data.LoadFromEnumerable(metrics); + return result; + } + + private static IDataView GetMulticlassMetrics( + MLContext mlContext, + ISingleFeaturePredictionTransformer predictor, + IDataView data, + PermutationFeatureImportanceArguments input) + { + var slotNames = GetSlotNames(data); + + var permutationMetrics = mlContext.MulticlassClassification + .PermutationFeatureImportance(predictor, + data, + labelColumnName: input.LabelColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + IEnumerable metrics = Enumerable.Empty(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + var pMetric = permutationMetrics[i]; + metrics = metrics.Append(new MulticlassMetrics + { + FeatureName = slotNames[i], + MacroAccuracy = pMetric.MacroAccuracy.Mean, + MicroAccuracy = pMetric.MicroAccuracy.Mean, + LogLoss = pMetric.LogLoss.Mean, + LogLossReduction = pMetric.LogLossReduction.Mean, + TopKAccuracy = pMetric.TopKAccuracy.Mean, + PerClassLogLoss = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray() + }); ; + } + + // Convert unknown size vectors to known size. + var metric = metrics.First(); + int perClassLogLossDimension = metric.PerClassLogLoss.Length; + SchemaDefinition schema = SchemaDefinition.Create(typeof(MulticlassMetrics)); + var perClassLogLossType = ((VectorDataViewType)schema[nameof(metric.PerClassLogLoss)].ColumnType).ItemType; + schema[nameof(metric.PerClassLogLoss)].ColumnType = new VectorDataViewType(perClassLogLossType, perClassLogLossDimension); + + var result = mlContext.Data.LoadFromEnumerable(metrics, schema); + return result; + } + + private static IDataView GetRegressionMetrics( + MLContext mlContext, + ISingleFeaturePredictionTransformer predictor, + IDataView data, + PermutationFeatureImportanceArguments input) + { + var slotNames = GetSlotNames(data); + + var permutationMetrics = mlContext.Regression + .PermutationFeatureImportance(predictor, + data, + labelColumnName: input.LabelColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + IEnumerable metrics = Enumerable.Empty(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + var pMetric = permutationMetrics[i]; + metrics = metrics.Append(new RegressionMetrics + { + FeatureName = slotNames[i], + MeanAbsoluteError = pMetric.MeanAbsoluteError.Mean, + MeanSquaredError = pMetric.MeanSquaredError.Mean, + RootMeanSquaredError = pMetric.RootMeanSquaredError.Mean, + LossFunction = pMetric.LossFunction.Mean, + RSquared = pMetric.RSquared.Mean + }); + } + + var result = mlContext.Data.LoadFromEnumerable(metrics); + return result; + } + + private static IDataView GetRankingMetrics( + MLContext mlContext, + ISingleFeaturePredictionTransformer predictor, + IDataView data, + PermutationFeatureImportanceArguments input) + { + var slotNames = GetSlotNames(data); + + var permutationMetrics = mlContext.Ranking + .PermutationFeatureImportance(predictor, + data, + labelColumnName: input.LabelColumnName, + rowGroupColumnName: input.RowGroupColumnName, + useFeatureWeightFilter: input.UseFeatureWeightFilter, + numberOfExamplesToUse: input.NumberOfExamplesToUse, + permutationCount: input.PermutationCount); + + Contracts.Assert(slotNames.Length == permutationMetrics.Length, + "Mismatch between number of feature slots and number of features permuted."); + + IEnumerable metrics = Enumerable.Empty(); + for (int i = 0; i < permutationMetrics.Length; i++) + { + var pMetric = permutationMetrics[i]; + metrics = metrics.Append(new RankingMetrics + { + FeatureName = slotNames[i], + DiscountedCumulativeGains = pMetric.DiscountedCumulativeGains.Select(x => x.Mean).ToArray(), + NormalizedDiscountedCumulativeGains = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.Mean).ToArray() + }); + } + + // Convert unknown size vectors to known size. + var metric = metrics.First(); + int dcgDimension = metric.DiscountedCumulativeGains.Length; + int ndcgDimension = metric.NormalizedDiscountedCumulativeGains.Length; + SchemaDefinition schema = SchemaDefinition.Create(typeof(RankingMetrics)); + var dcgType = ((VectorDataViewType)schema[nameof(metric.DiscountedCumulativeGains)].ColumnType).ItemType; + var ndcgType = ((VectorDataViewType)schema[nameof(metric.NormalizedDiscountedCumulativeGains)].ColumnType).ItemType; + schema[nameof(metric.DiscountedCumulativeGains)].ColumnType = new VectorDataViewType(dcgType, dcgDimension); + schema[nameof(metric.NormalizedDiscountedCumulativeGains)].ColumnType = new VectorDataViewType(ndcgType, ndcgDimension); + + var result = mlContext.Data.LoadFromEnumerable(metrics, schema); + return result; + } + } + + internal class BinaryMetrics + { + public string FeatureName { get; set; } + + public double AreaUnderRocCurve { get; set; } + + public double Accuracy { get; set; } + + public double PositivePrecision { get; set; } + + public double PositiveRecall { get; set; } + + public double NegativePrecision { get; set; } + + public double NegativeRecall { get; set; } + + public double F1Score { get; set; } + + public double AreaUnderPrecisionRecallCurve { get; set; } + } + + internal class MulticlassMetrics + { + public string FeatureName { get; set; } + + public double MacroAccuracy { get; set; } + + public double MicroAccuracy { get; set; } + + public double LogLoss { get; set; } + + public double LogLossReduction { get; set; } + + public double TopKAccuracy { get; set; } + + public double[] PerClassLogLoss { get; set; } + } + + internal class RegressionMetrics + { + public string FeatureName { get; set; } + + public double MeanAbsoluteError { get; set; } + + public double MeanSquaredError { get; set; } + + public double RootMeanSquaredError { get; set; } + + public double LossFunction { get; set; } + + public double RSquared { get; set; } + } + + internal class RankingMetrics + { + public string FeatureName { get; set; } + + public double[] DiscountedCumulativeGains { get; set; } + + public double[] NormalizedDiscountedCumulativeGains { get; set; } + } +} diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 6288a254ae..00a790252a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -116,6 +116,7 @@ Transforms.NGramTranslator Produces a bag of counts of n-grams (sequences of con Transforms.NoOperation Does nothing. Microsoft.ML.Data.NopTransform Nop Microsoft.ML.Data.NopTransform+NopInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.OptionalColumnCreator If the source column does not exist after deserialization, create a column with the right type and default values. Microsoft.ML.Transforms.OptionalColumnTransform MakeOptional Microsoft.ML.Transforms.OptionalColumnTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Transforms.PrincipalComponentAnalysisTransformer Calculate Microsoft.ML.Transforms.PrincipalComponentAnalysisTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceEntryPoints PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceArguments Microsoft.ML.Transforms.PermutationFeatureImportanceOutput Transforms.PredictedLabelColumnOriginalValueConverter Transforms a predicted label column to its original values, unless it is of type bool. Microsoft.ML.EntryPoints.FeatureCombiner ConvertPredictedLabel Microsoft.ML.EntryPoints.FeatureCombiner+PredictedLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RandomNumberGenerator Adds a column with a generated number sequence. Microsoft.ML.Transforms.RandomNumberGenerator Generate Microsoft.ML.Transforms.GenerateNumberTransform+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RowRangeFilter Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values. Microsoft.ML.EntryPoints.SelectRows FilterByRange Microsoft.ML.Transforms.RangeFilter+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index a68890fe6f..316166fecc 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21710,6 +21710,91 @@ "ITransformOutput" ] }, + { + "Name": "Transforms.PermutationFeatureImportance", + "Desc": "", + "FriendlyName": "BinaryPFI", + "ShortName": "BinaryPFI", + "Inputs": [ + { + "Name": "ModelPath", + "Type": "FileHandle", + "Desc": "", + "Aliases": [ + "path" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "LabelColumnName", + "Type": "String", + "Desc": "", + "Aliases": [ + "label" + ], + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": "Label" + }, + { + "Name": "UseFeatureWeightFilter", + "Type": "Bool", + "Desc": "", + "Aliases": [ + "usefw" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false + }, + { + "Name": "NumberOfExamplesToUse", + "Type": "Int", + "Desc": "", + "Aliases": [ + "numexamples" + ], + "Required": false, + "SortOrder": 4.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "PermutationCount", + "Type": "Int", + "Desc": "", + "Aliases": [ + "permutations" + ], + "Required": false, + "SortOrder": 5.0, + "IsNullable": false, + "Default": 1 + } + ], + "Outputs": [ + { + "Name": "Metrics", + "Type": "DataView", + "Desc": "The PFI metrics" + } + ], + "InputKind": [ + "ITransformInput" + ] + }, { "Name": "Transforms.PredictedLabelColumnOriginalValueConverter", "Desc": "Transforms a predicted label column to its original values, unless it is of type bool.", From a31efedb8f7075a4364a01d6870ab2c0249016b4 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Thu, 19 Sep 2019 17:01:57 -0700 Subject: [PATCH 02/15] Regenerate EP catalog --- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../Common/EntryPoints/core_manifest.json | 54 +++++++++++-------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 00a790252a..4f2bcc426a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -116,7 +116,7 @@ Transforms.NGramTranslator Produces a bag of counts of n-grams (sequences of con Transforms.NoOperation Does nothing. Microsoft.ML.Data.NopTransform Nop Microsoft.ML.Data.NopTransform+NopInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.OptionalColumnCreator If the source column does not exist after deserialization, create a column with the right type and default values. Microsoft.ML.Transforms.OptionalColumnTransform MakeOptional Microsoft.ML.Transforms.OptionalColumnTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.PcaCalculator PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace. Microsoft.ML.Transforms.PrincipalComponentAnalysisTransformer Calculate Microsoft.ML.Transforms.PrincipalComponentAnalysisTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceEntryPoints PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceArguments Microsoft.ML.Transforms.PermutationFeatureImportanceOutput +Transforms.PermutationFeatureImportance Permutation Feature Importance (PFI) Microsoft.ML.Transforms.PermutationFeatureImportanceEntryPoints PermutationFeatureImportance Microsoft.ML.Transforms.PermutationFeatureImportanceArguments Microsoft.ML.Transforms.PermutationFeatureImportanceOutput Transforms.PredictedLabelColumnOriginalValueConverter Transforms a predicted label column to its original values, unless it is of type bool. Microsoft.ML.EntryPoints.FeatureCombiner ConvertPredictedLabel Microsoft.ML.EntryPoints.FeatureCombiner+PredictedLabelInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RandomNumberGenerator Adds a column with a generated number sequence. Microsoft.ML.Transforms.RandomNumberGenerator Generate Microsoft.ML.Transforms.GenerateNumberTransform+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.RowRangeFilter Filters a dataview on a column of type Single, Double or Key (contiguous). Keeps the values that are in the specified min/max range. NaNs are always filtered out. If the input is a Key type, the min/max are considered percentages of the number of values. Microsoft.ML.EntryPoints.SelectRows FilterByRange Microsoft.ML.Transforms.RangeFilter+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 316166fecc..ec7ddbb497 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21712,74 +21712,86 @@ }, { "Name": "Transforms.PermutationFeatureImportance", - "Desc": "", - "FriendlyName": "BinaryPFI", - "ShortName": "BinaryPFI", + "Desc": "Permutation Feature Importance (PFI)", + "FriendlyName": "PFI", + "ShortName": "PFI", "Inputs": [ { - "Name": "ModelPath", - "Type": "FileHandle", - "Desc": "", - "Aliases": [ - "path" - ], + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", "Required": true, "SortOrder": 1.0, "IsNullable": false }, { - "Name": "Data", - "Type": "DataView", - "Desc": "Input dataset", + "Name": "ModelPath", + "Type": "FileHandle", + "Desc": "The path to the model file", + "Aliases": [ + "path" + ], "Required": true, - "SortOrder": 1.0, + "SortOrder": 150.0, "IsNullable": false }, { "Name": "LabelColumnName", "Type": "String", - "Desc": "", + "Desc": "Label column name", "Aliases": [ "label" ], "Required": false, - "SortOrder": 2.0, + "SortOrder": 150.0, "IsNullable": false, "Default": "Label" }, + { + "Name": "RowGroupColumnName", + "Type": "String", + "Desc": "Group ID column", + "Aliases": [ + "groupId" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "GroupId" + }, { "Name": "UseFeatureWeightFilter", "Type": "Bool", - "Desc": "", + "Desc": "Use feature weights to pre-filter features", "Aliases": [ "usefw" ], "Required": false, - "SortOrder": 3.0, + "SortOrder": 150.0, "IsNullable": false, "Default": false }, { "Name": "NumberOfExamplesToUse", "Type": "Int", - "Desc": "", + "Desc": "Limit the number of examples to evaluate on", "Aliases": [ "numexamples" ], "Required": false, - "SortOrder": 4.0, + "SortOrder": 150.0, "IsNullable": true, "Default": null }, { "Name": "PermutationCount", "Type": "Int", - "Desc": "", + "Desc": "The number of permutations to perform", "Aliases": [ "permutations" ], "Required": false, - "SortOrder": 5.0, + "SortOrder": 150.0, "IsNullable": false, "Default": 1 } From 35fdaa883615af91d1e2de0fbad12336aee8afd7 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Fri, 20 Sep 2019 12:57:06 -0700 Subject: [PATCH 03/15] Add tests --- .../PermutationFeatureImportance.cs | 7 +- .../UnitTests/TestEntryPoints.cs | 686 ++++++++++++++++++ 2 files changed, 692 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index a05ace83ea..33fcb21993 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Linq; using Microsoft.ML; @@ -27,6 +31,7 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH var model = mlContext.Model.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema); var chain = model as TransformerChain; var predictor = chain.LastTransformer as ISingleFeaturePredictionTransformer; + Contracts.Assert(!(predictor is null), "Model does not have a predictor or the predictor is not supported."); var transformedData = model.Transform(input.Data); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index e9727d13a7..0473018b76 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -612,6 +612,692 @@ public void EntryPointExecGraphCommand() cmd.Run(); } + [Fact] + public void BinaryPermutationFeatureImportance() + { + var dataPath = GetDataPath("adult.tiny.with-schema.txt"); + var modelPath = DeleteOutputPath("model.zip"); + var outputDataPath = DeleteOutputPath("metrics.idv"); + + string trainingGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data', + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'MaximumNumberOfIterations': 1, + 'NumThreads': 1, + 'TrainingData': '$output_data4' + }}, + 'Name': 'Trainers.LogisticRegressionBinaryClassifier', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }} + ], + 'Outputs': {{ + 'output_model': '{1}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath)); + + var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); + File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); + + var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; + var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); + trainingCmd.Run(); + + string pfiGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file1': '{0}', + 'file2': '{1}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file1' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'ModelPath': '$file2', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{2}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + + var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); + File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + + var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; + var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); + pfiCmd.Run(); + + var mlContext = new MLContext(); + + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderRocCurve")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("Accuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositivePrecision")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositiveRecall")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativePrecision")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativeRecall")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("F1Score")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderPrecisionRecallCurve")); + } + + [Fact] + public void MulticlassPermutationFeatureImportance() + { + var dataPath = GetDataPath("adult.tiny.with-schema.txt"); + var modelPath = DeleteOutputPath("model.zip"); + var outputDataPath = DeleteOutputPath("metrics.idv"); + + string trainingGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data', + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'TrainingData': '$output_data4', + 'NumThreads': 1, + 'MaxIterations': 1 + }}, + 'Name': 'Trainers.StochasticDualCoordinateAscentClassifier', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }} + ], + 'Outputs': {{ + 'output_model': '{1}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath)); + + var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); + File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); + + var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; + var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); + trainingCmd.Run(); + + string pfiGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file1': '{0}', + 'file2': '{1}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file1' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'ModelPath': '$file2', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{2}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + + var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); + File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + + var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; + var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); + pfiCmd.Run(); + + var mlContext = new MLContext(); + + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReduction")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLoss")); + } + + [Fact] + public void RegressionPermutationFeatureImportance() + { + var dataPath = GetDataPath("adult.tiny.with-schema.txt"); + var modelPath = DeleteOutputPath("model.zip"); + var outputDataPath = DeleteOutputPath("metrics.idv"); + + string trainingGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data', + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label' + }}, + 'Name': 'Transforms.LabelToFloatConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'TrainingData': '$output_data4', + 'NumThreads': 1, + 'MaxIterations': 1 + }}, + 'Name': 'Trainers.StochasticDualCoordinateAscentRegressor', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }} + ], + 'Outputs': {{ + 'output_model': '{1}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath)); + + var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); + File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); + + var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; + var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); + trainingCmd.Run(); + + string pfiGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file1': '{0}', + 'file2': '{1}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file1' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'ModelPath': '$file2', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{2}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + + var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); + File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + + var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; + var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); + pfiCmd.Run(); + + var mlContext = new MLContext(); + + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanAbsoluteError")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanSquaredError")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RootMeanSquaredError")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LossFunction")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RSquared")); + } + + [Fact] + public void RankingPermutationFeatureImportance() + { + var dataPath = GetDataPath("adult.tiny.with-schema.txt"); + var modelPath = DeleteOutputPath("model.zip"); + var outputDataPath = DeleteOutputPath("metrics.idv"); + + string trainingGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Name': 'Data.CustomTextLoader', + 'Outputs': {{ + 'Data': '$input_data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'Workclass', + 'Source': 'Workclass' + }} + ], + 'Data': '$input_data', + 'MaxNumTerms': 1000000, + 'Sort': 'ByOccurrence', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.TextToKeyConverter', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education-num', + 'capital-gain' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'NumberOfTrees': 1, + 'RowGroupColumnName': 'Workclass', + 'TrainingData': '$output_data4', + 'NumberOfLeaves': 2 + }}, + 'Name': 'Trainers.FastTreeRanker', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }} + ], + 'Outputs': {{ + 'output_model': '{1}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath)); + + var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); + File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); + + var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; + var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); + trainingCmd.Run(); + + string pfiGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file1': '{0}', + 'file2': '{1}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file1' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'ModelPath': '$file2', + 'PermutationCount': 20, + 'RowGroupColumnName': 'Workclass' + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{2}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + + var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); + File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + + var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; + var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); + pfiCmd.Run(); + + var mlContext = new MLContext(); + + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("DiscountedCumulativeGains")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NormalizedDiscountedCumulativeGains")); + } + [Fact] public void ScoreTransformerChainModel() { From 72991f31cf534a7a018c70499f58ede1dfeae1bd Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Fri, 20 Sep 2019 15:10:03 -0700 Subject: [PATCH 04/15] Adding Standard Error of Mean to PFI Metrics in EntryPoint --- .../PermutationFeatureImportance.cs | 72 +++++++++++++++++-- .../UnitTests/TestEntryPoints.cs | 21 ++++++ 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index 33fcb21993..6107f80cf1 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -138,13 +138,21 @@ private static IDataView GetBinaryMetrics( { FeatureName = slotNames[i], AreaUnderRocCurve = pMetric.AreaUnderRocCurve.Mean, + AreaUnderRocCurveStdErr = pMetric.AreaUnderRocCurve.StandardError, Accuracy = pMetric.Accuracy.Mean, + AccuracyStdErr = pMetric.Accuracy.StandardError, PositivePrecision = pMetric.PositivePrecision.Mean, + PositivePrecisionStdErr = pMetric.PositivePrecision.StandardError, PositiveRecall = pMetric.PositiveRecall.Mean, + PositiveRecallStdErr = pMetric.PositiveRecall.StandardError, NegativePrecision = pMetric.NegativePrecision.Mean, + NegativePrecisionStdErr = pMetric.NegativePrecision.StandardError, NegativeRecall = pMetric.NegativeRecall.Mean, + NegativeRecallStdErr = pMetric.NegativeRecall.StandardError, F1Score = pMetric.F1Score.Mean, - AreaUnderPrecisionRecallCurve = pMetric.AreaUnderPrecisionRecallCurve.Mean + F1ScoreStdErr = pMetric.F1Score.StandardError, + AreaUnderPrecisionRecallCurve = pMetric.AreaUnderPrecisionRecallCurve.Mean, + AreaUnderPrecisionRecallCurveStdErr = pMetric.AreaUnderPrecisionRecallCurve.StandardError }); } @@ -179,11 +187,17 @@ private static IDataView GetMulticlassMetrics( { FeatureName = slotNames[i], MacroAccuracy = pMetric.MacroAccuracy.Mean, + MacroAccuracyStdErr = pMetric.MacroAccuracy.StandardError, MicroAccuracy = pMetric.MicroAccuracy.Mean, + MicroAccuracyStdErr = pMetric.MicroAccuracy.StandardError, LogLoss = pMetric.LogLoss.Mean, + LogLossStdErr = pMetric.LogLoss.StandardError, LogLossReduction = pMetric.LogLossReduction.Mean, + LogLossReductionStdErr = pMetric.LogLossReduction.StandardError, TopKAccuracy = pMetric.TopKAccuracy.Mean, - PerClassLogLoss = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray() + TopKAccuracyStdErr = pMetric.TopKAccuracy.StandardError, + PerClassLogLoss = pMetric.PerClassLogLoss.Select(x => x.Mean).ToArray(), + PerClassLogLossStdErr = pMetric.PerClassLogLoss.Select(x => x.StandardError).ToArray() }); ; } @@ -225,10 +239,15 @@ private static IDataView GetRegressionMetrics( { FeatureName = slotNames[i], MeanAbsoluteError = pMetric.MeanAbsoluteError.Mean, + MeanAbsoluteErrorStdErr = pMetric.MeanAbsoluteError.StandardError, MeanSquaredError = pMetric.MeanSquaredError.Mean, + MeanSquaredErrorStdErr = pMetric.MeanSquaredError.StandardError, RootMeanSquaredError = pMetric.RootMeanSquaredError.Mean, + RootMeanSquaredErrorStdErr = pMetric.RootMeanSquaredError.StandardError, LossFunction = pMetric.LossFunction.Mean, - RSquared = pMetric.RSquared.Mean + LossFunctionStdErr = pMetric.LossFunction.StandardError, + RSquared = pMetric.RSquared.Mean, + RSquaredStdErr = pMetric.RSquared.StandardError }); } @@ -264,7 +283,9 @@ private static IDataView GetRankingMetrics( { FeatureName = slotNames[i], DiscountedCumulativeGains = pMetric.DiscountedCumulativeGains.Select(x => x.Mean).ToArray(), - NormalizedDiscountedCumulativeGains = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.Mean).ToArray() + DiscountedCumulativeGainsStdErr = pMetric.DiscountedCumulativeGains.Select(x => x.StandardError).ToArray(), + NormalizedDiscountedCumulativeGains = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.Mean).ToArray(), + NormalizedDiscountedCumulativeGainsStdErr = pMetric.NormalizedDiscountedCumulativeGains.Select(x => x.StandardError).ToArray() }); } @@ -289,19 +310,36 @@ internal class BinaryMetrics public double AreaUnderRocCurve { get; set; } + public double AreaUnderRocCurveStdErr { get; set; } + public double Accuracy { get; set; } + public double AccuracyStdErr { get; set; } + public double PositivePrecision { get; set; } + public double PositivePrecisionStdErr { get; set; } + public double PositiveRecall { get; set; } + public double PositiveRecallStdErr { get; set; } + public double NegativePrecision { get; set; } + public double NegativePrecisionStdErr { get; set; } + public double NegativeRecall { get; set; } + public double NegativeRecallStdErr { get; set; } + public double F1Score { get; set; } + public double F1ScoreStdErr { get; set; } + public double AreaUnderPrecisionRecallCurve { get; set; } + + public double AreaUnderPrecisionRecallCurveStdErr { get; set; } + } internal class MulticlassMetrics @@ -310,15 +348,27 @@ internal class MulticlassMetrics public double MacroAccuracy { get; set; } + public double MacroAccuracyStdErr { get; set; } + public double MicroAccuracy { get; set; } + public double MicroAccuracyStdErr { get; set; } + public double LogLoss { get; set; } + public double LogLossStdErr { get; set; } + public double LogLossReduction { get; set; } + public double LogLossReductionStdErr { get; set; } + public double TopKAccuracy { get; set; } + public double TopKAccuracyStdErr { get; set; } + public double[] PerClassLogLoss { get; set; } + + public double[] PerClassLogLossStdErr { get; set; } } internal class RegressionMetrics @@ -327,13 +377,23 @@ internal class RegressionMetrics public double MeanAbsoluteError { get; set; } + public double MeanAbsoluteErrorStdErr { get; set; } + public double MeanSquaredError { get; set; } + public double MeanSquaredErrorStdErr { get; set; } + public double RootMeanSquaredError { get; set; } + public double RootMeanSquaredErrorStdErr { get; set; } + public double LossFunction { get; set; } + public double LossFunctionStdErr { get; set; } + public double RSquared { get; set; } + + public double RSquaredStdErr { get; set; } } internal class RankingMetrics @@ -342,6 +402,10 @@ internal class RankingMetrics public double[] DiscountedCumulativeGains { get; set; } + public double[] DiscountedCumulativeGainsStdErr { get; set; } + public double[] NormalizedDiscountedCumulativeGains { get; set; } + + public double[] NormalizedDiscountedCumulativeGainsStdErr { get; set; } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 0473018b76..f5b5025065 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -783,6 +783,14 @@ public void BinaryPermutationFeatureImportance() Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativeRecall")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("F1Score")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderPrecisionRecallCurve")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderRocCurveStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositivePrecisionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PositiveRecallStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativePrecisionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NegativeRecallStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("F1ScoreStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderPrecisionRecallCurveStdErr")); } [Fact] @@ -954,6 +962,12 @@ public void MulticlassPermutationFeatureImportance() Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReduction")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracy")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReductionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLossStdErr")); } [Fact] @@ -1123,6 +1137,11 @@ public void RegressionPermutationFeatureImportance() Assert.NotNull(loadedData.Schema.GetColumnOrNull("RootMeanSquaredError")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("LossFunction")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("RSquared")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanAbsoluteErrorStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanSquaredErrorStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RootMeanSquaredErrorStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LossFunctionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("RSquaredStdErr")); } [Fact] @@ -1296,6 +1315,8 @@ public void RankingPermutationFeatureImportance() Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("DiscountedCumulativeGains")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("NormalizedDiscountedCumulativeGains")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("DiscountedCumulativeGainsStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("NormalizedDiscountedCumulativeGainsStdErr")); } [Fact] From 163292db757980e66a5c335407d1b4c548a416c2 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Fri, 20 Sep 2019 19:05:45 -0700 Subject: [PATCH 05/15] PR Feedback --- .../PermutationFeatureImportance.cs | 194 +++++++++--------- 1 file changed, 99 insertions(+), 95 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index 6107f80cf1..01c124b0f3 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -31,7 +31,9 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH var model = mlContext.Model.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema); var chain = model as TransformerChain; var predictor = chain.LastTransformer as ISingleFeaturePredictionTransformer; - Contracts.Assert(!(predictor is null), "Model does not have a predictor or the predictor is not supported."); + Contracts.Assert(!(predictor is null), "The last transformer in the model is not a predictor, or Permutation " + + "Feature Importance (PFI) is not supported for the predictor. The last transformer in the model must be a " + + "predictor, as PFI is calculated for a predictor model."); var transformedData = model.Transform(input.Data); @@ -70,24 +72,6 @@ internal sealed class PermutationFeatureImportanceArguments : TransformInputBase internal static class PermutationFeatureImportanceUtils { - private static string[] GetSlotNames(IDataView data) - { - VBuffer> slots = default; - data.Schema["Features"].GetSlotNames(ref slots); - - var column = data.GetColumn>( - data.Schema["Features"]); - - List slotNames = new List(); - - foreach (var item in column.First>().Items(all: true)) - { - slotNames.Add(slots.GetValues()[item.Key].ToString()); - }; - - return slotNames.ToArray(); - } - internal static IDataView GetMetrics( MLContext mlContext, ISingleFeaturePredictionTransformer predictor, @@ -105,7 +89,7 @@ internal static IDataView GetMetrics( result = GetRankingMetrics(mlContext, predictor, data, input); else throw Contracts.Except( - "Unsupported predictor type. Predictor must be binary classifier," + + "Unsupported predictor type. Predictor must be binary classifier, " + "multiclass classifier, regressor, or ranker."); return result; @@ -130,11 +114,11 @@ private static IDataView GetBinaryMetrics( Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); - IEnumerable metrics = Enumerable.Empty(); + List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { var pMetric = permutationMetrics[i]; - metrics = metrics.Append(new BinaryMetrics + metrics.Add(new BinaryMetrics { FeatureName = slotNames[i], AreaUnderRocCurve = pMetric.AreaUnderRocCurve.Mean, @@ -179,11 +163,11 @@ private static IDataView GetMulticlassMetrics( Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); - IEnumerable metrics = Enumerable.Empty(); + List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { var pMetric = permutationMetrics[i]; - metrics = metrics.Append(new MulticlassMetrics + metrics.Add(new MulticlassMetrics { FeatureName = slotNames[i], MacroAccuracy = pMetric.MacroAccuracy.Mean, @@ -203,10 +187,9 @@ private static IDataView GetMulticlassMetrics( // Convert unknown size vectors to known size. var metric = metrics.First(); - int perClassLogLossDimension = metric.PerClassLogLoss.Length; SchemaDefinition schema = SchemaDefinition.Create(typeof(MulticlassMetrics)); - var perClassLogLossType = ((VectorDataViewType)schema[nameof(metric.PerClassLogLoss)].ColumnType).ItemType; - schema[nameof(metric.PerClassLogLoss)].ColumnType = new VectorDataViewType(perClassLogLossType, perClassLogLossDimension); + ConvertVectorToKnownSize(nameof(metric.PerClassLogLoss), metric.PerClassLogLoss.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema); var result = mlContext.Data.LoadFromEnumerable(metrics, schema); return result; @@ -231,11 +214,11 @@ private static IDataView GetRegressionMetrics( Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); - IEnumerable metrics = Enumerable.Empty(); + List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { var pMetric = permutationMetrics[i]; - metrics = metrics.Append(new RegressionMetrics + metrics.Add(new RegressionMetrics { FeatureName = slotNames[i], MeanAbsoluteError = pMetric.MeanAbsoluteError.Mean, @@ -275,11 +258,11 @@ private static IDataView GetRankingMetrics( Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); - IEnumerable metrics = Enumerable.Empty(); + List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { var pMetric = permutationMetrics[i]; - metrics = metrics.Append(new RankingMetrics + metrics.Add(new RankingMetrics { FeatureName = slotNames[i], DiscountedCumulativeGains = pMetric.DiscountedCumulativeGains.Select(x => x.Mean).ToArray(), @@ -291,121 +274,142 @@ private static IDataView GetRankingMetrics( // Convert unknown size vectors to known size. var metric = metrics.First(); - int dcgDimension = metric.DiscountedCumulativeGains.Length; - int ndcgDimension = metric.NormalizedDiscountedCumulativeGains.Length; SchemaDefinition schema = SchemaDefinition.Create(typeof(RankingMetrics)); - var dcgType = ((VectorDataViewType)schema[nameof(metric.DiscountedCumulativeGains)].ColumnType).ItemType; - var ndcgType = ((VectorDataViewType)schema[nameof(metric.NormalizedDiscountedCumulativeGains)].ColumnType).ItemType; - schema[nameof(metric.DiscountedCumulativeGains)].ColumnType = new VectorDataViewType(dcgType, dcgDimension); - schema[nameof(metric.NormalizedDiscountedCumulativeGains)].ColumnType = new VectorDataViewType(ndcgType, ndcgDimension); + ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGains), metric.DiscountedCumulativeGains.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGains), metric.NormalizedDiscountedCumulativeGains.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGainsStdErr), metric.DiscountedCumulativeGainsStdErr.Length, ref schema); + ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGainsStdErr), metric.NormalizedDiscountedCumulativeGainsStdErr.Length, ref schema); var result = mlContext.Data.LoadFromEnumerable(metrics, schema); return result; } - } - internal class BinaryMetrics - { - public string FeatureName { get; set; } + private static string[] GetSlotNames(IDataView data) + { + VBuffer> slots = default; + data.Schema["Features"].GetSlotNames(ref slots); + + var column = data.GetColumn>( + data.Schema["Features"]); - public double AreaUnderRocCurve { get; set; } + List slotNames = new List(); - public double AreaUnderRocCurveStdErr { get; set; } + foreach (var item in column.First>().Items(all: true)) + { + slotNames.Add(slots.GetValues()[item.Key].ToString()); + }; - public double Accuracy { get; set; } + return slotNames.ToArray(); + } - public double AccuracyStdErr { get; set; } + private static void ConvertVectorToKnownSize(string metricName, int size, ref SchemaDefinition schema) + { + var type = ((VectorDataViewType)schema[metricName].ColumnType).ItemType; + schema[metricName].ColumnType = new VectorDataViewType(type, size); + } - public double PositivePrecision { get; set; } + private class BinaryMetrics + { + public string FeatureName { get; set; } - public double PositivePrecisionStdErr { get; set; } + public double AreaUnderRocCurve { get; set; } - public double PositiveRecall { get; set; } + public double AreaUnderRocCurveStdErr { get; set; } - public double PositiveRecallStdErr { get; set; } + public double Accuracy { get; set; } - public double NegativePrecision { get; set; } + public double AccuracyStdErr { get; set; } - public double NegativePrecisionStdErr { get; set; } + public double PositivePrecision { get; set; } - public double NegativeRecall { get; set; } + public double PositivePrecisionStdErr { get; set; } - public double NegativeRecallStdErr { get; set; } + public double PositiveRecall { get; set; } - public double F1Score { get; set; } + public double PositiveRecallStdErr { get; set; } - public double F1ScoreStdErr { get; set; } + public double NegativePrecision { get; set; } - public double AreaUnderPrecisionRecallCurve { get; set; } + public double NegativePrecisionStdErr { get; set; } - public double AreaUnderPrecisionRecallCurveStdErr { get; set; } + public double NegativeRecall { get; set; } - } + public double NegativeRecallStdErr { get; set; } - internal class MulticlassMetrics - { - public string FeatureName { get; set; } + public double F1Score { get; set; } - public double MacroAccuracy { get; set; } + public double F1ScoreStdErr { get; set; } - public double MacroAccuracyStdErr { get; set; } + public double AreaUnderPrecisionRecallCurve { get; set; } - public double MicroAccuracy { get; set; } + public double AreaUnderPrecisionRecallCurveStdErr { get; set; } + } - public double MicroAccuracyStdErr { get; set; } + private class MulticlassMetrics + { + public string FeatureName { get; set; } - public double LogLoss { get; set; } + public double MacroAccuracy { get; set; } - public double LogLossStdErr { get; set; } + public double MacroAccuracyStdErr { get; set; } - public double LogLossReduction { get; set; } + public double MicroAccuracy { get; set; } - public double LogLossReductionStdErr { get; set; } + public double MicroAccuracyStdErr { get; set; } - public double TopKAccuracy { get; set; } + public double LogLoss { get; set; } - public double TopKAccuracyStdErr { get; set; } + public double LogLossStdErr { get; set; } - public double[] PerClassLogLoss { get; set; } + public double LogLossReduction { get; set; } - public double[] PerClassLogLossStdErr { get; set; } - } + public double LogLossReductionStdErr { get; set; } - internal class RegressionMetrics - { - public string FeatureName { get; set; } + public double TopKAccuracy { get; set; } - public double MeanAbsoluteError { get; set; } + public double TopKAccuracyStdErr { get; set; } - public double MeanAbsoluteErrorStdErr { get; set; } + public double[] PerClassLogLoss { get; set; } - public double MeanSquaredError { get; set; } + public double[] PerClassLogLossStdErr { get; set; } + } - public double MeanSquaredErrorStdErr { get; set; } + private class RegressionMetrics + { + public string FeatureName { get; set; } - public double RootMeanSquaredError { get; set; } + public double MeanAbsoluteError { get; set; } - public double RootMeanSquaredErrorStdErr { get; set; } + public double MeanAbsoluteErrorStdErr { get; set; } - public double LossFunction { get; set; } + public double MeanSquaredError { get; set; } - public double LossFunctionStdErr { get; set; } + public double MeanSquaredErrorStdErr { get; set; } - public double RSquared { get; set; } + public double RootMeanSquaredError { get; set; } - public double RSquaredStdErr { get; set; } - } + public double RootMeanSquaredErrorStdErr { get; set; } - internal class RankingMetrics - { - public string FeatureName { get; set; } + public double LossFunction { get; set; } - public double[] DiscountedCumulativeGains { get; set; } + public double LossFunctionStdErr { get; set; } - public double[] DiscountedCumulativeGainsStdErr { get; set; } + public double RSquared { get; set; } - public double[] NormalizedDiscountedCumulativeGains { get; set; } + public double RSquaredStdErr { get; set; } + } - public double[] NormalizedDiscountedCumulativeGainsStdErr { get; set; } + private class RankingMetrics + { + public string FeatureName { get; set; } + + public double[] DiscountedCumulativeGains { get; set; } + + public double[] DiscountedCumulativeGainsStdErr { get; set; } + + public double[] NormalizedDiscountedCumulativeGains { get; set; } + + public double[] NormalizedDiscountedCumulativeGainsStdErr { get; set; } + } } } From ad491eaa60117b9958a5d49bd420278cd0c83c97 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Mon, 23 Sep 2019 11:16:36 -0700 Subject: [PATCH 06/15] Remove MLContext from EntryPoint --- .../PermutationFeatureImportance.cs | 63 ++++++++++--------- .../PermutationFeatureImportanceExtensions.cs | 4 +- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index 01c124b0f3..d0b0b6f52c 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -26,9 +26,8 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - var mlContext = new MLContext(); - - var model = mlContext.Model.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema); + var modelOps = new ModelOperationsCatalog(env); + var model = modelOps.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema); var chain = model as TransformerChain; var predictor = chain.LastTransformer as ISingleFeaturePredictionTransformer; Contracts.Assert(!(predictor is null), "The last transformer in the model is not a predictor, or Permutation " + @@ -36,9 +35,7 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH "predictor, as PFI is calculated for a predictor model."); var transformedData = model.Transform(input.Data); - - IDataView result = PermutationFeatureImportanceUtils.GetMetrics(mlContext, predictor, transformedData, input); - + IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, transformedData, input); return new PermutationFeatureImportanceOutput { Metrics = result }; } } @@ -73,20 +70,20 @@ internal sealed class PermutationFeatureImportanceArguments : TransformInputBase internal static class PermutationFeatureImportanceUtils { internal static IDataView GetMetrics( - MLContext mlContext, + IHostEnvironment env, ISingleFeaturePredictionTransformer predictor, IDataView data, PermutationFeatureImportanceArguments input) { IDataView result; if (predictor is BinaryPredictionTransformer>) - result = GetBinaryMetrics(mlContext, predictor, data, input); + result = GetBinaryMetrics(env, predictor, data, input); else if (predictor is MulticlassPredictionTransformer>>) - result = GetMulticlassMetrics(mlContext, predictor, data, input); + result = GetMulticlassMetrics(env, predictor, data, input); else if (predictor is RegressionPredictionTransformer>) - result = GetRegressionMetrics(mlContext, predictor, data, input); + result = GetRegressionMetrics(env, predictor, data, input); else if (predictor is RankingPredictionTransformer>) - result = GetRankingMetrics(mlContext, predictor, data, input); + result = GetRankingMetrics(env, predictor, data, input); else throw Contracts.Except( "Unsupported predictor type. Predictor must be binary classifier, " + @@ -96,14 +93,13 @@ internal static IDataView GetMetrics( } private static IDataView GetBinaryMetrics( - MLContext mlContext, + IHostEnvironment env, ISingleFeaturePredictionTransformer predictor, IDataView data, PermutationFeatureImportanceArguments input) { - var slotNames = GetSlotNames(data); - - var permutationMetrics = mlContext.BinaryClassification + var binaryCatalog = new BinaryClassificationCatalog(env); + var permutationMetrics = binaryCatalog .PermutationFeatureImportance(predictor, data, labelColumnName: input.LabelColumnName, @@ -111,6 +107,7 @@ private static IDataView GetBinaryMetrics( numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); + var slotNames = GetSlotNames(data); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -140,19 +137,19 @@ private static IDataView GetBinaryMetrics( }); } - var result = mlContext.Data.LoadFromEnumerable(metrics); + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics); return result; } private static IDataView GetMulticlassMetrics( - MLContext mlContext, + IHostEnvironment env, ISingleFeaturePredictionTransformer predictor, IDataView data, PermutationFeatureImportanceArguments input) { - var slotNames = GetSlotNames(data); - - var permutationMetrics = mlContext.MulticlassClassification + var multiclassCatalog = new MulticlassClassificationCatalog(env); + var permutationMetrics = multiclassCatalog .PermutationFeatureImportance(predictor, data, labelColumnName: input.LabelColumnName, @@ -160,6 +157,7 @@ private static IDataView GetMulticlassMetrics( numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); + var slotNames = GetSlotNames(data); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -191,19 +189,19 @@ private static IDataView GetMulticlassMetrics( ConvertVectorToKnownSize(nameof(metric.PerClassLogLoss), metric.PerClassLogLoss.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema); - var result = mlContext.Data.LoadFromEnumerable(metrics, schema); + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics); return result; } private static IDataView GetRegressionMetrics( - MLContext mlContext, + IHostEnvironment env, ISingleFeaturePredictionTransformer predictor, IDataView data, PermutationFeatureImportanceArguments input) { - var slotNames = GetSlotNames(data); - - var permutationMetrics = mlContext.Regression + var regressionCatalog = new RegressionCatalog(env); + var permutationMetrics = regressionCatalog .PermutationFeatureImportance(predictor, data, labelColumnName: input.LabelColumnName, @@ -211,6 +209,7 @@ private static IDataView GetRegressionMetrics( numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); + var slotNames = GetSlotNames(data); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -234,19 +233,19 @@ private static IDataView GetRegressionMetrics( }); } - var result = mlContext.Data.LoadFromEnumerable(metrics); + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics); return result; } private static IDataView GetRankingMetrics( - MLContext mlContext, + IHostEnvironment env, ISingleFeaturePredictionTransformer predictor, IDataView data, PermutationFeatureImportanceArguments input) { - var slotNames = GetSlotNames(data); - - var permutationMetrics = mlContext.Ranking + var rankingCatalog = new RankingCatalog(env); + var permutationMetrics = rankingCatalog .PermutationFeatureImportance(predictor, data, labelColumnName: input.LabelColumnName, @@ -255,6 +254,7 @@ private static IDataView GetRankingMetrics( numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); + var slotNames = GetSlotNames(data); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -280,7 +280,8 @@ private static IDataView GetRankingMetrics( ConvertVectorToKnownSize(nameof(metric.DiscountedCumulativeGainsStdErr), metric.DiscountedCumulativeGainsStdErr.Length, ref schema); ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGainsStdErr), metric.NormalizedDiscountedCumulativeGainsStdErr.Length, ref schema); - var result = mlContext.Data.LoadFromEnumerable(metrics, schema); + var dataOps = new DataOperationsCatalog(env); + var result = dataOps.LoadFromEnumerable(metrics); return result; } diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index a860a3023b..07b9c8f435 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -206,7 +206,7 @@ private static BinaryClassificationMetrics BinaryClassifierDelta( /// ]]> /// /// - /// The clustering catalog. + /// The multiclass classification catalog. /// The model on which to evaluate feature importance. /// The evaluation data set. /// Label column name. The column data must be . @@ -291,7 +291,7 @@ private static MulticlassClassificationMetrics MulticlassClassificationDelta( /// ]]> /// /// - /// The clustering catalog. + /// The ranking catalog. /// The model on which to evaluate feature importance. /// The evaluation data set. /// Label column name. The column data must be or . From 626914c8130edb3294f2d6665e2d9f9a660f6aab Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Mon, 23 Sep 2019 19:28:05 -0700 Subject: [PATCH 07/15] Use last predictor in the model if model.LastTransformer is not a predictor --- .../DataLoadSave/TransformerChain.cs | 21 +++++ .../PermutationFeatureImportance.cs | 15 +-- .../UnitTests/TestEntryPoints.cs | 92 +++++++++++++++++++ 3 files changed, 122 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index f9f7a1f413..09c2e55c9d 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -128,6 +128,27 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) return s; } + [BestFriend] + internal TransformerChain RewindToLastPredictionTransformer() + { + int lastPredictorIndex = 0; + for (int i = _transformers.Length; i > 0; i--) + { + var current = _transformers[i - 1]; + if (current is IPredictionTransformer> || + current is IPredictionTransformer>>) + { + lastPredictorIndex = i; + break; + } + } + + Contracts.Assert(lastPredictorIndex != 0, "No predictor found in the model."); + + var predictorChain = _transformers.Take(lastPredictorIndex).ToArray(); + return new TransformerChain(predictorChain); + } + public IDataView Transform(IDataView input) { Contracts.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index d0b0b6f52c..2ea6f2f1bb 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -27,12 +27,15 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH EntryPointUtils.CheckInputArgs(host, input); var modelOps = new ModelOperationsCatalog(env); - var model = modelOps.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema); - var chain = model as TransformerChain; - var predictor = chain.LastTransformer as ISingleFeaturePredictionTransformer; - Contracts.Assert(!(predictor is null), "The last transformer in the model is not a predictor, or Permutation " + - "Feature Importance (PFI) is not supported for the predictor. The last transformer in the model must be a " + - "predictor, as PFI is calculated for a predictor model."); + var model = modelOps.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema) as TransformerChain; + + // If model.LastTransformer is not an IPredictionTransformer, get the part of the TransformerChain + // up to the last ITransformer that is indeed an IPredictionTransformer. This piece of the TransformerChain + // is used to extract the IPredictionTransformer and also to transform the input data. + // Will throw if there is no IPredictionTransformer in the TransformerChain. + model = model.RewindToLastPredictionTransformer(); + var predictor = model.LastTransformer as ISingleFeaturePredictionTransformer; + Contracts.Assert(predictor != null, "Permutation Feature Importance (PFI) is not supported for the predictor."); var transformedData = model.Transform(input.Data); IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, transformedData, input); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index f5b5025065..c481f673f5 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -970,6 +970,98 @@ public void MulticlassPermutationFeatureImportance() Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLossStdErr")); } + [Fact] + public void MulticlassPermutationFeatureImportanceWithKeyToValue() + { + var dataPath = GetDataPath("adult.tiny.with-schema.txt"); + var modelPath = DeleteOutputPath("model.zip"); + var outputDataPath = DeleteOutputPath("metrics.idv"); + + var mlContext = new MLContext(); + + var data = new TextLoader(mlContext, + new TextLoader.Options() + { + AllowQuoting = true, + Separator = "\t", + HasHeader = true, + Columns = new[] + { + new TextLoader.Column("Label", DataKind.String, 0), + new TextLoader.Column("education", DataKind.String, 2), + new TextLoader.Column("age", DataKind.Single, 9) + } + }).Load(dataPath); + + var pipeline = mlContext.Transforms.Categorical.OneHotEncoding("education") + .Append(mlContext.Transforms.Concatenate("Features", new[] { "education", "age" })) + .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) + .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(maximumNumberOfIterations: 1)) + .Append(mlContext.Transforms.Conversion.MapKeyToValue("Label")); + + var model = pipeline.Fit(data); + + using (FileStream stream = new FileStream(modelPath, FileMode.Create)) + mlContext.Model.Save(model, data.Schema, stream); + + string pfiGraph = string.Format(@" + {{ + 'Inputs': {{ + 'file1': '{0}', + 'file2': '{1}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:TX:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:R4:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file1' + }}, + 'Outputs': {{ + 'Data': '$data' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'ModelPath': '$file2', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} + }} + ], + 'Outputs': {{ + 'output_data': '{2}' + }} + }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + + var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); + File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + + var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; + var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); + pfiCmd.Run(); + + var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); + + Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReduction")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracy")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLoss")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("LogLossReductionStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("TopKAccuracyStdErr")); + Assert.NotNull(loadedData.Schema.GetColumnOrNull("PerClassLogLossStdErr")); + } + [Fact] public void RegressionPermutationFeatureImportance() { From c726e21fdf5c47b913074cd452200e7aa8a26066 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Tue, 24 Sep 2019 13:11:49 -0700 Subject: [PATCH 08/15] nit --- src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs | 6 ++---- .../PermutationFeatureImportance.cs | 5 +++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 09c2e55c9d..97681df42f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -135,16 +135,14 @@ internal TransformerChain RewindToLastPredictionTransformer() for (int i = _transformers.Length; i > 0; i--) { var current = _transformers[i - 1]; - if (current is IPredictionTransformer> || - current is IPredictionTransformer>>) + if (current is IPredictionTransformer> + || current is IPredictionTransformer>>) { lastPredictorIndex = i; break; } } - Contracts.Assert(lastPredictorIndex != 0, "No predictor found in the model."); - var predictorChain = _transformers.Take(lastPredictorIndex).ToArray(); return new TransformerChain(predictorChain); } diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index 2ea6f2f1bb..4b72383797 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -28,12 +28,13 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH var modelOps = new ModelOperationsCatalog(env); var model = modelOps.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema) as TransformerChain; - // If model.LastTransformer is not an IPredictionTransformer, get the part of the TransformerChain // up to the last ITransformer that is indeed an IPredictionTransformer. This piece of the TransformerChain // is used to extract the IPredictionTransformer and also to transform the input data. // Will throw if there is no IPredictionTransformer in the TransformerChain. - model = model.RewindToLastPredictionTransformer(); + if (!(model.LastTransformer is IPredictionTransformer> + || model.LastTransformer is IPredictionTransformer>>)) + model = model.RewindToLastPredictionTransformer(); var predictor = model.LastTransformer as ISingleFeaturePredictionTransformer; Contracts.Assert(predictor != null, "Permutation Feature Importance (PFI) is not supported for the predictor."); From 54e51b1890f17b2d9e5afb6e6202732b7c5f6ad2 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Tue, 24 Sep 2019 14:58:37 -0700 Subject: [PATCH 09/15] Model file path conflicts in tests --- .../UnitTests/TestEntryPoints.cs | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index c481f673f5..b6d4d90920 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -616,8 +616,8 @@ public void EntryPointExecGraphCommand() public void BinaryPermutationFeatureImportance() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("model.zip"); - var outputDataPath = DeleteOutputPath("metrics.idv"); + var modelPath = DeleteOutputPath("binary_pfi_model.zip"); + var outputDataPath = DeleteOutputPath("binary_pfi_metrics.idv"); string trainingGraph = string.Format(@" {{ @@ -797,8 +797,8 @@ public void BinaryPermutationFeatureImportance() public void MulticlassPermutationFeatureImportance() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("model.zip"); - var outputDataPath = DeleteOutputPath("metrics.idv"); + var modelPath = DeleteOutputPath("mc_pfi_model.zip"); + var outputDataPath = DeleteOutputPath("mc_pfi_metrics.idv"); string trainingGraph = string.Format(@" {{ @@ -974,8 +974,8 @@ public void MulticlassPermutationFeatureImportance() public void MulticlassPermutationFeatureImportanceWithKeyToValue() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("model.zip"); - var outputDataPath = DeleteOutputPath("metrics.idv"); + var modelPath = DeleteOutputPath("mc_ktv_pfi_model.zip"); + var outputDataPath = DeleteOutputPath("mc_ktv_pfi_metrics.idv"); var mlContext = new MLContext(); @@ -1066,8 +1066,8 @@ public void MulticlassPermutationFeatureImportanceWithKeyToValue() public void RegressionPermutationFeatureImportance() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("model.zip"); - var outputDataPath = DeleteOutputPath("metrics.idv"); + var modelPath = DeleteOutputPath("reg_pfi_model.zip"); + var outputDataPath = DeleteOutputPath("reg_pfi_metrics.idv"); string trainingGraph = string.Format(@" {{ @@ -1240,8 +1240,8 @@ public void RegressionPermutationFeatureImportance() public void RankingPermutationFeatureImportance() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("model.zip"); - var outputDataPath = DeleteOutputPath("metrics.idv"); + var modelPath = DeleteOutputPath("rank_pfi_model.zip"); + var outputDataPath = DeleteOutputPath("rank_pfi_metrics.idv"); string trainingGraph = string.Format(@" {{ @@ -1415,7 +1415,7 @@ public void RankingPermutationFeatureImportance() public void ScoreTransformerChainModel() { var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); - var modelPath = DeleteOutputPath("model.zip"); + var modelPath = DeleteOutputPath("score_model.zip"); var outputDataPath = DeleteOutputPath("scored.idv"); var mlContext = new MLContext(); From d9f522f5af853984c52c16fad590e777e4ffd948 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Tue, 24 Sep 2019 18:27:48 -0700 Subject: [PATCH 10/15] nit --- src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index 4b72383797..a34fff3187 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -194,7 +194,7 @@ private static IDataView GetMulticlassMetrics( ConvertVectorToKnownSize(nameof(metric.PerClassLogLossStdErr), metric.PerClassLogLossStdErr.Length, ref schema); var dataOps = new DataOperationsCatalog(env); - var result = dataOps.LoadFromEnumerable(metrics); + var result = dataOps.LoadFromEnumerable(metrics, schema); return result; } @@ -285,7 +285,7 @@ private static IDataView GetRankingMetrics( ConvertVectorToKnownSize(nameof(metric.NormalizedDiscountedCumulativeGainsStdErr), metric.NormalizedDiscountedCumulativeGainsStdErr.Length, ref schema); var dataOps = new DataOperationsCatalog(env); - var result = dataOps.LoadFromEnumerable(metrics); + var result = dataOps.LoadFromEnumerable(metrics, schema); return result; } From ab0893b895e2be247f94b137b83a181dc6d53bc4 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Wed, 25 Sep 2019 14:16:26 -0700 Subject: [PATCH 11/15] PR Feedback --- src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 97681df42f..a430f704d7 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -131,19 +131,19 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) [BestFriend] internal TransformerChain RewindToLastPredictionTransformer() { - int lastPredictorIndex = 0; + int numTransformersToTake = 0; for (int i = _transformers.Length; i > 0; i--) { var current = _transformers[i - 1]; if (current is IPredictionTransformer> || current is IPredictionTransformer>>) { - lastPredictorIndex = i; + numTransformersToTake = i; break; } } - Contracts.Assert(lastPredictorIndex != 0, "No predictor found in the model."); - var predictorChain = _transformers.Take(lastPredictorIndex).ToArray(); + Contracts.Check(numTransformersToTake != 0, "No predictor found in the model."); + var predictorChain = _transformers.Take(numTransformersToTake).ToArray(); return new TransformerChain(predictorChain); } From 32eec60240f41bb56b9cc65477e6dc64aafa68d9 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Thu, 26 Sep 2019 17:08:06 -0700 Subject: [PATCH 12/15] Pass in model as PredictorModel --- .../DataLoadSave/TransformerChain.cs | 19 - .../PermutationFeatureImportance.cs | 89 ++-- .../UnitTests/TestEntryPoints.cs | 417 ++++++++---------- 3 files changed, 228 insertions(+), 297 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index a430f704d7..f9f7a1f413 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -128,25 +128,6 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) return s; } - [BestFriend] - internal TransformerChain RewindToLastPredictionTransformer() - { - int numTransformersToTake = 0; - for (int i = _transformers.Length; i > 0; i--) - { - var current = _transformers[i - 1]; - if (current is IPredictionTransformer> - || current is IPredictionTransformer>>) - { - numTransformersToTake = i; - break; - } - } - Contracts.Check(numTransformersToTake != 0, "No predictor found in the model."); - var predictorChain = _transformers.Take(numTransformersToTake).ToArray(); - return new TransformerChain(predictorChain); - } - public IDataView Transform(IDataView input) { Contracts.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index a34fff3187..40781cc30b 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -26,20 +26,10 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); - var modelOps = new ModelOperationsCatalog(env); - var model = modelOps.Load(input.ModelPath.OpenReadStream(), out DataViewSchema schema) as TransformerChain; - // If model.LastTransformer is not an IPredictionTransformer, get the part of the TransformerChain - // up to the last ITransformer that is indeed an IPredictionTransformer. This piece of the TransformerChain - // is used to extract the IPredictionTransformer and also to transform the input data. - // Will throw if there is no IPredictionTransformer in the TransformerChain. - if (!(model.LastTransformer is IPredictionTransformer> - || model.LastTransformer is IPredictionTransformer>>)) - model = model.RewindToLastPredictionTransformer(); - var predictor = model.LastTransformer as ISingleFeaturePredictionTransformer; - Contracts.Assert(predictor != null, "Permutation Feature Importance (PFI) is not supported for the predictor."); - - var transformedData = model.Transform(input.Data); - IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, transformedData, input); + input.PredictorModel.PrepareData(env, input.Data, out RoleMappedData roleMappedData, out IPredictor predictor); + Contracts.Assert(predictor != null, "No predictor found in model"); + var transformedData = input.PredictorModel.TransformModel.Apply(env, input.Data); + IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, roleMappedData.Schema, transformedData, input); return new PermutationFeatureImportanceOutput { Metrics = result }; } } @@ -53,7 +43,7 @@ internal sealed class PermutationFeatureImportanceOutput internal sealed class PermutationFeatureImportanceArguments : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The path to the model file", ShortName = "path", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public IFileHandle ModelPath; + public PredictorModel PredictorModel; [Argument(ArgumentType.AtMostOnce, HelpText = "Label column name", ShortName = "label", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public string LabelColumnName = "Label"; @@ -75,19 +65,20 @@ internal static class PermutationFeatureImportanceUtils { internal static IDataView GetMetrics( IHostEnvironment env, - ISingleFeaturePredictionTransformer predictor, + IPredictor predictor, + RoleMappedSchema roleMappedSchema, IDataView data, PermutationFeatureImportanceArguments input) { IDataView result; - if (predictor is BinaryPredictionTransformer>) - result = GetBinaryMetrics(env, predictor, data, input); - else if (predictor is MulticlassPredictionTransformer>>) - result = GetMulticlassMetrics(env, predictor, data, input); - else if (predictor is RegressionPredictionTransformer>) - result = GetRegressionMetrics(env, predictor, data, input); - else if (predictor is RankingPredictionTransformer>) - result = GetRankingMetrics(env, predictor, data, input); + if (predictor.PredictionKind == PredictionKind.BinaryClassification) + result = GetBinaryMetrics(env, predictor, roleMappedSchema, data, input); + else if (predictor.PredictionKind == PredictionKind.MulticlassClassification) + result = GetMulticlassMetrics(env, predictor, roleMappedSchema, data, input); + else if (predictor.PredictionKind == PredictionKind.Regression) + result = GetRegressionMetrics(env, predictor, roleMappedSchema, data, input); + else if (predictor.PredictionKind == PredictionKind.Ranking) + result = GetRankingMetrics(env, predictor, roleMappedSchema, data, input); else throw Contracts.Except( "Unsupported predictor type. Predictor must be binary classifier, " + @@ -98,13 +89,17 @@ internal static IDataView GetMetrics( private static IDataView GetBinaryMetrics( IHostEnvironment env, - ISingleFeaturePredictionTransformer predictor, + IPredictor predictor, + RoleMappedSchema roleMappedSchema, IDataView data, PermutationFeatureImportanceArguments input) { + var pred = new BinaryPredictionTransformer>( + env, predictor as IPredictorProducing, data.Schema, + roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); var binaryCatalog = new BinaryClassificationCatalog(env); var permutationMetrics = binaryCatalog - .PermutationFeatureImportance(predictor, + .PermutationFeatureImportance(pred, data, labelColumnName: input.LabelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, @@ -118,6 +113,8 @@ private static IDataView GetBinaryMetrics( List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; var pMetric = permutationMetrics[i]; metrics.Add(new BinaryMetrics { @@ -148,13 +145,18 @@ private static IDataView GetBinaryMetrics( private static IDataView GetMulticlassMetrics( IHostEnvironment env, - ISingleFeaturePredictionTransformer predictor, + IPredictor predictor, + RoleMappedSchema roleMappedSchema, IDataView data, PermutationFeatureImportanceArguments input) { + var pred = new MulticlassPredictionTransformer>>( + env, predictor as IPredictorProducing>, data.Schema, + roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value, + roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value); var multiclassCatalog = new MulticlassClassificationCatalog(env); var permutationMetrics = multiclassCatalog - .PermutationFeatureImportance(predictor, + .PermutationFeatureImportance(pred, data, labelColumnName: input.LabelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, @@ -168,6 +170,8 @@ private static IDataView GetMulticlassMetrics( List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; var pMetric = permutationMetrics[i]; metrics.Add(new MulticlassMetrics { @@ -200,13 +204,17 @@ private static IDataView GetMulticlassMetrics( private static IDataView GetRegressionMetrics( IHostEnvironment env, - ISingleFeaturePredictionTransformer predictor, + IPredictor predictor, + RoleMappedSchema roleMappedSchema, IDataView data, PermutationFeatureImportanceArguments input) { + var pred = new RegressionPredictionTransformer>( + env, predictor as IPredictorProducing, data.Schema, + roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); var regressionCatalog = new RegressionCatalog(env); var permutationMetrics = regressionCatalog - .PermutationFeatureImportance(predictor, + .PermutationFeatureImportance(pred, data, labelColumnName: input.LabelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, @@ -220,6 +228,8 @@ private static IDataView GetRegressionMetrics( List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; var pMetric = permutationMetrics[i]; metrics.Add(new RegressionMetrics { @@ -244,13 +254,17 @@ private static IDataView GetRegressionMetrics( private static IDataView GetRankingMetrics( IHostEnvironment env, - ISingleFeaturePredictionTransformer predictor, + IPredictor predictor, + RoleMappedSchema roleMappedSchema, IDataView data, PermutationFeatureImportanceArguments input) { + var pred = new RankingPredictionTransformer>( + env, predictor as IPredictorProducing, data.Schema, + roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); var rankingCatalog = new RankingCatalog(env); var permutationMetrics = rankingCatalog - .PermutationFeatureImportance(predictor, + .PermutationFeatureImportance(pred, data, labelColumnName: input.LabelColumnName, rowGroupColumnName: input.RowGroupColumnName, @@ -265,6 +279,8 @@ private static IDataView GetRankingMetrics( List metrics = new List(); for (int i = 0; i < permutationMetrics.Length; i++) { + if (string.IsNullOrWhiteSpace(slotNames[i])) + continue; var pMetric = permutationMetrics[i]; metrics.Add(new RankingMetrics { @@ -293,15 +309,12 @@ private static string[] GetSlotNames(IDataView data) { VBuffer> slots = default; data.Schema["Features"].GetSlotNames(ref slots); - - var column = data.GetColumn>( - data.Schema["Features"]); + var slotValues = slots.DenseValues(); List slotNames = new List(); - - foreach (var item in column.First>().Items(all: true)) + foreach (var value in slotValues) { - slotNames.Add(slots.GetValues()[item.Key].ToString()); + slotNames.Add(value.ToString()); }; return slotNames.ToArray(); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index b6d4d90920..63179e4f41 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -615,11 +615,10 @@ public void EntryPointExecGraphCommand() [Fact] public void BinaryPermutationFeatureImportance() { - var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("binary_pfi_model.zip"); + var inputDataPath = GetDataPath("adult.tiny.with-schema.txt"); var outputDataPath = DeleteOutputPath("binary_pfi_metrics.idv"); - string trainingGraph = string.Format(@" + string inputGraph = string.Format(@" {{ 'Inputs': {{ 'file': '{0}' @@ -643,7 +642,7 @@ public void BinaryPermutationFeatureImportance() 'Source': 'education' }} ], - 'Data': '$data', + 'Data': '$data' }}, 'Name': 'Transforms.CategoricalOneHotVectorizer', 'Outputs': {{ @@ -715,42 +714,12 @@ public void BinaryPermutationFeatureImportance() 'Outputs': {{ 'PredictorModel': '$output_model' }} - }} - ], - 'Outputs': {{ - 'output_model': '{1}' - }} - }}", EscapePath(dataPath), EscapePath(modelPath)); - - var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); - File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); - - var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; - var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); - trainingCmd.Run(); - - string pfiGraph = string.Format(@" - {{ - 'Inputs': {{ - 'file1': '{0}', - 'file2': '{1}' - }}, - 'Nodes': [ - {{ - 'Name': 'Data.CustomTextLoader', - 'Inputs': {{ - 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', - 'InputFile': '$file1' - }}, - 'Outputs': {{ - 'Data': '$data' - }} }}, {{ 'Name': 'Transforms.PermutationFeatureImportance', 'Inputs': {{ 'Data': '$data', - 'ModelPath': '$file2', + 'PredictorModel': '$output_model', 'PermutationCount': 5 }}, 'Outputs': {{ @@ -759,21 +728,19 @@ public void BinaryPermutationFeatureImportance() }} ], 'Outputs': {{ - 'output_data': '{2}' + 'output_data': '{1}' }} - }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + }}", EscapePath(inputDataPath), EscapePath(outputDataPath)); - var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); - File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); - var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; - var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); - pfiCmd.Run(); + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); var mlContext = new MLContext(); - var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); - Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("AreaUnderRocCurve")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("Accuracy")); @@ -796,11 +763,10 @@ public void BinaryPermutationFeatureImportance() [Fact] public void MulticlassPermutationFeatureImportance() { - var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("mc_pfi_model.zip"); + var inputDataPath = GetDataPath("adult.tiny.with-schema.txt"); var outputDataPath = DeleteOutputPath("mc_pfi_metrics.idv"); - string trainingGraph = string.Format(@" + string inputGraph = string.Format(@" {{ 'Inputs': {{ 'file': '{0}' @@ -896,42 +862,12 @@ public void MulticlassPermutationFeatureImportance() 'Outputs': {{ 'PredictorModel': '$output_model' }} - }} - ], - 'Outputs': {{ - 'output_model': '{1}' - }} - }}", EscapePath(dataPath), EscapePath(modelPath)); - - var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); - File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); - - var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; - var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); - trainingCmd.Run(); - - string pfiGraph = string.Format(@" - {{ - 'Inputs': {{ - 'file1': '{0}', - 'file2': '{1}' - }}, - 'Nodes': [ - {{ - 'Name': 'Data.CustomTextLoader', - 'Inputs': {{ - 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', - 'InputFile': '$file1' - }}, - 'Outputs': {{ - 'Data': '$data' - }} }}, {{ 'Name': 'Transforms.PermutationFeatureImportance', 'Inputs': {{ 'Data': '$data', - 'ModelPath': '$file2', + 'PredictorModel': '$output_model', 'PermutationCount': 5 }}, 'Outputs': {{ @@ -940,21 +876,19 @@ public void MulticlassPermutationFeatureImportance() }} ], 'Outputs': {{ - 'output_data': '{2}' + 'output_data': '{1}' }} - }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + }}", EscapePath(inputDataPath), EscapePath(outputDataPath)); - var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); - File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); - var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; - var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); - pfiCmd.Run(); + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); var mlContext = new MLContext(); - var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); - Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracy")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracy")); @@ -973,80 +907,149 @@ public void MulticlassPermutationFeatureImportance() [Fact] public void MulticlassPermutationFeatureImportanceWithKeyToValue() { - var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("mc_ktv_pfi_model.zip"); + var inputData = GetDataPath("adult.tiny.with-schema.txt"); var outputDataPath = DeleteOutputPath("mc_ktv_pfi_metrics.idv"); - var mlContext = new MLContext(); - - var data = new TextLoader(mlContext, - new TextLoader.Options() - { - AllowQuoting = true, - Separator = "\t", - HasHeader = true, - Columns = new[] - { - new TextLoader.Column("Label", DataKind.String, 0), - new TextLoader.Column("education", DataKind.String, 2), - new TextLoader.Column("age", DataKind.Single, 9) - } - }).Load(dataPath); - - var pipeline = mlContext.Transforms.Categorical.OneHotEncoding("education") - .Append(mlContext.Transforms.Concatenate("Features", new[] { "education", "age" })) - .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")) - .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(maximumNumberOfIterations: 1)) - .Append(mlContext.Transforms.Conversion.MapKeyToValue("Label")); - - var model = pipeline.Fit(data); - - using (FileStream stream = new FileStream(modelPath, FileMode.Create)) - mlContext.Model.Save(model, data.Schema, stream); - - string pfiGraph = string.Format(@" - {{ - 'Inputs': {{ - 'file1': '{0}', - 'file2': '{1}' - }}, - 'Nodes': [ + string inputGraph = string.Format(@" {{ - 'Name': 'Data.CustomTextLoader', - 'Inputs': {{ - 'CustomSchema': 'col=Label:TX:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:R4:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', - 'InputFile': '$file1' + 'Inputs': {{ + 'file': '{0}' + }}, + 'Nodes': [ + {{ + 'Name': 'Data.CustomTextLoader', + 'Inputs': {{ + 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', + 'InputFile': '$file' + }}, + 'Outputs': {{ + 'Data': '$data' + }} }}, - 'Outputs': {{ - 'Data': '$data' - }} - }}, - {{ - 'Name': 'Transforms.PermutationFeatureImportance', - 'Inputs': {{ - 'Data': '$data', - 'ModelPath': '$file2', - 'PermutationCount': 5 + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'education', + 'Source': 'education' + }} + ], + 'Data': '$data', + }}, + 'Name': 'Transforms.CategoricalOneHotVectorizer', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} }}, - 'Outputs': {{ - 'Metrics': '$output_data' + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$output_data3' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data3', + 'Features': [ + 'age', + 'education' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$output_data4' + }} + }}, + {{ + 'Inputs': {{ + 'TrainingData': '$output_data4', + 'NumThreads': 1, + 'MaxIterations': 1 + }}, + 'Name': 'Trainers.StochasticDualCoordinateAscentClassifier', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data4', + 'Column': [ + {{ + 'Name': 'Label', + 'Source': 'Label' + }} + ], + }}, + 'Name': 'Transforms.TextToKeyConverter', + 'Outputs': {{ + 'Model': '$output_model5', + 'OutputData': '$output_data5' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$predictor_model', + 'TransformModels': [ + '$output_model1', + '$output_model2', + '$output_model3', + '$output_model4', + '$output_model5' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$output_model' + }} + }}, + {{ + 'Name': 'Transforms.PermutationFeatureImportance', + 'Inputs': {{ + 'Data': '$data', + 'PredictorModel': '$output_model', + 'PermutationCount': 5 + }}, + 'Outputs': {{ + 'Metrics': '$output_data' + }} }} - }} - ], - 'Outputs': {{ - 'output_data': '{2}' - }} - }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + ], + 'Outputs': {{ + 'output_data': '{1}' + }} + }}", EscapePath(inputData), EscapePath(outputDataPath)); - var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); - File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); - var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; - var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); - pfiCmd.Run(); + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + var mlContext = new MLContext(); var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); - Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("MacroAccuracy")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("MicroAccuracy")); @@ -1065,11 +1068,10 @@ public void MulticlassPermutationFeatureImportanceWithKeyToValue() [Fact] public void RegressionPermutationFeatureImportance() { - var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("reg_pfi_model.zip"); + var inputDataPath = GetDataPath("adult.tiny.with-schema.txt"); var outputDataPath = DeleteOutputPath("reg_pfi_metrics.idv"); - string trainingGraph = string.Format(@" + string inputGraph = string.Format(@" {{ 'Inputs': {{ 'file': '{0}' @@ -1164,42 +1166,12 @@ public void RegressionPermutationFeatureImportance() 'Outputs': {{ 'PredictorModel': '$output_model' }} - }} - ], - 'Outputs': {{ - 'output_model': '{1}' - }} - }}", EscapePath(dataPath), EscapePath(modelPath)); - - var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); - File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); - - var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; - var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); - trainingCmd.Run(); - - string pfiGraph = string.Format(@" - {{ - 'Inputs': {{ - 'file1': '{0}', - 'file2': '{1}' - }}, - 'Nodes': [ - {{ - 'Name': 'Data.CustomTextLoader', - 'Inputs': {{ - 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', - 'InputFile': '$file1' - }}, - 'Outputs': {{ - 'Data': '$data' - }} }}, {{ 'Name': 'Transforms.PermutationFeatureImportance', 'Inputs': {{ 'Data': '$data', - 'ModelPath': '$file2', + 'PredictorModel': '$output_model', 'PermutationCount': 5 }}, 'Outputs': {{ @@ -1208,21 +1180,19 @@ public void RegressionPermutationFeatureImportance() }} ], 'Outputs': {{ - 'output_data': '{2}' + 'output_data': '{1}' }} - }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); - - var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); - File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + }}", EscapePath(inputDataPath), EscapePath(outputDataPath)); - var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; - var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); - pfiCmd.Run(); + var jsonPath = DeleteOutputPath("graph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); + var mlContext = new MLContext(); - var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); - Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanAbsoluteError")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("MeanSquaredError")); @@ -1239,11 +1209,10 @@ public void RegressionPermutationFeatureImportance() [Fact] public void RankingPermutationFeatureImportance() { - var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - var modelPath = DeleteOutputPath("rank_pfi_model.zip"); + var inputData = GetDataPath("adult.tiny.with-schema.txt"); var outputDataPath = DeleteOutputPath("rank_pfi_metrics.idv"); - string trainingGraph = string.Format(@" + string inputGraph = string.Format(@" {{ 'Inputs': {{ 'file': '{0}' @@ -1344,43 +1313,13 @@ public void RankingPermutationFeatureImportance() 'Outputs': {{ 'PredictorModel': '$output_model' }} - }} - ], - 'Outputs': {{ - 'output_model': '{1}' - }} - }}", EscapePath(dataPath), EscapePath(modelPath)); - - var trainingJsonPath = DeleteOutputPath("trainingGraph.json"); - File.WriteAllLines(trainingJsonPath, new[] { trainingGraph }); - - var trainingArgs = new ExecuteGraphCommand.Arguments() { GraphPath = trainingJsonPath }; - var trainingCmd = new ExecuteGraphCommand(Env, trainingArgs); - trainingCmd.Run(); - - string pfiGraph = string.Format(@" - {{ - 'Inputs': {{ - 'file1': '{0}', - 'file2': '{1}' - }}, - 'Nodes': [ - {{ - 'Name': 'Data.CustomTextLoader', - 'Inputs': {{ - 'CustomSchema': 'col=Label:I8:0 col=Workclass:TX:1 col=education:TX:2 col=marital-status:TX:3 col=occupation:TX:4 col=relationship:TX:5 col=ethnicity:TX:6 col=sex:TX:7 col=native-country-region:TX:8 col=age:I8:9 col=fnlwgt:I8:10 col=education-num:I8:11 col=capital-gain:I8:12 col=capital-loss:I8:13 col=hours-per-week:I8:14 quote+ header=+ sep=tab', - 'InputFile': '$file1' - }}, - 'Outputs': {{ - 'Data': '$data' - }} }}, {{ 'Name': 'Transforms.PermutationFeatureImportance', 'Inputs': {{ - 'Data': '$data', - 'ModelPath': '$file2', - 'PermutationCount': 20, + 'Data': '$input_data', + 'PredictorModel': '$output_model', + 'PermutationCount': 5, 'RowGroupColumnName': 'Workclass' }}, 'Outputs': {{ @@ -1389,21 +1328,19 @@ public void RankingPermutationFeatureImportance() }} ], 'Outputs': {{ - 'output_data': '{2}' + 'output_data': '{1}' }} - }}", EscapePath(dataPath), EscapePath(modelPath), EscapePath(outputDataPath)); + }}", EscapePath(inputData), EscapePath(outputDataPath)); - var pfiJsonPath = DeleteOutputPath("pfiGraph.json"); - File.WriteAllLines(pfiJsonPath, new[] { pfiGraph }); + var jsonPath = DeleteOutputPath("trainingGraph.json"); + File.WriteAllLines(jsonPath, new[] { inputGraph }); - var pfiArgs = new ExecuteGraphCommand.Arguments() { GraphPath = pfiJsonPath }; - var pfiCmd = new ExecuteGraphCommand(Env, pfiArgs); - pfiCmd.Run(); + var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath }; + var cmd = new ExecuteGraphCommand(Env, args); + cmd.Run(); var mlContext = new MLContext(); - var loadedData = mlContext.Data.LoadFromBinary(outputDataPath); - Assert.NotNull(loadedData.Schema.GetColumnOrNull("FeatureName")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("DiscountedCumulativeGains")); Assert.NotNull(loadedData.Schema.GetColumnOrNull("NormalizedDiscountedCumulativeGains")); From de67439c27bfb724bcf1d661cbe3f36d6f86c793 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Thu, 26 Sep 2019 18:50:47 -0700 Subject: [PATCH 13/15] Remove label column and group ID column from entrypoint input arguments --- .../PermutationFeatureImportance.cs | 42 ++++++++++--------- .../Common/EntryPoints/core_manifest.json | 28 +------------ .../UnitTests/TestEntryPoints.cs | 6 +-- 3 files changed, 27 insertions(+), 49 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index 40781cc30b..f8bcb12ed9 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -45,12 +45,6 @@ internal sealed class PermutationFeatureImportanceArguments : TransformInputBase [Argument(ArgumentType.Required, HelpText = "The path to the model file", ShortName = "path", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public PredictorModel PredictorModel; - [Argument(ArgumentType.AtMostOnce, HelpText = "Label column name", ShortName = "label", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public string LabelColumnName = "Label"; - - [Argument(ArgumentType.AtMostOnce, HelpText = "Group ID column", ShortName = "groupId", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public string RowGroupColumnName = "GroupId"; - [Argument(ArgumentType.AtMostOnce, HelpText = "Use feature weights to pre-filter features", ShortName = "usefw", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public bool UseFeatureWeightFilter = false; @@ -94,14 +88,16 @@ private static IDataView GetBinaryMetrics( IDataView data, PermutationFeatureImportanceArguments input) { + var roles = roleMappedSchema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new BinaryPredictionTransformer>( - env, predictor as IPredictorProducing, data.Schema, - roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + env, predictor as IPredictorProducing, data.Schema, featureColumnName); var binaryCatalog = new BinaryClassificationCatalog(env); var permutationMetrics = binaryCatalog .PermutationFeatureImportance(pred, data, - labelColumnName: input.LabelColumnName, + labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); @@ -150,15 +146,16 @@ private static IDataView GetMulticlassMetrics( IDataView data, PermutationFeatureImportanceArguments input) { + var roles = roleMappedSchema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new MulticlassPredictionTransformer>>( - env, predictor as IPredictorProducing>, data.Schema, - roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value, - roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value); + env, predictor as IPredictorProducing>, data.Schema, featureColumnName, labelColumnName); var multiclassCatalog = new MulticlassClassificationCatalog(env); var permutationMetrics = multiclassCatalog .PermutationFeatureImportance(pred, data, - labelColumnName: input.LabelColumnName, + labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); @@ -209,14 +206,16 @@ private static IDataView GetRegressionMetrics( IDataView data, PermutationFeatureImportanceArguments input) { + var roles = roleMappedSchema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new RegressionPredictionTransformer>( - env, predictor as IPredictorProducing, data.Schema, - roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + env, predictor as IPredictorProducing, data.Schema, featureColumnName); var regressionCatalog = new RegressionCatalog(env); var permutationMetrics = regressionCatalog .PermutationFeatureImportance(pred, data, - labelColumnName: input.LabelColumnName, + labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); @@ -259,15 +258,18 @@ private static IDataView GetRankingMetrics( IDataView data, PermutationFeatureImportanceArguments input) { + var roles = roleMappedSchema.GetColumnRoleNames(); + var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; + var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; + var groupIdColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Group.Value).First().Value; var pred = new RankingPredictionTransformer>( - env, predictor as IPredictorProducing, data.Schema, - roleMappedSchema.GetColumnRoleNames().Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + env, predictor as IPredictorProducing, data.Schema, featureColumnName); var rankingCatalog = new RankingCatalog(env); var permutationMetrics = rankingCatalog .PermutationFeatureImportance(pred, data, - labelColumnName: input.LabelColumnName, - rowGroupColumnName: input.RowGroupColumnName, + labelColumnName: labelColumnName, + rowGroupColumnName: groupIdColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index ec7ddbb497..c8e6d6e55c 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21725,8 +21725,8 @@ "IsNullable": false }, { - "Name": "ModelPath", - "Type": "FileHandle", + "Name": "PredictorModel", + "Type": "PredictorModel", "Desc": "The path to the model file", "Aliases": [ "path" @@ -21735,30 +21735,6 @@ "SortOrder": 150.0, "IsNullable": false }, - { - "Name": "LabelColumnName", - "Type": "String", - "Desc": "Label column name", - "Aliases": [ - "label" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "Label" - }, - { - "Name": "RowGroupColumnName", - "Type": "String", - "Desc": "Group ID column", - "Aliases": [ - "groupId" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": "GroupId" - }, { "Name": "UseFeatureWeightFilter", "Type": "Bool", diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 63179e4f41..b5b74cb915 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -224,7 +224,8 @@ private string GetBuildPrefix() #endif } - [Fact(Skip = "Execute this test if you want to regenerate the core_manifest and core_ep_list files")] + //[Fact(Skip = "Execute this test if you want to regenerate the core_manifest and core_ep_list files")] + [Fact] public void RegenerateEntryPointCatalog() { var (epListContents, jObj) = BuildManifests(); @@ -1319,8 +1320,7 @@ public void RankingPermutationFeatureImportance() 'Inputs': {{ 'Data': '$input_data', 'PredictorModel': '$output_model', - 'PermutationCount': 5, - 'RowGroupColumnName': 'Workclass' + 'PermutationCount': 5 }}, 'Outputs': {{ 'Metrics': '$output_data' From 7f729f1dcaec6735457698a219e92263fa76c933 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Thu, 26 Sep 2019 18:52:06 -0700 Subject: [PATCH 14/15] nit --- test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index b5b74cb915..87a0c61f1c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -224,8 +224,7 @@ private string GetBuildPrefix() #endif } - //[Fact(Skip = "Execute this test if you want to regenerate the core_manifest and core_ep_list files")] - [Fact] + [Fact(Skip = "Execute this test if you want to regenerate the core_manifest and core_ep_list files")] public void RegenerateEntryPointCatalog() { var (epListContents, jObj) = BuildManifests(); From 3211365b8479a1bed29fc7237b5434e8fb8d7f3e Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Fri, 27 Sep 2019 14:58:13 -0700 Subject: [PATCH 15/15] Simplify data prep --- .../PermutationFeatureImportance.cs | 62 +++++++++---------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index f8bcb12ed9..135ad3b333 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -28,8 +28,7 @@ public static PermutationFeatureImportanceOutput PermutationFeatureImportance(IH input.PredictorModel.PrepareData(env, input.Data, out RoleMappedData roleMappedData, out IPredictor predictor); Contracts.Assert(predictor != null, "No predictor found in model"); - var transformedData = input.PredictorModel.TransformModel.Apply(env, input.Data); - IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, roleMappedData.Schema, transformedData, input); + IDataView result = PermutationFeatureImportanceUtils.GetMetrics(env, predictor, roleMappedData, input); return new PermutationFeatureImportanceOutput { Metrics = result }; } } @@ -60,19 +59,18 @@ internal static class PermutationFeatureImportanceUtils internal static IDataView GetMetrics( IHostEnvironment env, IPredictor predictor, - RoleMappedSchema roleMappedSchema, - IDataView data, + RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { IDataView result; if (predictor.PredictionKind == PredictionKind.BinaryClassification) - result = GetBinaryMetrics(env, predictor, roleMappedSchema, data, input); + result = GetBinaryMetrics(env, predictor, roleMappedData, input); else if (predictor.PredictionKind == PredictionKind.MulticlassClassification) - result = GetMulticlassMetrics(env, predictor, roleMappedSchema, data, input); + result = GetMulticlassMetrics(env, predictor, roleMappedData, input); else if (predictor.PredictionKind == PredictionKind.Regression) - result = GetRegressionMetrics(env, predictor, roleMappedSchema, data, input); + result = GetRegressionMetrics(env, predictor, roleMappedData, input); else if (predictor.PredictionKind == PredictionKind.Ranking) - result = GetRankingMetrics(env, predictor, roleMappedSchema, data, input); + result = GetRankingMetrics(env, predictor, roleMappedData, input); else throw Contracts.Except( "Unsupported predictor type. Predictor must be binary classifier, " + @@ -84,25 +82,24 @@ internal static IDataView GetMetrics( private static IDataView GetBinaryMetrics( IHostEnvironment env, IPredictor predictor, - RoleMappedSchema roleMappedSchema, - IDataView data, + RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { - var roles = roleMappedSchema.GetColumnRoleNames(); + var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new BinaryPredictionTransformer>( - env, predictor as IPredictorProducing, data.Schema, featureColumnName); + env, predictor as IPredictorProducing, roleMappedData.Data.Schema, featureColumnName); var binaryCatalog = new BinaryClassificationCatalog(env); var permutationMetrics = binaryCatalog .PermutationFeatureImportance(pred, - data, + roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); - var slotNames = GetSlotNames(data); + var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -142,25 +139,24 @@ private static IDataView GetBinaryMetrics( private static IDataView GetMulticlassMetrics( IHostEnvironment env, IPredictor predictor, - RoleMappedSchema roleMappedSchema, - IDataView data, + RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { - var roles = roleMappedSchema.GetColumnRoleNames(); + var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new MulticlassPredictionTransformer>>( - env, predictor as IPredictorProducing>, data.Schema, featureColumnName, labelColumnName); + env, predictor as IPredictorProducing>, roleMappedData.Data.Schema, featureColumnName, labelColumnName); var multiclassCatalog = new MulticlassClassificationCatalog(env); var permutationMetrics = multiclassCatalog .PermutationFeatureImportance(pred, - data, + roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); - var slotNames = GetSlotNames(data); + var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -202,25 +198,24 @@ private static IDataView GetMulticlassMetrics( private static IDataView GetRegressionMetrics( IHostEnvironment env, IPredictor predictor, - RoleMappedSchema roleMappedSchema, - IDataView data, + RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { - var roles = roleMappedSchema.GetColumnRoleNames(); + var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var pred = new RegressionPredictionTransformer>( - env, predictor as IPredictorProducing, data.Schema, featureColumnName); + env, predictor as IPredictorProducing, roleMappedData.Data.Schema, featureColumnName); var regressionCatalog = new RegressionCatalog(env); var permutationMetrics = regressionCatalog .PermutationFeatureImportance(pred, - data, + roleMappedData.Data, labelColumnName: labelColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); - var slotNames = GetSlotNames(data); + var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -254,27 +249,26 @@ private static IDataView GetRegressionMetrics( private static IDataView GetRankingMetrics( IHostEnvironment env, IPredictor predictor, - RoleMappedSchema roleMappedSchema, - IDataView data, + RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { - var roles = roleMappedSchema.GetColumnRoleNames(); + var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value; var groupIdColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Group.Value).First().Value; var pred = new RankingPredictionTransformer>( - env, predictor as IPredictorProducing, data.Schema, featureColumnName); + env, predictor as IPredictorProducing, roleMappedData.Data.Schema, featureColumnName); var rankingCatalog = new RankingCatalog(env); var permutationMetrics = rankingCatalog .PermutationFeatureImportance(pred, - data, + roleMappedData.Data, labelColumnName: labelColumnName, rowGroupColumnName: groupIdColumnName, useFeatureWeightFilter: input.UseFeatureWeightFilter, numberOfExamplesToUse: input.NumberOfExamplesToUse, permutationCount: input.PermutationCount); - var slotNames = GetSlotNames(data); + var slotNames = GetSlotNames(roleMappedData.Schema); Contracts.Assert(slotNames.Length == permutationMetrics.Length, "Mismatch between number of feature slots and number of features permuted."); @@ -307,10 +301,10 @@ private static IDataView GetRankingMetrics( return result; } - private static string[] GetSlotNames(IDataView data) + private static string[] GetSlotNames(RoleMappedSchema schema) { VBuffer> slots = default; - data.Schema["Features"].GetSlotNames(ref slots); + schema.Feature.Value.GetSlotNames(ref slots); var slotValues = slots.DenseValues(); List slotNames = new List();