diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs index 8fedc3fb4f..bc97a372a0 100644 --- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs @@ -57,13 +57,13 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); if (overallMetrics == null) { - throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); + throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClassificationEvaluator)} Evaluate."); } IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix); if (confusionMatrix == null) { - throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); + throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(ClassificationEvaluator)} Evaluate."); } var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); diff --git a/src/Microsoft.ML/Models/ClusterEvaluator.cs b/src/Microsoft.ML/Models/ClusterEvaluator.cs new file mode 100644 index 0000000000..5aceca16f7 --- /dev/null +++ b/src/Microsoft.ML/Models/ClusterEvaluator.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Transforms; + +namespace Microsoft.ML.Models +{ + public sealed partial class ClusterEvaluator + { + /// + /// Computes the quality metrics for the PredictionModel using the specified data set. + /// + /// + /// The trained PredictionModel to be evaluated. + /// + /// + /// The test data that will be predicted and used to evaluate the model. + /// + /// + /// A ClusterMetrics instance that describes how well the model performed against the test data. + /// + public ClusterMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) + { + using (var environment = new TlcEnvironment()) + { + environment.CheckValue(model, nameof(model)); + environment.CheckValue(testData, nameof(testData)); + + Experiment experiment = environment.CreateExperiment(); + + ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); + if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) + { + throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); + } + + var datasetScorer = new DatasetTransformScorer + { + Data = testDataOutput.Data, + }; + DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); + + Data = scoreOutput.ScoredData; + Output evaluteOutput = experiment.Add(this); + + experiment.Compile(); + + experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); + testData.SetInput(environment, experiment); + + experiment.Run(); + + IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); + + if (overallMetrics == null) + { + throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClusterEvaluator)} Evaluate."); + } + + var metric = ClusterMetrics.FromOverallMetrics(environment, overallMetrics); + + Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics"); + + return metric[0]; + } + } + } +} diff --git a/src/Microsoft.ML/Models/ClusterMetrics.cs b/src/Microsoft.ML/Models/ClusterMetrics.cs new file mode 100644 index 0000000000..7f88784ef8 --- /dev/null +++ b/src/Microsoft.ML/Models/ClusterMetrics.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Models +{ + /// + /// This class contains the overall metrics computed by cluster evaluators. + /// + public sealed class ClusterMetrics + { + private ClusterMetrics() + { + } + + internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics) + { + Contracts.AssertValue(env); + env.AssertValue(overallMetrics); + + var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); + if (!metricsEnumerable.GetEnumerator().MoveNext()) + { + throw env.Except("The overall ClusteringMetrics didn't have any rows."); + } + + var metrics = new List(); + foreach (var metric in metricsEnumerable) + { + metrics.Add(new ClusterMetrics() + { + AvgMinScore = metric.AvgMinScore, + Nmi = metric.Nmi, + Dbi = metric.Dbi, + }); + } + + return metrics; + } + + /// + /// Davies-Bouldin Index. + /// + /// + /// DBI is a measure of the how much scatter is in the cluster and the cluster separation. + /// + public double Dbi { get; private set; } + + /// + /// Normalized Mutual Information + /// + /// + /// NMI is a measure of the mutual dependence between the true and predicted cluster labels for instances in the dataset. + /// NMI ranges between 0 and 1 where "0" indicates clustering is random and "1" indicates clustering is perfect w.r.t true labels. + /// + public double Nmi { get; private set; } + + /// + /// Average minimum score. + /// + /// + /// AvgMinScore is the average squared-distance of examples from the respective cluster centroids. + /// It is defined as + /// AvgMinScore = (1/m) * sum ((xi - c(xi))^2) + /// where m is the number of instances in the dataset. + /// xi is the i'th instance and c(xi) is the centriod of the predicted cluster for xi. + /// + public double AvgMinScore { get; private set; } + + /// + /// This class contains the public fields necessary to deserialize from IDataView. + /// + private sealed class SerializationClass + { +#pragma warning disable 649 // never assigned + [ColumnName(Runtime.Data.ClusteringEvaluator.Dbi)] + public Double Dbi; + + [ColumnName(Runtime.Data.ClusteringEvaluator.Nmi)] + public Double Nmi; + + [ColumnName(Runtime.Data.ClusteringEvaluator.AvgMinScore)] + public Double AvgMinScore; + +#pragma warning restore 649 // never assigned + } + } +} diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index 173e03916c..ab84f8a715 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -19,7 +19,7 @@ public sealed partial class CrossValidator /// Class type that represents prediction schema. /// Machine learning pipeline may contain loader, transforms and at least one trainer. /// List containing metrics and predictor model for each fold - public CrossValidationOutput CrossValidate(LearningPipeline pipeline) + public CrossValidationOutput CrossValidate(LearningPipeline pipeline) where TInput : class where TOutput : class, new() { @@ -76,7 +76,7 @@ public CrossValidationOutput CrossValidate(Lea { PredictorModel = predictorModel }; - + var scorerOutput = subGraph.Add(scorer); lastTransformModel = scorerOutput.ScoringTransform; step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); @@ -129,7 +129,7 @@ public CrossValidationOutput CrossValidate(Lea experiment.GetOutput(crossValidateOutput.OverallMetrics), experiment.GetOutput(crossValidateOutput.ConfusionMatrix), 2); } - else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) + else if (Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) { cvOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics( environment, @@ -142,6 +142,12 @@ public CrossValidationOutput CrossValidate(Lea environment, experiment.GetOutput(crossValidateOutput.OverallMetrics)); } + else if (Kind == MacroUtilsTrainerKinds.SignatureClusteringTrainer) + { + cvOutput.ClusterMetrics = ClusterMetrics.FromOverallMetrics( + environment, + experiment.GetOutput(crossValidateOutput.OverallMetrics)); + } else { //Implement metrics for ranking, clustering and anomaly detection. @@ -174,6 +180,7 @@ public class CrossValidationOutput public List BinaryClassificationMetrics; public List ClassificationMetrics; public List RegressionMetrics; + public List ClusterMetrics; public PredictionModel[] PredictorModels; //REVIEW: Add warnings and per instance results and implement diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs index 19261e82de..ae00a34de6 100644 --- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -102,7 +102,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate TrainTestEvaluate TrainTestEvaluate(predictor, memoryStream); } - + return trainTestOutput; } } @@ -171,6 +177,7 @@ public class TrainTestEvaluatorOutput public BinaryClassificationMetrics BinaryClassificationMetrics; public ClassificationMetrics ClassificationMetrics; public RegressionMetrics RegressionMetrics; + public ClusterMetrics ClusterMetrics; public PredictionModel PredictorModels; //REVIEW: Add warnings and per instance results and implement diff --git a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs index 560ee11d28..ad1027bb1a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs @@ -1,4 +1,5 @@ using Microsoft.ML.Data; +using Microsoft.ML.Models; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; @@ -116,6 +117,16 @@ public void PredictClusters() Assert.True(!labels.Contains(scores.SelectedClusterId)); labels.Add(scores.SelectedClusterId); } + + var evaluator = new ClusterEvaluator(); + var testData = CollectionDataSource.Create(clusters); + ClusterMetrics metrics = evaluator.Evaluate(model, testData); + + //Label is not specified, so NMI would be equal to NaN + Assert.Equal(metrics.Nmi, double.NaN); + //Calculate dbi is false by default so Dbi would be 0 + Assert.Equal(metrics.Dbi, (double)0.0); + Assert.Equal(metrics.AvgMinScore, (double)0.0, 5); } } }