diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 7dadaf2b36..9b40b3e150 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -305,41 +305,8 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output } } - if (string.IsNullOrEmpty(stratificationColumn)) - { - stratificationColumn = "StratificationColumn"; - int tmp; - int inc = 0; - while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) - stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc); - var keyGenArgs = new GenerateNumberTransform.Options(); - var col = new GenerateNumberTransform.Column(); - col.Name = stratificationColumn; - keyGenArgs.Columns = new[] { col }; - output = new GenerateNumberTransform(Host, keyGenArgs, input); - } - else - { - int col; - if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col)) - throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn); - var type = input.Schema[col].Type; - if (!RangeFilter.IsValidRangeFilterColumnType(ch, type)) - { - ch.Info("Hashing the stratification column"); - var origStratCol = stratificationColumn; - stratificationColumn = input.Schema.GetTempColumnName("strat"); - - // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. - var itemType = type.GetItemType(); - if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) - input = new TypeConvertingTransformer(Host, origStratCol, DataKind.Int64, origStratCol).Transform(input); - - output = new HashingEstimator(Host, stratificationColumn, origStratCol, 30).Fit(input).Transform(input); - } - } - - return stratificationColumn; + var splitColumn = DataOperationsCatalog.CreateSplitColumn(Host, ref output, stratificationColumn); + return splitColumn; } private bool TryGetOverallMetrics(Dictionary[] metrics, out List overallList) diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index 0bd4075670..1cac67e05a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -413,24 +413,27 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s _env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); _env.CheckValueOrNull(samplingKeyColumnName); - EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed); + var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var trainFilter = new RangeFilter(_env, new RangeFilter.Options() { - Column = samplingKeyColumnName, + Column = splitColumn, Min = 0, Max = testFraction, Complement = true }, data); var testFilter = new RangeFilter(_env, new RangeFilter.Options() { - Column = samplingKeyColumnName, + Column = splitColumn, Min = 0, Max = testFraction, Complement = false }, data); - return new TrainTestData(trainFilter, testFilter); + var trainDV = ColumnSelectingTransformer.CreateDrop(_env, trainFilter, splitColumn); + var testDV = ColumnSelectingTransformer.CreateDrop(_env, testFilter, splitColumn); + + return new TrainTestData(trainDV, testDV); } /// @@ -455,20 +458,26 @@ public IReadOnlyList CrossValidationSplit(IDataView data, int num _env.CheckValue(data, nameof(data)); _env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1"); _env.CheckValueOrNull(samplingKeyColumnName); - EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed); + var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true); var result = new List(); - foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, samplingKeyColumnName)) + foreach (var split in CrossValidationSplit(_env, data, splitColumn, numberOfFolds)) result.Add(split); return result; } - internal static IEnumerable CrossValidationSplit(IHostEnvironment env, IDataView data, int numberOfFolds = 5, string samplingKeyColumnName = null) + /// + /// Splits the data based on the splitColumn, and drops that column as it is only + /// intended to be used for splitting the data, and shouldn't be part of the output schema. + /// + internal static IEnumerable CrossValidationSplit(IHostEnvironment env, IDataView data, string splitColumn, int numberOfFolds = 5) { + env.CheckValue(splitColumn, nameof(splitColumn)); + for (int fold = 0; fold < numberOfFolds; fold++) { var trainFilter = new RangeFilter(env, new RangeFilter.Options { - Column = samplingKeyColumnName, + Column = splitColumn, Min = (double)fold / numberOfFolds, Max = (double)(fold + 1) / numberOfFolds, Complement=true, @@ -478,7 +487,7 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment var testFilter = new RangeFilter(env, new RangeFilter.Options { - Column = samplingKeyColumnName, + Column = splitColumn, Min = (double)fold / numberOfFolds, Max = (double)(fold + 1) / numberOfFolds, Complement = false, @@ -486,55 +495,96 @@ internal static IEnumerable CrossValidationSplit(IHostEnvironment IncludeMax = true }, data); - yield return new TrainTestData(trainFilter, testFilter); + var trainDV = ColumnSelectingTransformer.CreateDrop(env, trainFilter, splitColumn); + var testDV = ColumnSelectingTransformer.CreateDrop(env, testFilter, splitColumn); + + yield return new TrainTestData(trainDV, testDV); } } /// - /// Ensures the provided is valid for , hashing it if necessary, or creates a new column is null. + /// Based on the input samplingKeyColumn creates a new splitColumn that will be used by the callers to apply a RangeFilter that will produce train-test splits + /// or cross-validation splits. + /// + /// Notice that the new splitColumn might get dropped by the callers of this method after using it, as it wasn't part of + /// the input DataView schema. /// - internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null) + /// IHostEnvironment of the caller + /// DataView that should contain the "samplingKeyColumn". The new splitColumn will be added to this DataView. + /// Name of the column that will be used as base of the new splitColumn. + /// Notice that in other places in the code the samplingKeyColumn, and/or the splitColumn this method creates, + /// are refered to as "SamplingKeyColumn", "StratificationColumn", "SplitColumn", "GroupPreservationColumn" or similar names. + /// The seed that might be used by the transformers that will create the new splitColumn + /// If seed = null, then should we use the env seed? If seed = null, and this parameter is false, then we won't use a seed. + /// The name of the new column + [BestFriend] + internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false) { Contracts.CheckValue(env, nameof(env)); - // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to - // build a single hash of it. If it is not, we generate a random number. + Contracts.CheckValueOrNull(samplingKeyColumn); + + var splitColumnName = data.Schema.GetTempColumnName("SplitColumn"); + int? seedToUse; + + if(seed.HasValue) + { + seedToUse = seed.Value; + } + else if(fallbackInEnvSeed) + { + ISeededEnvironment seededEnv = (ISeededEnvironment)env; + seedToUse = seededEnv.Seed; + } + else + { + seedToUse = null; + } + + // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number. if (samplingKeyColumn == null) { - samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); - data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed)); + data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse); } else { - if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) + // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary + // name, as it might be dropped elsewhere in the code + + if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex)) throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); - var type = data.Schema[stratCol].Type; + var type = data.Schema[samplingColIndex].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { - var origStratCol = samplingKeyColumn; - samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); + var hashInputColumnName = samplingKeyColumn; // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) - data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data); + { + data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data); + hashInputColumnName = splitColumnName; + } - var localSeed = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null; var columnOptions = - localSeed.HasValue ? - new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) : - new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true); + seedToUse.HasValue ? + new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) : + new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true); data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { - if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double)) + if (type != NumberDataViewType.Single && type != NumberDataViewType.Double) { - var origStratCol = samplingKeyColumn; - samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); - data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data); + data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data); + } + else + { + data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data); } } } + + return splitColumnName; } } } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index dd9ad3ac11..f063c342cc 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -93,13 +93,13 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); Environment.CheckValueOrNull(samplingKeyColumn); - DataOperationsCatalog.EnsureGroupPreservationColumn(Environment, ref data, ref samplingKeyColumn, seed); + var splitColumn = DataOperationsCatalog.CreateSplitColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true); var result = new CrossValidationResult[numFolds]; int fold = 0; // Sequential per-fold training. // REVIEW: we could have a parallel implementation here. We would need to // spawn off a separate host per fold in that case. - foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, numFolds, samplingKeyColumn)) + foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, splitColumn, numFolds)) { var model = estimator.Fit(split.TrainSet); var scoredTest = model.Transform(split.TestSet); diff --git a/src/Microsoft.ML.EntryPoints/CVSplit.cs b/src/Microsoft.ML.EntryPoints/CVSplit.cs index 05c259625b..87d3de3be2 100644 --- a/src/Microsoft.ML.EntryPoints/CVSplit.cs +++ b/src/Microsoft.ML.EntryPoints/CVSplit.cs @@ -4,6 +4,7 @@ using Microsoft.ML; using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -53,7 +54,7 @@ public static Output Split(IHostEnvironment env, Input input) var data = input.Data; - var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn); + var splitCol = DataOperationsCatalog.CreateSplitColumn(env, ref data, input.StratificationColumn); int n = input.NumFolds; var output = new Output @@ -67,12 +68,12 @@ public static Output Split(IHostEnvironment env, Input input) for (int i = 0; i < n; i++) { var trainData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data); - output.TrainData[i] = ColumnSelectingTransformer.CreateDrop(host, trainData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data); + output.TrainData[i] = ColumnSelectingTransformer.CreateDrop(host, trainData, splitCol); var testData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data); - output.TestData[i] = ColumnSelectingTransformer.CreateDrop(host, testData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data); + output.TestData[i] = ColumnSelectingTransformer.CreateDrop(host, testData, splitCol); } return output; diff --git a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs index 5d953668e9..42016fd247 100644 --- a/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML.EntryPoints/TrainTestSplit.cs @@ -50,71 +50,18 @@ public static Output Split(IHostEnvironment env, Input input) EntryPointUtils.CheckInputArgs(host, input); var data = input.Data; - var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn); + var splitCol = DataOperationsCatalog.CreateSplitColumn(env, ref data, input.StratificationColumn); IDataView trainData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = 0, Max = input.Fraction, Complement = false }, data); - trainData = ColumnSelectingTransformer.CreateDrop(host, trainData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = 0, Max = input.Fraction, Complement = false }, data); + trainData = ColumnSelectingTransformer.CreateDrop(host, trainData, splitCol); IDataView testData = new RangeFilter(host, - new RangeFilter.Options { Column = stratCol, Min = 0, Max = input.Fraction, Complement = true }, data); - testData = ColumnSelectingTransformer.CreateDrop(host, testData, stratCol); + new RangeFilter.Options { Column = splitCol, Min = 0, Max = input.Fraction, Complement = true }, data); + testData = ColumnSelectingTransformer.CreateDrop(host, testData, splitCol); return new Output() { TrainData = trainData, TestData = testData }; } } - - internal static class SplitUtils - { - public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null) - { - host.CheckValue(data, nameof(data)); - host.CheckValueOrNull(stratificationColumn); - - // Pick a unique name for the stratificationColumn. - const string stratColName = "StratificationKey"; - string stratCol = data.Schema.GetTempColumnName(stratColName); - - // Construct the stratification column. If user-provided stratification column exists, use HashJoin - // of it to construct the strat column, otherwise generate a random number and use it. - if (stratificationColumn == null) - { - data = new GenerateNumberTransform(host, - new GenerateNumberTransform.Options - { - Columns = new[] { new GenerateNumberTransform.Column { Name = stratCol } } - }, data); - } - else - { - var col = data.Schema.GetColumnOrNull(stratificationColumn); - if (col == null) - throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn); - - var type = col.Value.Type; - if (!RangeFilter.IsValidRangeFilterColumnType(host, type)) - { - // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. - var itemType = type.GetItemType(); - if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) - data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data); - - var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true); - data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data); - } - else - { - if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) - return stratificationColumn; - - data = new NormalizingEstimator(host, - new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true)) - .Fit(data).Transform(data); - } - } - - return stratCol; - } - } } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt b/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt index 7412d1e4a2..453f27b324 100644 --- a/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt +++ b/test/BaselineOutput/Common/Command/CommandCrossValidationWithTextStratificationColumn-out.txt @@ -1,5 +1,4 @@ maml.exe CV tr=lr{l1=1.0 l2=0.1 ot=1e-3 nt=1} strat=Strat threads=- norm=Warn loader=text{col=Features:R4:9-14 col=Label:R4:0 col=Strat:TX:1 header+} data=%Data% out=%Output% -Hashing the stratification column Warning: A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options. Beginning optimization num vars: 7 diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs index b43acf7af1..988149cf93 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -202,6 +202,63 @@ public void AutoFitRecommendationTest() Assert.NotEqual(0, metrices.MeanSquaredError); } + [Fact] + public void AutoFitWithPresplittedData() + { + // Models created in AutoML should work over the same data, + // no matter how that data is splitted before passing it to the experiment execution + // or to the model for prediction + + var context = new MLContext(1); + var dataPath = DatasetUtil.GetUciAdultDataset(); + var columnInference = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel); + var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); + var dataFull = textLoader.Load(dataPath); + var dataTrainTest = context.Data.TrainTestSplit(dataFull); + var dataCV = context.Data.CrossValidationSplit(dataFull, numberOfFolds: 2); + + var modelFull = context.Auto() + .CreateBinaryClassificationExperiment(0) + .Execute(dataFull, + new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }) + .BestRun + .Model; + + var modelTrainTest = context.Auto() + .CreateBinaryClassificationExperiment(0) + .Execute(dataTrainTest.TrainSet, + new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }) + .BestRun + .Model; + + var modelCV = context.Auto() + .CreateBinaryClassificationExperiment(0) + .Execute(dataCV.First().TrainSet, + new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }) + .BestRun + .Model; + + var models = new[] { modelFull, modelTrainTest, modelCV }; + + foreach(var model in models) + { + var resFull = model.Transform(dataFull); + var resTrainTest = model.Transform(dataTrainTest.TrainSet); + var resCV = model.Transform(dataCV.First().TrainSet); + + Assert.Equal(30, resFull.Schema.Count); + Assert.Equal(30, resTrainTest.Schema.Count); + Assert.Equal(30, resCV.Schema.Count); + + foreach (var col in resFull.Schema) + { + Assert.Equal(col.Name, resTrainTest.Schema[col.Index].Name); + Assert.Equal(col.Name, resCV.Schema[col.Index].Name); + } + } + + } + private TextLoader.Options GetLoaderArgs(string labelColumnName, string userIdColumnName, string itemIdColumnName) { return new TextLoader.Options() diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 22da23dd4e..b0307a9df3 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -5734,7 +5734,7 @@ public void TestCrossValidationMacroWithStratification() getter(ref stdev); foldGetter(ref fold); Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); - Assert.Equal(0.0087, stdev, 5); + Assert.Equal(0.02582, stdev, 5); double sum = 0; double val = 0; @@ -6613,5 +6613,231 @@ public void SummarizeEntryPointTest() Done(); } + + [LightGBMFact] + void RankingWithColumnIdEntryPoint() + { + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); + var dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); + + string inputGraph = $@" + {{ + 'Nodes': [ + {{ + 'Inputs': {{ + 'CustomSchema': 'col=Label:R4:0 col=GroupId:TX:1 col=Features_1:R4:9-14 header+ sep=tab', + 'InputFile': '$file' + }}, + 'Name': 'Data.CustomTextLoader', + 'Outputs': {{ + 'Data': '$input_data' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'GroupId', + 'Source': 'GroupId' + }} + ], + 'Data': '$input_data', + 'MaxNumTerms': 1000000, + 'Sort': 'ByOccurrence', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.TextToKeyConverter', + 'Outputs': {{ + 'Model': '$output_model1', + 'OutputData': '$output_data1' + }} + }}, + {{ + 'Inputs': {{ + 'Column': [ + {{ + 'Name': 'Features', + 'Source': [ + 'Features_1' + ] + }} + ], + 'Data': '$output_data1' + }}, + 'Name': 'Transforms.ColumnConcatenator', + 'Outputs': {{ + 'Model': '$output_model2', + 'OutputData': '$output_data2' + }} + }}, + {{ + 'Inputs': {{ + 'Models': [ + '$output_model1', + '$output_model2' + ] + }}, + 'Name': 'Transforms.ModelCombiner', + 'Outputs': {{ + 'OutputModel': '$output_model_combined_pre_split' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$output_data2', + 'GroupColumn': 'GroupId', + 'Inputs': {{ + 'Data': '$cv_subgraph_input_data' + }}, + 'Kind': 'SignatureRankerTrainer', + 'LabelColumn': 'Label', + 'NameColumn': 'Name', + 'Nodes': [ + {{ + 'Inputs': {{ + 'Column': [ + 'Label' + ], + 'Data': '$cv_subgraph_input_data' + }}, + 'Name': 'Transforms.OptionalColumnCreator', + 'Outputs': {{ + 'Model': '$output_model3', + 'OutputData': '$optional_data' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$optional_data', + 'LabelColumn': 'Label', + 'TextKeyValues': false + }}, + 'Name': 'Transforms.LabelColumnKeyBooleanConverter', + 'Outputs': {{ + 'Model': '$output_model4', + 'OutputData': '$label_data' + }} + }}, + {{ + 'Inputs': {{ + 'Data': '$label_data', + 'Features': [ + 'Features' + ] + }}, + 'Name': 'Transforms.FeatureCombiner', + 'Outputs': {{ + 'Model': '$output_model5', + 'OutputData': '$output_data' + }} + }}, + {{ + 'Inputs': {{ + 'BatchSize': 1048576, + 'Caching': 'Auto', + 'CategoricalSmoothing': 10.0, + 'CustomGains': [ + 0, + 3, + 7, + 15, + 31, + 63, + 127, + 255, + 511, + 1023, + 2047, + 4095 + ], + 'EarlyStoppingRound': 0, + 'EvaluationMetric': 'NormalizedDiscountedCumulativeGain', + 'FeatureColumnName': 'Features', + 'HandleMissingValue': true, + 'L2CategoricalRegularization': 10.0, + 'LabelColumnName': 'Label', + 'MaximumBinCountPerFeature': 255, + 'MaximumCategoricalSplitPointCount': 32, + 'MinimumExampleCountPerGroup': 100, + 'MinimumExampleCountPerLeaf': 1, + 'NormalizeFeatures': 'Auto', + 'NumberOfIterations': 100, + 'RowGroupColumnName': 'GroupId', + 'Sigmoid': 0.5, + 'Silent': true, + 'TrainingData': '$output_data', + 'UseZeroAsMissingValue': false, + 'Verbose': false + }}, + 'Name': 'Trainers.LightGbmRanker', + 'Outputs': {{ + 'PredictorModel': '$output_model_learner' + }} + }}, + {{ + 'Inputs': {{ + 'PredictorModel': '$output_model_learner', + 'TransformModels': [ + '$output_model3', + '$output_model4', + '$output_model5' + ] + }}, + 'Name': 'Transforms.ManyHeterogeneousModelCombiner', + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }} + }} + ], + 'NumFolds': 2, + 'Outputs': {{ + 'PredictorModel': '$predictor_model' + }}, + 'StratificationColumn': 'GroupId', + 'TransformModel': '$output_model_combined_pre_split' + }}, + 'Name': 'Models.CrossValidator', + 'Outputs': {{ + 'OverallMetrics': '$overall_metrics', + 'PerInstanceMetrics': '$per_instance_metrics', + 'PredictorModel': '$predictor_model', + 'Warnings': '$warnings' + }} + }} + ], + 'Outputs': {{ + 'overall_metrics': '$outmetrics', + 'per_instance_metrics': '', + 'predictor_model': '$outModel', + 'warnings': '$outwarnings' + }} + }} + + "; + + JObject graph = JObject.Parse(inputGraph); + var runner = new GraphRunner(Env, graph[FieldNames.Nodes] as JArray); + var inputFile = new SimpleFileHandle(Env, dataPath, false, false); + runner.SetInput("file", inputFile); + runner.RunAll(); + + var data = runner.GetOutput("overall_metrics"); + using(var cursor = data.GetRowCursorForAllColumns()) + { + var ndcgGetter = cursor.GetGetter>(data.Schema["NDCG"]); + VBuffer ndcgBuffer = default; + + cursor.MoveNext(); + ndcgGetter(ref ndcgBuffer); + var ndcgArray = ndcgBuffer.DenseValues().ToArray(); + + // Since we used a toy dataset, we won't worry much about comparing actual + // Double values of the result. Simply check that we get results. + Assert.Equal(3, ndcgArray.Length); + Assert.True(ndcgArray[0] > 0); + Assert.True(ndcgArray[1] > 0); + Assert.True(ndcgArray[2] > 0); + } + } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 0c8f8a7be0..5b036639e2 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -290,6 +290,55 @@ private List ReadBreastCancerExamples() return data; } + [Fact] + public void TestSplitsSchema() + { + + var mlContext = new MLContext(0); + var dataPath = GetDataPath("adult.tiny.with-schema.txt"); + + var fullInput = mlContext.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Workclass", DataKind.String, 1), + new TextLoader.Column("Education", DataKind.String,2), + new TextLoader.Column("Age", DataKind.Single,9) + }, hasHeader: true); + + var ttSplit = mlContext.Data.TrainTestSplit(fullInput); + var ttSplitWithSeed = mlContext.Data.TrainTestSplit(fullInput, seed: 10); + var ttSplitWithSeedAndSamplingKey = mlContext.Data.TrainTestSplit(fullInput, seed: 10, samplingKeyColumnName: "Workclass"); + + var cvSplit = mlContext.Data.CrossValidationSplit(fullInput); + var cvSplitWithSeed = mlContext.Data.CrossValidationSplit(fullInput, seed: 10); + var cvSplitWithSeedAndSamplingKey = mlContext.Data.CrossValidationSplit(fullInput, seed: 10, samplingKeyColumnName: "Workclass"); + + var splits = new[] + { + ttSplit.TrainSet, + ttSplit.TestSet, + ttSplitWithSeed.TrainSet, + ttSplitWithSeed.TestSet, + ttSplitWithSeedAndSamplingKey.TrainSet, + ttSplitWithSeedAndSamplingKey.TestSet, + cvSplit.First().TrainSet, + cvSplit.First().TestSet, + cvSplitWithSeed.First().TrainSet, + cvSplitWithSeed.First().TestSet, + cvSplitWithSeedAndSamplingKey.First().TrainSet, + cvSplitWithSeedAndSamplingKey.First().TestSet + }; + + // Splitting a dataset shouldn't affect its schema + foreach(var split in splits) + { + Assert.Equal(fullInput.Schema.Count, split.Schema.Count); + foreach(var col in fullInput.Schema) + { + Assert.Equal(col.Name, split.Schema[col.Index].Name); + } + } + } + [Fact] public void TestTrainTestSplit() { @@ -363,7 +412,7 @@ private sealed class Input } [Fact] - public void TestTrainTestSplitWithStratification() + public void TestSplitsWithSamplingKeyColumn() { var mlContext = new MLContext(0); var input = mlContext.Data.LoadFromEnumerable(new[] @@ -402,12 +451,13 @@ public void TestTrainTestSplitWithStratification() }, }); + // TEST TRAINTESTSPLIT var split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TextStrat)); var ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); Assert.Contains(1, ids); Assert.Contains(5, ids); split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.FloatStrat)); - ids = split.TrainSet.GetColumn(split.TrainSet.Schema[nameof(Input.Id)]); + ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); Assert.Contains(4, ids); Assert.Contains(5, ids); split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.VectorStrat)); @@ -426,6 +476,36 @@ public void TestTrainTestSplitWithStratification() ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); Assert.Contains(1, ids); Assert.Contains(2, ids); + + var inputWithKey = mlContext.Transforms.Conversion.MapValueToKey("KeyStrat", "TextStrat").Fit(input).Transform(input); + split = mlContext.Data.TrainTestSplit(inputWithKey, 0.5, "KeyStrat"); + ids = split.TestSet.GetColumn(split.TestSet.Schema[nameof(Input.Id)]); + Assert.Contains(1, ids); + Assert.Contains(5, ids); + Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull("KeyStrat")); // Check that the key column used as SamplingKeyColumn wasn't deleted by the split + + // TEST CROSSVALIDATIONSPLIT + var colnames = new[] { + nameof(Input.TextStrat), + nameof(Input.FloatStrat), + nameof(Input.VectorStrat), + nameof(Input.DateTimeStrat), + nameof(Input.DateTimeOffsetStrat), + nameof(Input.TimeSpanStrat), + "KeyStrat" }; + + foreach(var colname in colnames) + { + var cvSplits = mlContext.Data.CrossValidationSplit(inputWithKey, numberOfFolds: 2, samplingKeyColumnName: colname); + var idsTest1 = cvSplits[0].TestSet.GetColumn(cvSplits[0].TestSet.Schema[nameof(Input.Id)]); + var idsTest2 = cvSplits[1].TestSet.GetColumn(cvSplits[1].TestSet.Schema[nameof(Input.Id)]); + Assert.True(Enumerable.Intersect(idsTest1, idsTest2).Count() == 0); + Assert.True(idsTest1.Count() > 0, $"CV Split 0 for Column {colname} was empty"); + Assert.True(idsTest2.Count() > 0, $"CV Split 1 for Column {colname} was empty"); + + // Check that using CV didn't remove the SamplingKeyColumn + Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull(colname)); + } } } }