diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs index 1bd034ff74..648e49f565 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs @@ -171,7 +171,7 @@ public static ImmutableArray int processedCnt = 0; int nextFeatureIndex = 0; var shuffleRand = RandomUtils.Create(host.Rand.Next()); - using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup")) + using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance")) { pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt)); foreach (var workingIndx in workingFeatureIndices) diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index 07b9c8f435..02e8832f0e 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -145,16 +145,16 @@ public static ImmutableArray int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( - catalog.GetEnvironment(), - predictionTransformer, - data, - () => new BinaryClassificationMetricsStatistics(), - idv => catalog.Evaluate(idv, labelColumnName), - BinaryClassifierDelta, - predictionTransformer.FeatureColumnName, - permutationCount, - useFeatureWeightFilter, - numberOfExamplesToUse); + catalog.GetEnvironment(), + predictionTransformer, + data, + () => new BinaryClassificationMetricsStatistics(), + idv => catalog.EvaluateNonCalibrated(idv, labelColumnName), + BinaryClassifierDelta, + predictionTransformer.FeatureColumnName, + permutationCount, + useFeatureWeightFilter, + numberOfExamplesToUse); } private static BinaryClassificationMetrics BinaryClassifierDelta( diff --git a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs index 3feca9421b..5bd999f21b 100644 --- a/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs +++ b/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs @@ -305,6 +305,36 @@ public void TestPfiBinaryClassificationOnSparseFeatures(bool saveModel) Done(); } + + [Fact] + public void TestBinaryClassificationWithoutCalibrator() + { + var dataPath = GetDataPath("breast-cancer.txt"); + var ff = ML.BinaryClassification.Trainers.FastForest(); + var data = ML.Data.LoadFromTextFile(dataPath, + new[] { new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Features", DataKind.Single, 1, 9) }); + var model = ff.Fit(data); + var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data); + + // For the following metrics higher is better, so minimum delta means more important feature, and vice versa + Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Accuracy.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.Accuracy.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositivePrecision.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositivePrecision.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositiveRecall.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositiveRecall.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativePrecision.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativePrecision.Mean)); + Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativeRecall.Mean)); + Assert.Equal(3, MaxDeltaIndex(pfi, m => m.F1Score.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.F1Score.Mean)); + Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean)); + Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean)); + } #endregion #region Multiclass Classification Tests