From d2fb318088a3415f70db5b8655abd071060a3864 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 10 Jun 2020 16:41:59 -0700 Subject: [PATCH 01/13] Fixed issue --- .../TrainTestSplit.cs | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs index 5d953668e9..a08e69cc6a 100644 --- a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs @@ -67,19 +67,21 @@ public static Output Split(IHostEnvironment env, Input input) internal static class SplitUtils { + // Creates a new Stratification column to be used for splitting. + // Notice that the new column might be dropped elsewhere in the code + // Returns: the name of the new column. public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null) { host.CheckValue(data, nameof(data)); host.CheckValueOrNull(stratificationColumn); - // Pick a unique name for the stratificationColumn. + // Pick a unique name for the new stratificationColumn. const string stratColName = "StratificationKey"; string stratCol = data.Schema.GetTempColumnName(stratColName); - // Construct the stratification column. If user-provided stratification column exists, use HashJoin - // of it to construct the strat column, otherwise generate a random number and use it. if (stratificationColumn == null) { + // If the stratificationColumn wasn't provided by the user, simply create a new Random Number Generator data = new GenerateNumberTransform(host, new GenerateNumberTransform.Options { @@ -106,11 +108,15 @@ public static string CreateStratificationColumn(IHost host, ref IDataView data, else { if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) - return stratificationColumn; - - data = new NormalizingEstimator(host, - new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true)) - .Fit(data).Transform(data); + { + data = new ColumnCopyingEstimator(host,(stratCol,stratificationColumn)).Fit(data).Transform(data); + } + else + { + data = new NormalizingEstimator(host, + new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true)) + .Fit(data).Transform(data); + } } } From e9454a08ba26472b6a1a9afdf469a9d60a7320a5 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Tue, 16 Jun 2020 00:11:54 -0700 Subject: [PATCH 02/13] Added test --- .../UnitTests/TestEntryPoints.cs | 226 ++++++++++++++++++ 1 file changed, 226 insertions(+) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 22da23dd4e..a2a04f0d97 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -6613,5 +6613,231 @@ public void SummarizeEntryPointTest() Done(); } + + [Fact] + void RankingWithColumnIdEntryPoint() + { + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); + var dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); + + string inputGraph = $@" + {{ + 'Nodes': [ + {{ + 'Inputs': {{ + 'CustomSchema': 'col=Label:R4:0 col=GroupId:TX:1 col=Features_1:R4:9-14 header+ sep=tab', + 'InputFile': '$file' + }}, + 'Name': 'Data.CustomTextLoader', + 'Outputs': {{ + 'Data': '$input_data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'GroupId', + 'Source': 'GroupId' + }} + ], + 'Data': '$input_data', + 'MaxNumTerms': 1000000, + 'Sort': 'ByOccurrence', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.TextToKeyConverter', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'Features', + 'Source': [ + 'Features_1' + ] + }} + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.ColumnConcatenator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Models': [ + '$output_model1', + '$output_model2' + ] + }}, + 'Name': 'Transforms.ModelCombiner', + 'Outputs': {{ + 'OutputModel': '$output_model_combined_pre_split' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'GroupColumn': 'GroupId', + 'Inputs': {{ + 'Data': '$cv_subgraph_input_data' + }}, + 'Kind': 'SignatureRankerTrainer', + 'LabelColumn': 'Label', + 'NameColumn': 'Name', + 'Nodes': [ + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$cv_subgraph_input_data' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$optional_data' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$optional_data', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$label_data' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$label_data', + 'Features': [ + 'Features' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model5', + 'OutputData': '$output_data' + }} + }}, + {{ + 'Inputs': {{ + 'BatchSize': 1048576, + 'Caching': 'Auto', + 'CategoricalSmoothing': 10.0, + 'CustomGains': [ + 0, + 3, + 7, + 15, + 31, + 63, + 127, + 255, + 511, + 1023, + 2047, + 4095 + ], + 'EarlyStoppingRound': 0, + 'EvaluationMetric': 'NormalizedDiscountedCumulativeGain', + 'FeatureColumnName': 'Features', + 'HandleMissingValue': true, + 'L2CategoricalRegularization': 10.0, + 'LabelColumnName': 'Label', + 'MaximumBinCountPerFeature': 255, + 'MaximumCategoricalSplitPointCount': 32, + 'MinimumExampleCountPerGroup': 100, + 'MinimumExampleCountPerLeaf': 1, + 'NormalizeFeatures': 'Auto', + 'NumberOfIterations': 100, + 'RowGroupColumnName': 'GroupId', + 'Sigmoid': 0.5, + 'Silent': true, + 'TrainingData': '$output_data', + 'UseZeroAsMissingValue': false, + 'Verbose': false + }}, + 'Name': 'Trainers.LightGbmRanker', + 'Outputs': {{ + 'PredictorModel': '$output_model_learner' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$output_model_learner', + 'TransformModels': [ + '$output_model3', + '$output_model4', + '$output_model5' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }} + ], + 'NumFolds': 2, + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }}, + 'StratificationColumn': 'GroupId', + 'TransformModel': '$output_model_combined_pre_split' + }}, + 'Name': 'Models.CrossValidator', + 'Outputs': {{ + 'OverallMetrics': '$overall_metrics', + 'PerInstanceMetrics': '$per_instance_metrics', + 'PredictorModel': '$predictor_model', + 'Warnings': '$warnings' + }} + }} + ], + 'Outputs': {{ + 'overall_metrics': '$outmetrics', + 'per_instance_metrics': '', + 'predictor_model': '$outModel', + 'warnings': '$outwarnings' + }} + }} + + "; + + JObject graph = JObject.Parse(inputGraph); + var runner = new GraphRunner(Env, graph[FieldNames.Nodes] as JArray); + var inputFile = new SimpleFileHandle(Env, dataPath, false, false); + runner.SetInput("file", inputFile); + runner.RunAll(); + + var data = runner.GetOutput("overall_metrics"); + using(var cursor = data.GetRowCursorForAllColumns()) + { + var ndcgGetter = cursor.GetGetter>(data.Schema["NDCG"]); + VBuffer ndcgBuffer = default; + + cursor.MoveNext(); + ndcgGetter(ref ndcgBuffer); + var ndcgArray = ndcgBuffer.DenseValues().ToArray(); + + // Since we used a toy dataset, we won't worry much about comparing actual + // Double values of the result. Simply check that we get results. + Assert.Equal(3, ndcgArray.Length); + Assert.True(ndcgArray[0] > 0); + Assert.True(ndcgArray[1] > 0); + Assert.True(ndcgArray[2] > 0); + } + } } } From 8a747e5aec5988e885426317d50e340212c09f6e Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Tue, 16 Jun 2020 09:56:24 -0700 Subject: [PATCH 03/13] Added LightGBM Fact Attribute to test --- test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index a2a04f0d97..132b62d937 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -6614,7 +6614,7 @@ public void SummarizeEntryPointTest() Done(); } - [Fact] + [LightGBMFact] void RankingWithColumnIdEntryPoint() { Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); From 1f945844134e8398811dc3e4ee1b74e2b6be9e99 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 17 Jun 2020 18:54:12 -0700 Subject: [PATCH 04/13] Added test for Key type SamplingKeyColumn on TrainTestSplit --- test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 0c8f8a7be0..e33a22aad0 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -426,6 +426,13 @@ public void TestTrainTestSplitWithStratification() ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); Assert.Contains(1, ids); Assert.Contains(2, ids); + + var inputWithKey = mlContext.Transforms.Conversion.MapValueToKey("KeyStrat", "TextStrat").Fit(input).Transform(input); + split = mlContext.Data.TrainTestSplit(inputWithKey, 0.5, "KeyStrat"); + ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); + Assert.Contains(1, ids); + Assert.Contains(5, ids); + Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull("KeyStrat")); // Check that the original column wasn't deleted by the split } } } From 033ae83f2764e0500dfa955665ea1efb3e4d8bb2 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 17 Jun 2020 19:19:39 -0700 Subject: [PATCH 05/13] Added test for CV and that SamplingKeyColumn isn't removed by running a split --- .../Scenarios/Api/TestApi.cs | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index e33a22aad0..810447c183 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -363,7 +363,7 @@ private sealed class Input } [Fact] - public void TestTrainTestSplitWithStratification() + public void TestSplitsWithSamplingKeyColumn() { var mlContext = new MLContext(0); var input = mlContext.Data.LoadFromEnumerable(new[] @@ -402,6 +402,7 @@ public void TestTrainTestSplitWithStratification() }, }); + // TEST TRAINTESTSPLIT var split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TextStrat)); var ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); Assert.Contains(1, ids); @@ -432,7 +433,26 @@ public void TestTrainTestSplitWithStratification() ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); Assert.Contains(1, ids); Assert.Contains(5, ids); - Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull("KeyStrat")); // Check that the original column wasn't deleted by the split + Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull("KeyStrat")); // Check that the key column used as SamplingKeyColumn wasn't deleted by the split + + // TEST CROSSVALIDATIONSPLIT + var colnames = new[] { + nameof(Input.TextStrat), + nameof(Input.FloatStrat), + nameof(Input.VectorStrat), + nameof(Input.DateTimeStrat), + nameof(Input.DateTimeOffsetStrat), + nameof(Input.TimeSpanStrat), + "KeyStrat" }; + + foreach(var colname in colnames) + { + var cvSplits = mlContext.Data.CrossValidationSplit(inputWithKey, numberOfFolds: 2, samplingKeyColumnName: colname); + var idsTest1 = cvSplits[0].TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); + var idsTest2 = cvSplits[1].TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); + Assert.True(Enumerable.Intersect(idsTest1, idsTest2).Count() == 0); + Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull(colname)); // Check that using CV didn't remove the SamplingKeyColumn + } } } } From 16383414603e5e286d5404e3799f8ceb94e550b1 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 17 Jun 2020 19:59:41 -0700 Subject: [PATCH 06/13] Added assertion for non-emptiness and fixed mini-issue with the MinMaxNormalizer for this assertion --- .../DataLoadSave/DataOperationsCatalog.cs | 24 ++++++++++++------- .../Scenarios/Api/TestApi.cs | 12 ++++++---- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index 0bd4075670..d203f95be8 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -491,13 +491,13 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment } /// - /// Ensures the provided is valid for , hashing it if necessary, or creates a new column is null. + /// Ensures the provided is valid for , hashing, copying, or normalizing it if necessary, + /// or creates a new column if is null. /// internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null) { Contracts.CheckValue(env, nameof(env)); - // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to - // build a single hash of it. If it is not, we generate a random number. + // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number. if (samplingKeyColumn == null) { samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); @@ -505,14 +505,18 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa } else { + // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary + // name, as it might be dropped elsewhere in the code + if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); + var origStratCol = samplingKeyColumn; + samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); + var type = data.Schema[stratCol].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { - var origStratCol = samplingKeyColumn; - samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) @@ -527,11 +531,13 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa } else { - if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double)) + if (data.Schema[origStratCol].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) + { + data = new ColumnCopyingEstimator(env, (samplingKeyColumn, origStratCol)).Fit(data).Transform(data); + } + else { - var origStratCol = samplingKeyColumn; - samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); - data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data); + data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: false)).Fit(data).Transform(data); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 810447c183..1f80233a4e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -408,7 +408,7 @@ public void TestSplitsWithSamplingKeyColumn() Assert.Contains(1, ids); Assert.Contains(5, ids); split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.FloatStrat)); - ids = split.TrainSet.GetColumn(split.TrainSet.Schema[nameof(Input.Id)]); + ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); Assert.Contains(4, ids); Assert.Contains(5, ids); split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.VectorStrat)); @@ -448,10 +448,14 @@ public void TestSplitsWithSamplingKeyColumn() foreach(var colname in colnames) { var cvSplits = mlContext.Data.CrossValidationSplit(inputWithKey, numberOfFolds: 2, samplingKeyColumnName: colname); - var idsTest1 = cvSplits[0].TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); - var idsTest2 = cvSplits[1].TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); + var idsTest1 = cvSplits[0].TestSet.GetColumn(cvSplits[0].TestSet.Schema[nameof(Input.Id)]); + var idsTest2 = cvSplits[1].TestSet.GetColumn(cvSplits[1].TestSet.Schema[nameof(Input.Id)]); Assert.True(Enumerable.Intersect(idsTest1, idsTest2).Count() == 0); - Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull(colname)); // Check that using CV didn't remove the SamplingKeyColumn + Assert.True(idsTest1.Count() > 0, $"CV Split 0 for Column {colname} was empty"); + Assert.True(idsTest2.Count() > 0, $"CV Split 1 for Column {colname} was empty"); + + // Check that using CV didn't remove the SamplingKeyColumn + Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull(colname)); } } } From 5e9a8e4ab0a73eb8f0b587a8de8540534d560c02 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 17 Jun 2020 22:13:11 -0700 Subject: [PATCH 07/13] Removed CreateStratificationColumn and use EnsureGroupPreservationColumn instead --- .../DataLoadSave/DataOperationsCatalog.cs | 59 ++++++++++++------ src/Microsoft.ML.Data/TrainCatalog.cs | 4 +- src/Microsoft.ML.EntryPoints/CVSplit.cs | 3 +- .../TrainTestSplit.cs | 61 +------------------ .../UnitTests/TestEntryPoints.cs | 2 +- 5 files changed, 46 insertions(+), 83 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index d203f95be8..9a7d3fc182 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -413,18 +413,18 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s _env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); _env.CheckValueOrNull(samplingKeyColumnName); - EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed); + var newSamplingKeyColumn = EnsureGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var trainFilter = new RangeFilter(_env, new RangeFilter.Options() { - Column = samplingKeyColumnName, + Column = newSamplingKeyColumn, Min = 0, Max = testFraction, Complement = true }, data); var testFilter = new RangeFilter(_env, new RangeFilter.Options() { - Column = samplingKeyColumnName, + Column = newSamplingKeyColumn, Min = 0, Max = testFraction, Complement = false @@ -455,9 +455,9 @@ public IReadOnlyList CrossValidationSplit(IDataView data, int num _env.CheckValue(data, nameof(data)); _env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1"); _env.CheckValueOrNull(samplingKeyColumnName); - EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed); + var newSamplingKeyColumn = EnsureGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var result = new List(); - foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, samplingKeyColumnName)) + foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, newSamplingKeyColumn)) result.Add(split); return result; } @@ -494,14 +494,33 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment /// Ensures the provided is valid for , hashing, copying, or normalizing it if necessary, /// or creates a new column if is null. /// - internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null) + [BestFriend] + internal static string EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false) { Contracts.CheckValue(env, nameof(env)); + Contracts.CheckValueOrNull(samplingKeyColumn); + + var newSamplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); + int? seedToUse; + + if(seed.HasValue) + { + seedToUse = seed.Value; + } + else if(fallbackInEnvSeed) + { + ISeededEnvironment seededEnv = (ISeededEnvironment)env; + seedToUse = seededEnv.Seed; + } + else + { + seedToUse = null; + } + // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number. if (samplingKeyColumn == null) { - samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); - data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed)); + data = new GenerateNumberTransform(env, data, newSamplingKeyColumn, (uint?)seedToUse); } else { @@ -511,36 +530,38 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); - var origStratCol = samplingKeyColumn; - samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); - var type = data.Schema[stratCol].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { + var hashInputColumnName = samplingKeyColumn; // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) - data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data); + { + data = new TypeConvertingTransformer(env, newSamplingKeyColumn, DataKind.Int64, samplingKeyColumn).Transform(data); + hashInputColumnName = newSamplingKeyColumn; + } - var localSeed = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null; var columnOptions = - localSeed.HasValue ? - new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) : - new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true); + seedToUse.HasValue ? + new HashingEstimator.ColumnOptions(newSamplingKeyColumn, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) : + new HashingEstimator.ColumnOptions(newSamplingKeyColumn, hashInputColumnName, 30, combine: true); data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { - if (data.Schema[origStratCol].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) + if (data.Schema[samplingKeyColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) { - data = new ColumnCopyingEstimator(env, (samplingKeyColumn, origStratCol)).Fit(data).Transform(data); + data = new ColumnCopyingEstimator(env, (newSamplingKeyColumn, samplingKeyColumn)).Fit(data).Transform(data); } else { - data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: false)).Fit(data).Transform(data); + data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(newSamplingKeyColumn, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data); } } } + + return newSamplingKeyColumn; } } } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index dd9ad3ac11..4694e6a93c 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -93,13 +93,13 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); Environment.CheckValueOrNull(samplingKeyColumn); - DataOperationsCatalog.EnsureGroupPreservationColumn(Environment, ref data, ref samplingKeyColumn, seed); + var newSamplingKeyColumn = DataOperationsCatalog.EnsureGroupPreservationColumn(Environment, ref data, samplingKeyColumn, seed); var result = new CrossValidationResult[numFolds]; int fold = 0; // Sequential per-fold training. // REVIEW: we could have a parallel implementation here. We would need to // spawn off a separate host per fold in that case. - foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, numFolds, samplingKeyColumn)) + foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, numFolds, newSamplingKeyColumn)) { var model = estimator.Fit(split.TrainSet); var scoredTest = model.Transform(split.TestSet); diff --git a/src/Microsoft.ML.EntryPoints/CVSplit.cs b/src/Microsoft.ML.EntryPoints/CVSplit.cs index 05c259625b..2d4bab0c1d 100644 --- a/src/Microsoft.ML.EntryPoints/CVSplit.cs +++ b/src/Microsoft.ML.EntryPoints/CVSplit.cs @@ -4,6 +4,7 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -53,7 +54,7 @@ public static Output Split(IHostEnvironment env, Input input) var data = input.Data; - var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn); + var stratCol = DataOperationsCatalog.EnsureGroupPreservationColumn(env, ref data, input.StratificationColumn); int n = input.NumFolds; var output = new Output diff --git a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs index a08e69cc6a..dc8ba258c3 100644 --- a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs @@ -50,7 +50,7 @@ public static Output Split(IHostEnvironment env, Input input) EntryPointUtils.CheckInputArgs(host, input); var data = input.Data; - var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn); + var stratCol = DataOperationsCatalog.EnsureGroupPreservationColumn(env, ref data, input.StratificationColumn); IDataView trainData = new RangeFilter(host, new RangeFilter.Options { Column = stratCol, Min = 0, Max = input.Fraction, Complement = false }, data); @@ -64,63 +64,4 @@ public static Output Split(IHostEnvironment env, Input input) } } - - internal static class SplitUtils - { - // Creates a new Stratification column to be used for splitting. - // Notice that the new column might be dropped elsewhere in the code - // Returns: the name of the new column. - public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null) - { - host.CheckValue(data, nameof(data)); - host.CheckValueOrNull(stratificationColumn); - - // Pick a unique name for the new stratificationColumn. - const string stratColName = "StratificationKey"; - string stratCol = data.Schema.GetTempColumnName(stratColName); - - if (stratificationColumn == null) - { - // If the stratificationColumn wasn't provided by the user, simply create a new Random Number Generator - data = new GenerateNumberTransform(host, - new GenerateNumberTransform.Options - { - Columns = new[] { new GenerateNumberTransform.Column { Name = stratCol } } - }, data); - } - else - { - var col = data.Schema.GetColumnOrNull(stratificationColumn); - if (col == null) - throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn); - - var type = col.Value.Type; - if (!RangeFilter.IsValidRangeFilterColumnType(host, type)) - { - // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. - var itemType = type.GetItemType(); - if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) - data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data); - - var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true); - data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data); - } - else - { - if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) - { - data = new ColumnCopyingEstimator(host,(stratCol,stratificationColumn)).Fit(data).Transform(data); - } - else - { - data = new NormalizingEstimator(host, - new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true)) - .Fit(data).Transform(data); - } - } - } - - return stratCol; - } - } } \ No newline at end of file diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 132b62d937..b0307a9df3 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -5734,7 +5734,7 @@ public void TestCrossValidationMacroWithStratification() getter(ref stdev); foldGetter(ref fold); Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); - Assert.Equal(0.0087, stdev, 5); + Assert.Equal(0.02582, stdev, 5); double sum = 0; double val = 0; From cc9f6ede74d3123f31b46926fcf1191125ab8149 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 17 Jun 2020 22:35:23 -0700 Subject: [PATCH 08/13] Change MAML to also use EnsureGroupPreservationColumn method --- .../Commands/CrossValidationCommand.cs | 37 +------------------ ...dationWithTextStratificationColumn-out.txt | 1 - 2 files changed, 2 insertions(+), 36 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 7dadaf2b36..020d295e75 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -305,41 +305,8 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output } } - if (string.IsNullOrEmpty(stratificationColumn)) - { - stratificationColumn = "StratificationColumn"; - int tmp; - int inc = 0; - while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) - stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc); - var keyGenArgs = new GenerateNumberTransform.Options(); - var col = new GenerateNumberTransform.Column(); - col.Name = stratificationColumn; - keyGenArgs.Columns = new[] { col }; - output = new GenerateNumberTransform(Host, keyGenArgs, input); - } - else - { - int col; - if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col)) - throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn); - var type = input.Schema[col].Type; - if (!RangeFilter.IsValidRangeFilterColumnType(ch, type)) - { - ch.Info("Hashing the stratification column"); - var origStratCol = stratificationColumn; - stratificationColumn = input.Schema.GetTempColumnName("strat"); - - // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. - var itemType = type.GetItemType(); - if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) - input = new TypeConvertingTransformer(Host, origStratCol, DataKind.Int64, origStratCol).Transform(input); - - output = new HashingEstimator(Host, stratificationColumn, origStratCol, 30).Fit(input).Transform(input); - } - } - - return stratificationColumn; + var newStratificationColumn = DataOperationsCatalog.EnsureGroupPreservationColumn(Host, ref output, stratificationColumn); + return newStratificationColumn; } private bool TryGetOverallMetrics(Dictionary[] metrics, out List overallList) diff --git a/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt b/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt index 7412d1e4a2..453f27b324 100644 --- a/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt +++ b/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt @@ -1,5 +1,4 @@ maml.exe CV tr=lr{l1=1.0 l2=0.1 ot=1e-3 nt=1} strat=Strat threads=- norm=Warn loader=text{col=Features:R4:9-14 col=Label:R4:0 col=Strat:TX:1 header+} data=%Data% out=%Output% -Hashing the stratification column Warning: A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options. Beginning optimization num vars: 7 From 37332c1eb976f7eab5125ab905e0cb82b3c9c836 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 17 Jun 2020 23:11:09 -0700 Subject: [PATCH 09/13] Change name EnsureGroupPreservationColumn to CreateGroupPreservationColumn --- .../Commands/CrossValidationCommand.cs | 2 +- .../DataLoadSave/DataOperationsCatalog.cs | 20 ++++++++++++++----- src/Microsoft.ML.Data/TrainCatalog.cs | 2 +- src/Microsoft.ML.EntryPoints/CVSplit.cs | 2 +- .../TrainTestSplit.cs | 2 +- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 020d295e75..3be53da3f3 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -305,7 +305,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output } } - var newStratificationColumn = DataOperationsCatalog.EnsureGroupPreservationColumn(Host, ref output, stratificationColumn); + var newStratificationColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Host, ref output, stratificationColumn); return newStratificationColumn; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index 9a7d3fc182..8427502fba 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -413,7 +413,7 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s _env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); _env.CheckValueOrNull(samplingKeyColumnName); - var newSamplingKeyColumn = EnsureGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); + var newSamplingKeyColumn = CreateGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var trainFilter = new RangeFilter(_env, new RangeFilter.Options() { @@ -455,7 +455,7 @@ public IReadOnlyList CrossValidationSplit(IDataView data, int num _env.CheckValue(data, nameof(data)); _env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1"); _env.CheckValueOrNull(samplingKeyColumnName); - var newSamplingKeyColumn = EnsureGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); + var newSamplingKeyColumn = CreateGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var result = new List(); foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, newSamplingKeyColumn)) result.Add(split); @@ -491,11 +491,21 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment } /// - /// Ensures the provided is valid for , hashing, copying, or normalizing it if necessary, - /// or creates a new column if is null. + /// Based on the input samplingKeyColumn creates a new column that will be used by the callers to apply a RangeFilter that will produce train-test split + /// or cross-validation splits. + /// + /// Notice that the new column might get dropped by the callers of this method after using it. /// + /// IHostEnvironment of the caller + /// DataView that should contain the "samplingKeyColumn". The new column will be added to this DataView. + /// Name of the column that will be used as base of the new GroupPreservationColumn. + /// Notice that in other places in the code this column, and/or the column that this method creates, + /// are refered to as "SamplingKeyColumn", "GroupPreservationColumn" or "StratificationColumn". + /// The seed used by the transformers that will create the new column + /// If seed = null, then should we use the env seed? If seed = null, and this parameter is false, then we won't use a seed. + /// The name of the new column [BestFriend] - internal static string EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false) + internal static string CreateGroupPreservationColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false) { Contracts.CheckValue(env, nameof(env)); Contracts.CheckValueOrNull(samplingKeyColumn); diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 4694e6a93c..fe5b6ba7b7 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -93,7 +93,7 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); Environment.CheckValueOrNull(samplingKeyColumn); - var newSamplingKeyColumn = DataOperationsCatalog.EnsureGroupPreservationColumn(Environment, ref data, samplingKeyColumn, seed); + var newSamplingKeyColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Environment, ref data, samplingKeyColumn, seed); var result = new CrossValidationResult[numFolds]; int fold = 0; // Sequential per-fold training. diff --git a/src/Microsoft.ML.EntryPoints/CVSplit.cs b/src/Microsoft.ML.EntryPoints/CVSplit.cs index 2d4bab0c1d..4d925ce262 100644 --- a/src/Microsoft.ML.EntryPoints/CVSplit.cs +++ b/src/Microsoft.ML.EntryPoints/CVSplit.cs @@ -54,7 +54,7 @@ public static Output Split(IHostEnvironment env, Input input) var data = input.Data; - var stratCol = DataOperationsCatalog.EnsureGroupPreservationColumn(env, ref data, input.StratificationColumn); + var stratCol = DataOperationsCatalog.CreateGroupPreservationColumn(env, ref data, input.StratificationColumn); int n = input.NumFolds; var output = new Output diff --git a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs index dc8ba258c3..04f44e42a6 100644 --- a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs @@ -50,7 +50,7 @@ public static Output Split(IHostEnvironment env, Input input) EntryPointUtils.CheckInputArgs(host, input); var data = input.Data; - var stratCol = DataOperationsCatalog.EnsureGroupPreservationColumn(env, ref data, input.StratificationColumn); + var stratCol = DataOperationsCatalog.CreateGroupPreservationColumn(env, ref data, input.StratificationColumn); IDataView trainData = new RangeFilter(host, new RangeFilter.Options { Column = stratCol, Min = 0, Max = input.Fraction, Complement = false }, data); From 177fa04afcd0804c9d4d493d620ea2feec5269bf Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 17 Jun 2020 23:12:48 -0700 Subject: [PATCH 10/13] Fallback on EnvSeed when calling through TrainCatalog (fix Tensorflow failing test) --- src/Microsoft.ML.Data/TrainCatalog.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index fe5b6ba7b7..fdd7833285 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -93,7 +93,7 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); Environment.CheckValueOrNull(samplingKeyColumn); - var newSamplingKeyColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Environment, ref data, samplingKeyColumn, seed); + var newSamplingKeyColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true); var result = new CrossValidationResult[numFolds]; int fold = 0; // Sequential per-fold training. From e2a76c70d23d7661bb20be6c23e3631eaa89c3c6 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Mon, 22 Jun 2020 20:09:42 -0700 Subject: [PATCH 11/13] Normalize even the column was already normalized --- src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index 8427502fba..b70b9289fe 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -560,7 +560,7 @@ internal static string CreateGroupPreservationColumn(IHostEnvironment env, ref I } else { - if (data.Schema[samplingKeyColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) + if (type != NumberDataViewType.Single && type != NumberDataViewType.Double) { data = new ColumnCopyingEstimator(env, (newSamplingKeyColumn, samplingKeyColumn)).Fit(data).Transform(data); } From 2cf9d6a7b885eca9834ba8210c22162cf6eb5bb3 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Fri, 26 Jun 2020 21:49:42 -0700 Subject: [PATCH 12/13] * Drop newSamplingKeyColumn after splitting dataset to: ** Fix issue with AutoML ** Enforce having the same schema before and after splitting, avoiding future issues * Added tests for these * Enforce that samplingKeyColumnName shouldn't be null by the time we split stuff, since anyway it will throw in the RangeFilter if its null. Also change its name to tempSamplingKeyColumnName since it's going to be dropped. --- .../DataLoadSave/DataOperationsCatalog.cs | 24 ++++++-- src/Microsoft.ML.Data/TrainCatalog.cs | 2 +- .../Microsoft.ML.AutoML.Tests/AutoFitTests.cs | 57 +++++++++++++++++++ .../Scenarios/Api/TestApi.cs | 49 ++++++++++++++++ 4 files changed, 125 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index b70b9289fe..6c18f83077 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -430,7 +430,10 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s Complement = false }, data); - return new TrainTestData(trainFilter, testFilter); + var trainDV = ColumnSelectingTransformer.CreateDrop(_env, trainFilter, newSamplingKeyColumn); + var testDV = ColumnSelectingTransformer.CreateDrop(_env, testFilter, newSamplingKeyColumn); + + return new TrainTestData(trainDV, testDV); } /// @@ -457,18 +460,24 @@ public IReadOnlyList CrossValidationSplit(IDataView data, int num _env.CheckValueOrNull(samplingKeyColumnName); var newSamplingKeyColumn = CreateGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var result = new List(); - foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, newSamplingKeyColumn)) + foreach (var split in CrossValidationSplit(_env, data, newSamplingKeyColumn, numberOfFolds)) result.Add(split); return result; } - internal static IEnumerable CrossValidationSplit(IHostEnvironment env, IDataView data, int numberOfFolds = 5, string samplingKeyColumnName = null) + /// + /// Splits the data based on the tempSamplingKeyColumnName, and drops that column as it is only + /// intended to be used for splitting the data + /// + internal static IEnumerable CrossValidationSplit(IHostEnvironment env, IDataView data, string tempSamplingKeyColumnName, int numberOfFolds = 5) { + env.CheckValue(tempSamplingKeyColumnName, nameof(tempSamplingKeyColumnName)); + for (int fold = 0; fold < numberOfFolds; fold++) { var trainFilter = new RangeFilter(env, new RangeFilter.Options { - Column = samplingKeyColumnName, + Column = tempSamplingKeyColumnName, Min = (double)fold / numberOfFolds, Max = (double)(fold + 1) / numberOfFolds, Complement=true, @@ -478,7 +487,7 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment var testFilter = new RangeFilter(env, new RangeFilter.Options { - Column = samplingKeyColumnName, + Column = tempSamplingKeyColumnName, Min = (double)fold / numberOfFolds, Max = (double)(fold + 1) / numberOfFolds, Complement = false, @@ -486,7 +495,10 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment IncludeMax = true }, data); - yield return new TrainTestData(trainFilter, testFilter); + var trainDV = ColumnSelectingTransformer.CreateDrop(env, trainFilter, tempSamplingKeyColumnName); + var testDV = ColumnSelectingTransformer.CreateDrop(env, testFilter, tempSamplingKeyColumnName); + + yield return new TrainTestData(trainDV, testDV); } } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index fdd7833285..22c8fe86a2 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -99,7 +99,7 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs // Sequential per-fold training. // REVIEW: we could have a parallel implementation here. We would need to // spawn off a separate host per fold in that case. - foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, numFolds, newSamplingKeyColumn)) + foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, newSamplingKeyColumn, numFolds)) { var model = estimator.Fit(split.TrainSet); var scoredTest = model.Transform(split.TestSet); diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs index b43acf7af1..988149cf93 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -202,6 +202,63 @@ public void AutoFitRecommendationTest() Assert.NotEqual(0, metrices.MeanSquaredError); } + [Fact] + public void AutoFitWithPresplittedData() + { + // Models created in AutoML should work over the same data, + // no matter how that data is splitted before passing it to the experiment execution + // or to the model for prediction + + var context = new MLContext(1); + var dataPath = DatasetUtil.GetUciAdultDataset(); + var columnInference = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel); + var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); + var dataFull = textLoader.Load(dataPath); + var dataTrainTest = context.Data.TrainTestSplit(dataFull); + var dataCV = context.Data.CrossValidationSplit(dataFull, numberOfFolds: 2); + + var modelFull = context.Auto() + .CreateBinaryClassificationExperiment(0) + .Execute(dataFull, + new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }) + .BestRun + .Model; + + var modelTrainTest = context.Auto() + .CreateBinaryClassificationExperiment(0) + .Execute(dataTrainTest.TrainSet, + new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }) + .BestRun + .Model; + + var modelCV = context.Auto() + .CreateBinaryClassificationExperiment(0) + .Execute(dataCV.First().TrainSet, + new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }) + .BestRun + .Model; + + var models = new[] { modelFull, modelTrainTest, modelCV }; + + foreach(var model in models) + { + var resFull = model.Transform(dataFull); + var resTrainTest = model.Transform(dataTrainTest.TrainSet); + var resCV = model.Transform(dataCV.First().TrainSet); + + Assert.Equal(30, resFull.Schema.Count); + Assert.Equal(30, resTrainTest.Schema.Count); + Assert.Equal(30, resCV.Schema.Count); + + foreach (var col in resFull.Schema) + { + Assert.Equal(col.Name, resTrainTest.Schema[col.Index].Name); + Assert.Equal(col.Name, resCV.Schema[col.Index].Name); + } + } + + } + private TextLoader.Options GetLoaderArgs(string labelColumnName, string userIdColumnName, string itemIdColumnName) { return new TextLoader.Options() diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 1f80233a4e..5b036639e2 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -290,6 +290,55 @@ private List ReadBreastCancerExamples() return data; } + [Fact] + public void TestSplitsSchema() + { + + var mlContext = new MLContext(0); + var dataPath = GetDataPath("adult.tiny.with-schema.txt"); + + var fullInput = mlContext.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Workclass", DataKind.String, 1), + new TextLoader.Column("Education", DataKind.String,2), + new TextLoader.Column("Age", DataKind.Single,9) + }, hasHeader: true); + + var ttSplit = mlContext.Data.TrainTestSplit(fullInput); + var ttSplitWithSeed = mlContext.Data.TrainTestSplit(fullInput, seed: 10); + var ttSplitWithSeedAndSamplingKey = mlContext.Data.TrainTestSplit(fullInput, seed: 10, samplingKeyColumnName: "Workclass"); + + var cvSplit = mlContext.Data.CrossValidationSplit(fullInput); + var cvSplitWithSeed = mlContext.Data.CrossValidationSplit(fullInput, seed: 10); + var cvSplitWithSeedAndSamplingKey = mlContext.Data.CrossValidationSplit(fullInput, seed: 10, samplingKeyColumnName: "Workclass"); + + var splits = new[] + { + ttSplit.TrainSet, + ttSplit.TestSet, + ttSplitWithSeed.TrainSet, + ttSplitWithSeed.TestSet, + ttSplitWithSeedAndSamplingKey.TrainSet, + ttSplitWithSeedAndSamplingKey.TestSet, + cvSplit.First().TrainSet, + cvSplit.First().TestSet, + cvSplitWithSeed.First().TrainSet, + cvSplitWithSeed.First().TestSet, + cvSplitWithSeedAndSamplingKey.First().TrainSet, + cvSplitWithSeedAndSamplingKey.First().TestSet + }; + + // Splitting a dataset shouldn't affect its schema + foreach(var split in splits) + { + Assert.Equal(fullInput.Schema.Count, split.Schema.Count); + foreach(var col in fullInput.Schema) + { + Assert.Equal(col.Name, split.Schema[col.Index].Name); + } + } + } + [Fact] public void TestTrainTestSplit() { From 3656236e5d11a338f7cc981d412de57bed30697f Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Fri, 26 Jun 2020 23:10:43 -0700 Subject: [PATCH 13/13] Renamed CreatePreservationColumn to CreatSplitColumn and other similar renames for consistency. Internally we'll call the new column SplitColumn, regardless if it's based off a "samplingKeyColumn" or a "stratificationColumn" --- .../Commands/CrossValidationCommand.cs | 4 +- .../DataLoadSave/DataOperationsCatalog.cs | 69 ++++++++++--------- src/Microsoft.ML.Data/TrainCatalog.cs | 4 +- src/Microsoft.ML.EntryPoints/CVSplit.cs | 10 +-- .../TrainTestSplit.cs | 10 +-- 5 files changed, 49 insertions(+), 48 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 3be53da3f3..9b40b3e150 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -305,8 +305,8 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output } } - var newStratificationColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Host, ref output, stratificationColumn); - return newStratificationColumn; + var splitColumn = DataOperationsCatalog.CreateSplitColumn(Host, ref output, stratificationColumn); + return splitColumn; } private bool TryGetOverallMetrics(Dictionary[] metrics, out List overallList) diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index 6c18f83077..1cac67e05a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -413,25 +413,25 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s _env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); _env.CheckValueOrNull(samplingKeyColumnName); - var newSamplingKeyColumn = CreateGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); + var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var trainFilter = new RangeFilter(_env, new RangeFilter.Options() { - Column = newSamplingKeyColumn, + Column = splitColumn, Min = 0, Max = testFraction, Complement = true }, data); var testFilter = new RangeFilter(_env, new RangeFilter.Options() { - Column = newSamplingKeyColumn, + Column = splitColumn, Min = 0, Max = testFraction, Complement = false }, data); - var trainDV = ColumnSelectingTransformer.CreateDrop(_env, trainFilter, newSamplingKeyColumn); - var testDV = ColumnSelectingTransformer.CreateDrop(_env, testFilter, newSamplingKeyColumn); + var trainDV = ColumnSelectingTransformer.CreateDrop(_env, trainFilter, splitColumn); + var testDV = ColumnSelectingTransformer.CreateDrop(_env, testFilter, splitColumn); return new TrainTestData(trainDV, testDV); } @@ -458,26 +458,26 @@ public IReadOnlyList CrossValidationSplit(IDataView data, int num _env.CheckValue(data, nameof(data)); _env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1"); _env.CheckValueOrNull(samplingKeyColumnName); - var newSamplingKeyColumn = CreateGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); + var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var result = new List(); - foreach (var split in CrossValidationSplit(_env, data, newSamplingKeyColumn, numberOfFolds)) + foreach (var split in CrossValidationSplit(_env, data, splitColumn, numberOfFolds)) result.Add(split); return result; } /// - /// Splits the data based on the tempSamplingKeyColumnName, and drops that column as it is only - /// intended to be used for splitting the data + /// Splits the data based on the splitColumn, and drops that column as it is only + /// intended to be used for splitting the data, and shouldn't be part of the output schema. /// - internal static IEnumerable CrossValidationSplit(IHostEnvironment env, IDataView data, string tempSamplingKeyColumnName, int numberOfFolds = 5) + internal static IEnumerable CrossValidationSplit(IHostEnvironment env, IDataView data, string splitColumn, int numberOfFolds = 5) { - env.CheckValue(tempSamplingKeyColumnName, nameof(tempSamplingKeyColumnName)); + env.CheckValue(splitColumn, nameof(splitColumn)); for (int fold = 0; fold < numberOfFolds; fold++) { var trainFilter = new RangeFilter(env, new RangeFilter.Options { - Column = tempSamplingKeyColumnName, + Column = splitColumn, Min = (double)fold / numberOfFolds, Max = (double)(fold + 1) / numberOfFolds, Complement=true, @@ -487,7 +487,7 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment var testFilter = new RangeFilter(env, new RangeFilter.Options { - Column = tempSamplingKeyColumnName, + Column = splitColumn, Min = (double)fold / numberOfFolds, Max = (double)(fold + 1) / numberOfFolds, Complement = false, @@ -495,34 +495,35 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment IncludeMax = true }, data); - var trainDV = ColumnSelectingTransformer.CreateDrop(env, trainFilter, tempSamplingKeyColumnName); - var testDV = ColumnSelectingTransformer.CreateDrop(env, testFilter, tempSamplingKeyColumnName); + var trainDV = ColumnSelectingTransformer.CreateDrop(env, trainFilter, splitColumn); + var testDV = ColumnSelectingTransformer.CreateDrop(env, testFilter, splitColumn); yield return new TrainTestData(trainDV, testDV); } } /// - /// Based on the input samplingKeyColumn creates a new column that will be used by the callers to apply a RangeFilter that will produce train-test split + /// Based on the input samplingKeyColumn creates a new splitColumn that will be used by the callers to apply a RangeFilter that will produce train-test splits /// or cross-validation splits. /// - /// Notice that the new column might get dropped by the callers of this method after using it. + /// Notice that the new splitColumn might get dropped by the callers of this method after using it, as it wasn't part of + /// the input DataView schema. /// /// IHostEnvironment of the caller - /// DataView that should contain the "samplingKeyColumn". The new column will be added to this DataView. - /// Name of the column that will be used as base of the new GroupPreservationColumn. - /// Notice that in other places in the code this column, and/or the column that this method creates, - /// are refered to as "SamplingKeyColumn", "GroupPreservationColumn" or "StratificationColumn". - /// The seed used by the transformers that will create the new column + /// DataView that should contain the "samplingKeyColumn". The new splitColumn will be added to this DataView. + /// Name of the column that will be used as base of the new splitColumn. + /// Notice that in other places in the code the samplingKeyColumn, and/or the splitColumn this method creates, + /// are refered to as "SamplingKeyColumn", "StratificationColumn", "SplitColumn", "GroupPreservationColumn" or similar names. + /// The seed that might be used by the transformers that will create the new splitColumn /// If seed = null, then should we use the env seed? If seed = null, and this parameter is false, then we won't use a seed. /// The name of the new column [BestFriend] - internal static string CreateGroupPreservationColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false) + internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false) { Contracts.CheckValue(env, nameof(env)); Contracts.CheckValueOrNull(samplingKeyColumn); - var newSamplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); + var splitColumnName = data.Schema.GetTempColumnName("SplitColumn"); int? seedToUse; if(seed.HasValue) @@ -542,17 +543,17 @@ internal static string CreateGroupPreservationColumn(IHostEnvironment env, ref I // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number. if (samplingKeyColumn == null) { - data = new GenerateNumberTransform(env, data, newSamplingKeyColumn, (uint?)seedToUse); + data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse); } else { // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary // name, as it might be dropped elsewhere in the code - if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) + if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex)) throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); - var type = data.Schema[stratCol].Type; + var type = data.Schema[samplingColIndex].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { var hashInputColumnName = samplingKeyColumn; @@ -560,30 +561,30 @@ internal static string CreateGroupPreservationColumn(IHostEnvironment env, ref I var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { - data = new TypeConvertingTransformer(env, newSamplingKeyColumn, DataKind.Int64, samplingKeyColumn).Transform(data); - hashInputColumnName = newSamplingKeyColumn; + data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data); + hashInputColumnName = splitColumnName; } var columnOptions = seedToUse.HasValue ? - new HashingEstimator.ColumnOptions(newSamplingKeyColumn, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) : - new HashingEstimator.ColumnOptions(newSamplingKeyColumn, hashInputColumnName, 30, combine: true); + new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) : + new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true); data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { if (type != NumberDataViewType.Single && type != NumberDataViewType.Double) { - data = new ColumnCopyingEstimator(env, (newSamplingKeyColumn, samplingKeyColumn)).Fit(data).Transform(data); + data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data); } else { - data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(newSamplingKeyColumn, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data); + data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data); } } } - return newSamplingKeyColumn; + return splitColumnName; } } } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 22c8fe86a2..f063c342cc 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -93,13 +93,13 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); Environment.CheckValueOrNull(samplingKeyColumn); - var newSamplingKeyColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true); + var splitColumn = DataOperationsCatalog.CreateSplitColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true); var result = new CrossValidationResult[numFolds]; int fold = 0; // Sequential per-fold training. // REVIEW: we could have a parallel implementation here. We would need to // spawn off a separate host per fold in that case. - foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, newSamplingKeyColumn, numFolds)) + foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, splitColumn, numFolds)) { var model = estimator.Fit(split.TrainSet); var scoredTest = model.Transform(split.TestSet); diff --git a/src/Microsoft.ML.EntryPoints/CVSplit.cs b/src/Microsoft.ML.EntryPoints/CVSplit.cs index 4d925ce262..87d3de3be2 100644 --- a/src/Microsoft.ML.EntryPoints/CVSplit.cs +++ b/src/Microsoft.ML.EntryPoints/CVSplit.cs @@ -54,7 +54,7 @@ public static Output Split(IHostEnvironment env, Input input) var data = input.Data; - var stratCol = DataOperationsCatalog.CreateGroupPreservationColumn(env, ref data, input.StratificationColumn); + var splitCol = DataOperationsCatalog.CreateSplitColumn(env, ref data, input.StratificationColumn); int n = input.NumFolds; var output = new Output @@ -68,12 +68,12 @@ public static Output Split(IHostEnvironment env, Input input) for (int i = 0; i < n; i++) { var trainData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data); - output.TrainData[i] = ColumnSelectingTransformer.CreateDrop(host, trainData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data); + output.TrainData[i] = ColumnSelectingTransformer.CreateDrop(host, trainData, splitCol); var testData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data); - output.TestData[i] = ColumnSelectingTransformer.CreateDrop(host, testData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data); + output.TestData[i] = ColumnSelectingTransformer.CreateDrop(host, testData, splitCol); } return output; diff --git a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs index 04f44e42a6..42016fd247 100644 --- a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs @@ -50,15 +50,15 @@ public static Output Split(IHostEnvironment env, Input input) EntryPointUtils.CheckInputArgs(host, input); var data = input.Data; - var stratCol = DataOperationsCatalog.CreateGroupPreservationColumn(env, ref data, input.StratificationColumn); + var splitCol = DataOperationsCatalog.CreateSplitColumn(env, ref data, input.StratificationColumn); IDataView trainData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = 0, Max = input.Fraction, Complement = false }, data); - trainData = ColumnSelectingTransformer.CreateDrop(host, trainData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = 0, Max = input.Fraction, Complement = false }, data); + trainData = ColumnSelectingTransformer.CreateDrop(host, trainData, splitCol); IDataView testData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = 0, Max = input.Fraction, Complement = true }, data); - testData = ColumnSelectingTransformer.CreateDrop(host, testData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = 0, Max = input.Fraction, Complement = true }, data); + testData = ColumnSelectingTransformer.CreateDrop(host, testData, splitCol); return new Output() { TrainData = trainData, TestData = testData }; }