diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 37a6064469..4c7cbc324e 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -136,7 +136,7 @@ private protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol) Host.Assert(LabelColumn.IsValid); if (!LabelColumn.IsCompatibleWith(labelCol)) - throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", WeightColumn.Name, + throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name, LabelColumn.GetTypeString(), labelCol.GetTypeString()); } diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index cf5c43ad9e..de808b4843 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -28,6 +28,33 @@ namespace Microsoft.ML.Transforms { + internal sealed class MissingValueDroppingEstimator : TrivialEstimator + { + public MissingValueDroppingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingEstimator)), new MissingValueDroppingTransformer(env, columns)) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + var resultDic = inputSchema.ToDictionary(x => x.Name); + foreach (var (outputColumnName, inputColumnName) in Transformer.Columns) + { + if (!inputSchema.TryFindColumn(inputColumnName, out var originalColumn)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName); + if (originalColumn.Kind == SchemaShape.Column.VectorKind.Scalar) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", originalColumn.Name, "Vector", "Scalar"); + if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(originalColumn.ItemType, out _)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", originalColumn.Name, "Single, Double or Key", originalColumn.ItemType.ToString()); + var col = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.VariableVector, originalColumn.ItemType, originalColumn.IsKey, originalColumn.Annotations); + resultDic[outputColumnName] = col; + } + return new SchemaShape(resultDic.Values); + } + } + /// internal sealed class MissingValueDroppingTransformer : OneToOneTransformerBase { @@ -163,23 +190,15 @@ public Mapper(MissingValueDroppingTransformer parent, DataViewSchema inputSchema { inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]); var srcCol = inputSchema[_srcCols[i]]; + if (!(srcCol.Type is VectorDataViewType)) + throw _parent.Host.Except($"Column '{srcCol.Name}' is not a vector column"); + if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(srcCol.Type.GetItemType(), out _isNAs[i])) + throw _parent.Host.Except($"Column '{srcCol.Name}' is of type {srcCol.Type.GetItemType()}, which does not support missing values"); _srcTypes[i] = srcCol.Type; _types[i] = new VectorDataViewType((PrimitiveDataViewType)srcCol.Type.GetItemType()); - _isNAs[i] = GetIsNADelegate(srcCol.Type); } } - /// - /// Returns the isNA predicate for the respective type. - /// - private Delegate GetIsNADelegate(DataViewType type) - { - Func func = GetIsNADelegate; - return Utils.MarshalInvoke(func, type.GetItemType().RawType, type); - } - - private Delegate GetIsNADelegate(DataViewType type) => Data.Conversion.Conversions.Instance.GetIsNAPredicate(type.GetItemType()); - protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length]; diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index fedb164bc8..4ba655844d 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -95,12 +95,11 @@ internal sealed class Options : NgramExtractorTransform.ArgumentsBase internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building " + "a dictionary of n-grams and using the id in the dictionary as the index in the bag."; - internal static ITransformer CreateTransfomer(IHostEnvironment env, Options options, IDataView input) + internal static IEstimator CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); h.CheckValue(options, nameof(options)); - h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified"); // Compose the WordBagTransform from a tokenize transform, @@ -149,17 +148,14 @@ internal static ITransformer CreateTransfomer(IHostEnvironment env, Options opti }; } - IDataView view = input; - ITransformer t0 = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns); - view = t0.Transform(view); - ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view); - view = t1.Transform(view); - ITransformer t2 = NgramExtractorTransform.Create(h, extractorArgs, view); - return new TransformerChain(new[] { t0, t1, t2 }); + IEstimator estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns); + estimator = estimator.Append(new WordTokenizingEstimator(env, tokenizeColumns)); + estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema))); + return estimator; } internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) => - (IDataTransform)CreateTransfomer(env, options, input).Transform(input); + (IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input); } /// @@ -262,15 +258,14 @@ internal sealed class Options : ArgumentsBase internal const string LoaderSignature = "NgramExtractor"; - internal static ITransformer Create(IHostEnvironment env, Options options, IDataView input, TermLoaderArguments termLoaderArgs = null) + internal static IEstimator CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema, TermLoaderArguments termLoaderArgs = null) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(LoaderSignature); h.CheckValue(options, nameof(options)); - h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified"); - var chain = new TransformerChain(); + var chain = new EstimatorChain(); var termCols = new List(); var isTermCol = new bool[options.Columns.Length]; @@ -281,9 +276,8 @@ internal static ITransformer Create(IHostEnvironment env, Options options, IData h.CheckNonWhiteSpace(col.Name, nameof(col.Name)); h.CheckNonWhiteSpace(col.Source, nameof(col.Source)); - int colId; - if (input.Schema.TryGetColumnIndex(col.Source, out colId) && - input.Schema[colId].Type.GetItemType() is TextDataViewType) + if (inputSchema.TryFindColumn(col.Source, out var colShape) && + colShape.ItemType is TextDataViewType) { termCols.Add(col); isTermCol[i] = true; @@ -327,9 +321,9 @@ internal static ITransformer Create(IHostEnvironment env, Options options, IData using (var ch = env.Start("Create key data view")) keyData = ValueToKeyMappingTransformer.GetKeyDataViewOrNull(env, ch, termLoaderArgs.DataFile, termLoaderArgs.TermsColumn, termLoaderArgs.Loader, out var autoConvert); } - chain = chain.Append(new ValueToKeyMappingEstimator(h, columnOptions.ToArray(), keyData).Fit(input)); + chain = chain.Append(new ValueToKeyMappingEstimator(h, columnOptions.ToArray(), keyData)); if (missingDropColumns != null) - chain = chain.Append(new MissingValueDroppingTransformer(h, missingDropColumns.Select(x => (x, x)).ToArray())); + chain = chain.Append(new MissingValueDroppingEstimator(h, missingDropColumns.Select(x => (x, x)).ToArray())); } var ngramColumns = new NgramExtractingEstimator.ColumnOptions[options.Columns.Length]; @@ -345,9 +339,7 @@ internal static ITransformer Create(IHostEnvironment env, Options options, IData isTermCol[iinfo] ? column.Name : column.Source ); } - - input = chain.Transform(input); - return chain.Append(new NgramExtractingEstimator(env, ngramColumns).Fit(input)); + return chain.Append(new NgramExtractingEstimator(env, ngramColumns)); } internal static IDataTransform CreateDataTransform(IHostEnvironment env, Options options, IDataView input, @@ -355,20 +347,11 @@ internal static IDataTransform CreateDataTransform(IHostEnvironment env, Options { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); - return Create(env, options, input, termLoaderArgs).Transform(input) as IDataTransform; + return CreateEstimator(env, options, SchemaShape.Create(input.Schema), termLoaderArgs).Fit(input).Transform(input)/* Create(env, options, input, termLoaderArgs).Transform(input) */as IDataTransform; } - internal static ITransformer Create(IHostEnvironment env, NgramExtractorArguments extractorArgs, IDataView input, - ExtractorColumn[] cols, TermLoaderArguments termLoaderArgs = null) + internal static Options CreateNgramExtractorOptions(NgramExtractorArguments extractorArgs, ExtractorColumn[] cols) { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(LoaderSignature); - h.CheckValue(extractorArgs, nameof(extractorArgs)); - h.CheckValue(input, nameof(input)); - h.CheckUserArg(extractorArgs.SkipLength < extractorArgs.NgramLength, nameof(extractorArgs.SkipLength), "Should be less than " + nameof(extractorArgs.NgramLength)); - h.CheckUserArg(Utils.Size(cols) > 0, nameof(Options.Columns), "Must be specified"); - h.CheckValueOrNull(termLoaderArgs); - var extractorCols = new Column[cols.Length]; for (int i = 0; i < cols.Length; i++) { @@ -385,8 +368,7 @@ internal static ITransformer Create(IHostEnvironment env, NgramExtractorArgument MaxNumTerms = extractorArgs.MaxNumTerms, Weighting = extractorArgs.Weighting }; - - return Create(h, options, input, termLoaderArgs); + return options; } internal static INgramExtractorFactory Create(IHostEnvironment env, NgramExtractorArguments extractorArgs, @@ -468,7 +450,8 @@ public NgramExtractorFactory(NgramExtractorTransform.NgramExtractorArguments ext public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols) { - return NgramExtractorTransform.Create(env, _extractorArgs, input, cols, _termLoaderArgs); + var options = NgramExtractorTransform.CreateNgramExtractorOptions(_extractorArgs, cols); + return NgramExtractorTransform.CreateEstimator(env, options, SchemaShape.Create(input.Schema), _termLoaderArgs).Fit(input); } } @@ -495,12 +478,12 @@ public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColum internal static class NgramExtractionUtils { - public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns) + public static IEstimator GetConcatEstimator(IHostEnvironment env, ManyToOneColumn[] columns) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(columns, nameof(columns)); - var concatColumns = new List(); + var estimator = new EstimatorChain(); foreach (var col in columns) { env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Options.Columns)); @@ -508,13 +491,9 @@ public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneC env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source)); env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source)); if (col.Source.Length > 1) - concatColumns.Add(new ColumnConcatenatingTransformer.ColumnOptions(col.Name, col.Source)); + estimator = estimator.Append(new ColumnConcatenatingEstimator(env, col.Name, col.Source)); } - - if (concatColumns.Count > 0) - return new ColumnConcatenatingTransformer(env, concatColumns.ToArray()); - - return new TransformerChain(); + return estimator; } /// diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index 2d4a820ccc..5a05674358 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -174,8 +174,13 @@ internal WordBagEstimator(IHostEnvironment env, /// Trains and returns a . public ITransformer Fit(IDataView input) { - // Create arguments. - var options = new WordBagBuildingTransformer.Options + var estimator = WordBagBuildingTransformer.CreateEstimator(_host, CreateOptions(), SchemaShape.Create(input.Schema)); + return estimator.Fit(input); + } + + private WordBagBuildingTransformer.Options CreateOptions() + { + return new WordBagBuildingTransformer.Options { Columns = _columns.Select(x => new WordBagBuildingTransformer.Column { Name = x.outputColumnName, Source = x.sourceColumnsNames }).ToArray(), NgramLength = _ngramLength, @@ -184,8 +189,6 @@ public ITransformer Fit(IDataView input) MaxNumTerms = new[] { _maxNumTerms }, Weighting = _weighting }; - - return WordBagBuildingTransformer.CreateTransfomer(_host, options, input); } /// @@ -196,9 +199,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); - var fakeSchema = FakeSchemaFactory.Create(inputSchema); - var transformer = Fit(new EmptyDataView(_host, fakeSchema)); - return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema)); + var estimator = WordBagBuildingTransformer.CreateEstimator(_host, CreateOptions(), inputSchema); + return estimator.GetOutputSchema(inputSchema); } } diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 48306254c6..9a57ad96cd 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -599,9 +599,7 @@ public void WordBagWorkout() var est = new WordBagEstimator(ML, "bag_of_words", "text"). Append(new WordHashBagEstimator(ML, "bag_of_wordshash", "text", maximumNumberOfInverts: -1)); - // The following call fails because of the following issue - // https://github.com/dotnet/machinelearning/issues/969 - // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); + TestEstimatorCore(est, data, invalidInput: invalidData); var outputPath = GetOutputPath("Text", "bag_of_words.tsv"); var savedData = ML.Data.TakeRows(est.Fit(data).Transform(data), 4); @@ -686,10 +684,11 @@ public void LdaWorkout() Append(new LatentDirichletAllocationEstimator(env, "topics", "bag_of_words", 10, maximumNumberOfIterations: 10, resetRandomGenerator: true)); - // The following call fails because of the following issue - // https://github.com/dotnet/machinelearning/issues/969 - // In this test it manifests because of the WordBagEstimator in the estimator chain - // TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); + // Diabling this check due to the following issue with consitency of output. + // `seed` specified in ConsoleEnvironment has no effect. + // https://github.com/dotnet/machinelearning/issues/1004 + // On single box, setting `s.ResetRandomGenerator = true` works but fails on build server + // TestEstimatorCore(est, data, invalidInput: invalidData); var outputPath = GetOutputPath("Text", "ldatopics.tsv"); using (var ch = env.Start("save")) @@ -764,5 +763,21 @@ public void TestTextFeaturizerBackCompat() Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } + + [Fact] + public void TestWordBagInPipeline() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Features", DataKind.String, 1, 9) + }); + + var pipeline = ML.Transforms.Text.ProduceWordBags("Features") + .Append(ML.BinaryClassification.Trainers.FastTree()); + + TestEstimatorCore(pipeline, dataView); + Done(); + } } }