diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs
index 8fedc3fb4f..bc97a372a0 100644
--- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs
+++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs
@@ -57,13 +57,13 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo
IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
if (overallMetrics == null)
{
- throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
+ throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
}
IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
if (confusionMatrix == null)
{
- throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
+ throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
}
var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
diff --git a/src/Microsoft.ML/Models/ClusterEvaluator.cs b/src/Microsoft.ML/Models/ClusterEvaluator.cs
new file mode 100644
index 0000000000..5aceca16f7
--- /dev/null
+++ b/src/Microsoft.ML/Models/ClusterEvaluator.cs
@@ -0,0 +1,71 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Transforms;
+
+namespace Microsoft.ML.Models
+{
+ public sealed partial class ClusterEvaluator
+ {
+ ///
+ /// Computes the quality metrics for the PredictionModel using the specified data set.
+ ///
+ ///
+ /// The trained PredictionModel to be evaluated.
+ ///
+ ///
+ /// The test data that will be predicted and used to evaluate the model.
+ ///
+ ///
+ /// A ClusterMetrics instance that describes how well the model performed against the test data.
+ ///
+ public ClusterMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
+ {
+ using (var environment = new TlcEnvironment())
+ {
+ environment.CheckValue(model, nameof(model));
+ environment.CheckValue(testData, nameof(testData));
+
+ Experiment experiment = environment.CreateExperiment();
+
+ ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
+ if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
+ {
+ throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
+ }
+
+ var datasetScorer = new DatasetTransformScorer
+ {
+ Data = testDataOutput.Data,
+ };
+ DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
+
+ Data = scoreOutput.ScoredData;
+ Output evaluteOutput = experiment.Add(this);
+
+ experiment.Compile();
+
+ experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
+ testData.SetInput(environment, experiment);
+
+ experiment.Run();
+
+ IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
+
+ if (overallMetrics == null)
+ {
+ throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClusterEvaluator)} Evaluate.");
+ }
+
+ var metric = ClusterMetrics.FromOverallMetrics(environment, overallMetrics);
+
+ Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics");
+
+ return metric[0];
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML/Models/ClusterMetrics.cs b/src/Microsoft.ML/Models/ClusterMetrics.cs
new file mode 100644
index 0000000000..7f88784ef8
--- /dev/null
+++ b/src/Microsoft.ML/Models/ClusterMetrics.cs
@@ -0,0 +1,94 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Models
+{
+ ///
+ /// This class contains the overall metrics computed by cluster evaluators.
+ ///
+ public sealed class ClusterMetrics
+ {
+ private ClusterMetrics()
+ {
+ }
+
+ internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
+ {
+ Contracts.AssertValue(env);
+ env.AssertValue(overallMetrics);
+
+ var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true);
+ if (!metricsEnumerable.GetEnumerator().MoveNext())
+ {
+ throw env.Except("The overall ClusteringMetrics didn't have any rows.");
+ }
+
+ var metrics = new List();
+ foreach (var metric in metricsEnumerable)
+ {
+ metrics.Add(new ClusterMetrics()
+ {
+ AvgMinScore = metric.AvgMinScore,
+ Nmi = metric.Nmi,
+ Dbi = metric.Dbi,
+ });
+ }
+
+ return metrics;
+ }
+
+ ///
+ /// Davies-Bouldin Index.
+ ///
+ ///
+ /// DBI is a measure of the how much scatter is in the cluster and the cluster separation.
+ ///
+ public double Dbi { get; private set; }
+
+ ///
+ /// Normalized Mutual Information
+ ///
+ ///
+ /// NMI is a measure of the mutual dependence between the true and predicted cluster labels for instances in the dataset.
+ /// NMI ranges between 0 and 1 where "0" indicates clustering is random and "1" indicates clustering is perfect w.r.t true labels.
+ ///
+ public double Nmi { get; private set; }
+
+ ///
+ /// Average minimum score.
+ ///
+ ///
+ /// AvgMinScore is the average squared-distance of examples from the respective cluster centroids.
+ /// It is defined as
+ /// AvgMinScore = (1/m) * sum ((xi - c(xi))^2)
+ /// where m is the number of instances in the dataset.
+ /// xi is the i'th instance and c(xi) is the centriod of the predicted cluster for xi.
+ ///
+ public double AvgMinScore { get; private set; }
+
+ ///
+ /// This class contains the public fields necessary to deserialize from IDataView.
+ ///
+ private sealed class SerializationClass
+ {
+#pragma warning disable 649 // never assigned
+ [ColumnName(Runtime.Data.ClusteringEvaluator.Dbi)]
+ public Double Dbi;
+
+ [ColumnName(Runtime.Data.ClusteringEvaluator.Nmi)]
+ public Double Nmi;
+
+ [ColumnName(Runtime.Data.ClusteringEvaluator.AvgMinScore)]
+ public Double AvgMinScore;
+
+#pragma warning restore 649 // never assigned
+ }
+ }
+}
diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs
index 173e03916c..ab84f8a715 100644
--- a/src/Microsoft.ML/Models/CrossValidator.cs
+++ b/src/Microsoft.ML/Models/CrossValidator.cs
@@ -19,7 +19,7 @@ public sealed partial class CrossValidator
/// Class type that represents prediction schema.
/// Machine learning pipeline may contain loader, transforms and at least one trainer.
/// List containing metrics and predictor model for each fold
- public CrossValidationOutput CrossValidate(LearningPipeline pipeline)
+ public CrossValidationOutput CrossValidate(LearningPipeline pipeline)
where TInput : class
where TOutput : class, new()
{
@@ -76,7 +76,7 @@ public CrossValidationOutput CrossValidate(Lea
{
PredictorModel = predictorModel
};
-
+
var scorerOutput = subGraph.Add(scorer);
lastTransformModel = scorerOutput.ScoringTransform;
step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform);
@@ -129,7 +129,7 @@ public CrossValidationOutput CrossValidate(Lea
experiment.GetOutput(crossValidateOutput.OverallMetrics),
experiment.GetOutput(crossValidateOutput.ConfusionMatrix), 2);
}
- else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
+ else if (Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
{
cvOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics(
environment,
@@ -142,6 +142,12 @@ public CrossValidationOutput CrossValidate(Lea
environment,
experiment.GetOutput(crossValidateOutput.OverallMetrics));
}
+ else if (Kind == MacroUtilsTrainerKinds.SignatureClusteringTrainer)
+ {
+ cvOutput.ClusterMetrics = ClusterMetrics.FromOverallMetrics(
+ environment,
+ experiment.GetOutput(crossValidateOutput.OverallMetrics));
+ }
else
{
//Implement metrics for ranking, clustering and anomaly detection.
@@ -174,6 +180,7 @@ public class CrossValidationOutput
public List BinaryClassificationMetrics;
public List ClassificationMetrics;
public List RegressionMetrics;
+ public List ClusterMetrics;
public PredictionModel[] PredictorModels;
//REVIEW: Add warnings and per instance results and implement
diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs
index 19261e82de..ae00a34de6 100644
--- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs
+++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs
@@ -102,7 +102,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate TrainTestEvaluate TrainTestEvaluate(predictor, memoryStream);
}
-
+
return trainTestOutput;
}
}
@@ -171,6 +177,7 @@ public class TrainTestEvaluatorOutput
public BinaryClassificationMetrics BinaryClassificationMetrics;
public ClassificationMetrics ClassificationMetrics;
public RegressionMetrics RegressionMetrics;
+ public ClusterMetrics ClusterMetrics;
public PredictionModel PredictorModels;
//REVIEW: Add warnings and per instance results and implement
diff --git a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs
index 560ee11d28..ad1027bb1a 100644
--- a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs
@@ -1,4 +1,5 @@
using Microsoft.ML.Data;
+using Microsoft.ML.Models;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Trainers;
@@ -116,6 +117,16 @@ public void PredictClusters()
Assert.True(!labels.Contains(scores.SelectedClusterId));
labels.Add(scores.SelectedClusterId);
}
+
+ var evaluator = new ClusterEvaluator();
+ var testData = CollectionDataSource.Create(clusters);
+ ClusterMetrics metrics = evaluator.Evaluate(model, testData);
+
+ //Label is not specified, so NMI would be equal to NaN
+ Assert.Equal(metrics.Nmi, double.NaN);
+ //Calculate dbi is false by default so Dbi would be 0
+ Assert.Equal(metrics.Dbi, (double)0.0);
+ Assert.Equal(metrics.AvgMinScore, (double)0.0, 5);
}
}
}