diff --git a/src/Microsoft.ML.AutoML/API/RankingExperiment.cs b/src/Microsoft.ML.AutoML/API/RankingExperiment.cs index ea0ba735c4..7233dae9b4 100644 --- a/src/Microsoft.ML.AutoML/API/RankingExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/RankingExperiment.cs @@ -19,6 +19,12 @@ public sealed class RankingExperimentSettings : ExperimentSettings /// The default value is . public RankingMetric OptimizingMetric { get; set; } + /// + /// Name for the GroupId column. + /// + /// The default value is GroupId. + public string GroupIdColumnName { get; set; } + /// /// Collection of trainers the AutoML experiment can leverage. /// @@ -28,6 +34,7 @@ public sealed class RankingExperimentSettings : ExperimentSettings public ICollection Trainers { get; } public RankingExperimentSettings() { + GroupIdColumnName = "GroupId"; OptimizingMetric = RankingMetric.Ndcg; Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType().ToList(); } @@ -68,10 +75,11 @@ public static class RankingExperimentResultExtensions /// /// Enumeration of AutoML experiment run results. /// Metric to consider when selecting the best run. + /// Name for the GroupId column. /// The best experiment run. - public static RunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg) + public static RunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId") { - var metricsAgent = new RankingMetricsAgent(null, metric); + var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName); var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); } @@ -81,10 +89,11 @@ public static RunDetail Best(this IEnumerable /// Enumeration of AutoML experiment cross validation run results. /// Metric to consider when selecting the best run. + /// Name for the GroupId column. /// The best experiment run. - public static CrossValidationRunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg) + public static CrossValidationRunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId") { - var metricsAgent = new RankingMetricsAgent(null, metric); + var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName); var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); } @@ -103,7 +112,7 @@ public sealed class RankingExperiment : ExperimentBase { private readonly MLContext _mlContext; private readonly RankingMetric _optimizingMetric; + private readonly string _groupIdColumnName; - public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric) + public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric, string groupIdColumnName) { _mlContext = mlContext; _optimizingMetric = optimizingMetric; + _groupIdColumnName = groupIdColumnName; } // Optimizing metric used: NDCG@10 and DCG@10 @@ -59,7 +61,7 @@ public bool IsModelPerfect(double score) public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn) { - return _mlContext.Ranking.Evaluate(data, labelColumn); + return _mlContext.Ranking.Evaluate(data, labelColumn, _groupIdColumnName); } } } diff --git a/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs b/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs index eab10e5def..7d8d55e49a 100644 --- a/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs +++ b/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs @@ -158,7 +158,7 @@ private static IDictionary BuildBasePipelineNodeProps(IEnumerabl } private static IDictionary BuildLightGbmPipelineNodeProps(IEnumerable sweepParams, - string labelColumn, string weightColumn, string groupColumn = null) + string labelColumn, string weightColumn, string groupColumn) { Dictionary props = null; if (sweepParams == null || !sweepParams.Any()) diff --git a/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs b/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs index a0bff55c9e..601fa212be 100644 --- a/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs @@ -35,9 +35,9 @@ public static RunDetail GetBestRun(IEnumerable< } public static RunDetail GetBestRun(IEnumerable> results, - RankingMetric metric) + RankingMetric metric, string groupIdColumnName) { - var metricsAgent = new RankingMetricsAgent(null, metric); + var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName); var metricInfo = new OptimizingMetricInfo(metric); return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing); } diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs index b43acf7af1..97b3e07aa4 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -124,7 +124,7 @@ public void AutoFitRankingTest() { string labelColumnName = "Label"; string scoreColumnName = "Score"; - string groupIdColumnName = "GroupId"; + string groupIdColumnName = "CustomGroupId"; string featuresColumnVectorNameA = "FeatureVectorA"; string featuresColumnVectorNameB = "FeatureVectorB"; var mlContext = new MLContext(1); @@ -136,7 +136,7 @@ public void AutoFitRankingTest() trainDataView = mlContext.Data.SkipRows(trainDataView, 500); // STEP 2: Run AutoML experiment ExperimentResult experimentResult = mlContext.Auto() - .CreateRankingExperiment(5) + .CreateRankingExperiment(new RankingExperimentSettings() { GroupIdColumnName = "CustomGroupId", MaxExperimentTimeInSeconds = 5}) .Execute(trainDataView, testDataView, new ColumnInformation() { diff --git a/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs index 0f9b336d84..2f3745d937 100644 --- a/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs @@ -127,14 +127,14 @@ public void RankingMetricsGetScoreTest() double[] ndcg = { 0.2, 0.3, 0.4 }; double[] dcg = { 0.2, 0.3, 0.4 }; var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg); - Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg)); - Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg)); + Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg, "GroupId")); + Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg, "GroupId")); double[] largeNdcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 }; double[] largeDcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 }; metrics = MetricsUtil.CreateRankingMetrics(largeDcg, largeNdcg); - Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg)); - Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg)); + Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg, "GroupId")); + Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg, "GroupId")); } [Fact] @@ -143,8 +143,8 @@ public void RankingMetricsNonPerfectTest() double[] ndcg = { 0.2, 0.3, 0.4 }; double[] dcg = { 0.2, 0.3, 0.4 }; var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg); - Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg)); - Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg)); + Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId")); + Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId")); } [Fact] @@ -153,8 +153,8 @@ public void RankingMetricsPerfectTest() double[] ndcg = { 0.2, 0.3, 1 }; double[] dcg = { 0.2, 0.3, 1 }; var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg); - Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg)); //REVIEW: No true Perfect model - Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg)); + Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId")); //REVIEW: No true Perfect model + Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId")); } [Fact] @@ -179,9 +179,9 @@ private static double GetScore(RegressionMetrics metrics, RegressionMetric metri return new RegressionMetricsAgent(null, metric).GetScore(metrics); } - private static double GetScore(RankingMetrics metrics, RankingMetric metric) + private static double GetScore(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName) { - return new RankingMetricsAgent(null, metric).GetScore(metrics); + return new RankingMetricsAgent(null, metric, groupIdColumnName).GetScore(metrics); } private static bool IsPerfectModel(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric) @@ -202,9 +202,9 @@ private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric m return IsPerfectModel(metricsAgent, metrics); } - private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric) + private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName) { - var metricsAgent = new RankingMetricsAgent(null, metric); + var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName); return IsPerfectModel(metricsAgent, metrics); }