diff --git a/src/Microsoft.ML.AutoML/API/RankingExperiment.cs b/src/Microsoft.ML.AutoML/API/RankingExperiment.cs
index ea0ba735c4..7233dae9b4 100644
--- a/src/Microsoft.ML.AutoML/API/RankingExperiment.cs
+++ b/src/Microsoft.ML.AutoML/API/RankingExperiment.cs
@@ -19,6 +19,12 @@ public sealed class RankingExperimentSettings : ExperimentSettings
/// The default value is .
public RankingMetric OptimizingMetric { get; set; }
+ ///
+ /// Name for the GroupId column.
+ ///
+ /// The default value is GroupId.
+ public string GroupIdColumnName { get; set; }
+
///
/// Collection of trainers the AutoML experiment can leverage.
///
@@ -28,6 +34,7 @@ public sealed class RankingExperimentSettings : ExperimentSettings
public ICollection Trainers { get; }
public RankingExperimentSettings()
{
+ GroupIdColumnName = "GroupId";
OptimizingMetric = RankingMetric.Ndcg;
Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType().ToList();
}
@@ -68,10 +75,11 @@ public static class RankingExperimentResultExtensions
///
/// Enumeration of AutoML experiment run results.
/// Metric to consider when selecting the best run.
+ /// Name for the GroupId column.
/// The best experiment run.
- public static RunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg)
+ public static RunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
{
- var metricsAgent = new RankingMetricsAgent(null, metric);
+ var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
@@ -81,10 +89,11 @@ public static RunDetail Best(this IEnumerable
/// Enumeration of AutoML experiment cross validation run results.
/// Metric to consider when selecting the best run.
+ /// Name for the GroupId column.
/// The best experiment run.
- public static CrossValidationRunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg)
+ public static CrossValidationRunDetail Best(this IEnumerable> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
{
- var metricsAgent = new RankingMetricsAgent(null, metric);
+ var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
@@ -103,7 +112,7 @@ public sealed class RankingExperiment : ExperimentBase
{
private readonly MLContext _mlContext;
private readonly RankingMetric _optimizingMetric;
+ private readonly string _groupIdColumnName;
- public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric)
+ public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric, string groupIdColumnName)
{
_mlContext = mlContext;
_optimizingMetric = optimizingMetric;
+ _groupIdColumnName = groupIdColumnName;
}
// Optimizing metric used: NDCG@10 and DCG@10
@@ -59,7 +61,7 @@ public bool IsModelPerfect(double score)
public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn)
{
- return _mlContext.Ranking.Evaluate(data, labelColumn);
+ return _mlContext.Ranking.Evaluate(data, labelColumn, _groupIdColumnName);
}
}
}
diff --git a/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs b/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs
index eab10e5def..7d8d55e49a 100644
--- a/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs
+++ b/src/Microsoft.ML.AutoML/TrainerExtensions/TrainerExtensionUtil.cs
@@ -158,7 +158,7 @@ private static IDictionary BuildBasePipelineNodeProps(IEnumerabl
}
private static IDictionary BuildLightGbmPipelineNodeProps(IEnumerable sweepParams,
- string labelColumn, string weightColumn, string groupColumn = null)
+ string labelColumn, string weightColumn, string groupColumn)
{
Dictionary props = null;
if (sweepParams == null || !sweepParams.Any())
diff --git a/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs b/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs
index a0bff55c9e..601fa212be 100644
--- a/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs
+++ b/src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs
@@ -35,9 +35,9 @@ public static RunDetail GetBestRun(IEnumerable<
}
public static RunDetail GetBestRun(IEnumerable> results,
- RankingMetric metric)
+ RankingMetric metric, string groupIdColumnName)
{
- var metricsAgent = new RankingMetricsAgent(null, metric);
+ var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
var metricInfo = new OptimizingMetricInfo(metric);
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
}
diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
index b43acf7af1..97b3e07aa4 100644
--- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
+++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
@@ -124,7 +124,7 @@ public void AutoFitRankingTest()
{
string labelColumnName = "Label";
string scoreColumnName = "Score";
- string groupIdColumnName = "GroupId";
+ string groupIdColumnName = "CustomGroupId";
string featuresColumnVectorNameA = "FeatureVectorA";
string featuresColumnVectorNameB = "FeatureVectorB";
var mlContext = new MLContext(1);
@@ -136,7 +136,7 @@ public void AutoFitRankingTest()
trainDataView = mlContext.Data.SkipRows(trainDataView, 500);
// STEP 2: Run AutoML experiment
ExperimentResult experimentResult = mlContext.Auto()
- .CreateRankingExperiment(5)
+ .CreateRankingExperiment(new RankingExperimentSettings() { GroupIdColumnName = "CustomGroupId", MaxExperimentTimeInSeconds = 5})
.Execute(trainDataView, testDataView,
new ColumnInformation()
{
diff --git a/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs
index 0f9b336d84..2f3745d937 100644
--- a/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs
+++ b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs
@@ -127,14 +127,14 @@ public void RankingMetricsGetScoreTest()
double[] ndcg = { 0.2, 0.3, 0.4 };
double[] dcg = { 0.2, 0.3, 0.4 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
- Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg));
- Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg));
+ Assert.Equal(0.4, GetScore(metrics, RankingMetric.Dcg, "GroupId"));
+ Assert.Equal(0.4, GetScore(metrics, RankingMetric.Ndcg, "GroupId"));
double[] largeNdcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
double[] largeDcg = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95 };
metrics = MetricsUtil.CreateRankingMetrics(largeDcg, largeNdcg);
- Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg));
- Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg));
+ Assert.Equal(0.9, GetScore(metrics, RankingMetric.Dcg, "GroupId"));
+ Assert.Equal(0.9, GetScore(metrics, RankingMetric.Ndcg, "GroupId"));
}
[Fact]
@@ -143,8 +143,8 @@ public void RankingMetricsNonPerfectTest()
double[] ndcg = { 0.2, 0.3, 0.4 };
double[] dcg = { 0.2, 0.3, 0.4 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
- Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg));
- Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg));
+ Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId"));
+ Assert.False(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId"));
}
[Fact]
@@ -153,8 +153,8 @@ public void RankingMetricsPerfectTest()
double[] ndcg = { 0.2, 0.3, 1 };
double[] dcg = { 0.2, 0.3, 1 };
var metrics = MetricsUtil.CreateRankingMetrics(dcg, ndcg);
- Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg)); //REVIEW: No true Perfect model
- Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg));
+ Assert.False(IsPerfectModel(metrics, RankingMetric.Dcg, "GroupId")); //REVIEW: No true Perfect model
+ Assert.True(IsPerfectModel(metrics, RankingMetric.Ndcg, "GroupId"));
}
[Fact]
@@ -179,9 +179,9 @@ private static double GetScore(RegressionMetrics metrics, RegressionMetric metri
return new RegressionMetricsAgent(null, metric).GetScore(metrics);
}
- private static double GetScore(RankingMetrics metrics, RankingMetric metric)
+ private static double GetScore(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName)
{
- return new RankingMetricsAgent(null, metric).GetScore(metrics);
+ return new RankingMetricsAgent(null, metric, groupIdColumnName).GetScore(metrics);
}
private static bool IsPerfectModel(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
@@ -202,9 +202,9 @@ private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric m
return IsPerfectModel(metricsAgent, metrics);
}
- private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric)
+ private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName)
{
- var metricsAgent = new RankingMetricsAgent(null, metric);
+ var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
return IsPerfectModel(metricsAgent, metrics);
}