diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index ceca5fea04..bd1e7c94dd 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -362,6 +362,31 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf => new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting); + /// + /// Create a , which maps the column specified in + /// to a vector of n-gram counts in a new column named . + /// + /// + /// is different from in that the former + /// tokenizes text internally and the latter takes tokenized text as input. + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// This column's data type will be known-size vector of . + /// Name of the column to take the data from. + /// Maximum number of n-grams to store in the dictionary. + /// Separator used to separate terms/frequency pairs. + /// Separator used to separate terms from their frequency. + /// This estimator operates over vector of text. + public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog, + string outputColumnName, + char termSeparator, + char freqSeparator, + string inputColumnName = null, + int maximumNgramsCount = NgramExtractingEstimator.Defaults.MaximumNgramsCount) + => new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), + outputColumnName, inputColumnName, 1, 0, true, maximumNgramsCount, NgramExtractingEstimator.WeightingCriteria.Tf, termSeparator: termSeparator, freqSeparator: freqSeparator); + /// /// Create a , which maps the multiple columns specified in /// to a vector of n-gram counts in a new column named . diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index d45c5edc86..658445874d 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -12,6 +13,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms.Text; +using static Microsoft.ML.Transforms.Text.WordBagBuildingTransformer; [assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform), "Word Bag Transform", "WordBagTransform", "WordBag")] @@ -21,6 +23,16 @@ [assembly: EntryPointModule(typeof(NgramExtractorTransform.NgramExtractorArguments))] +// These are for the internal only TextExpandingTransformer. Not exposed publically +[assembly: LoadableClass(TextExpandingTransformer.Summary, typeof(IDataTransform), typeof(TextExpandingTransformer), null, typeof(SignatureLoadDataTransform), + TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(TextExpandingTransformer), null, typeof(SignatureLoadModel), + TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(TextExpandingTransformer), null, typeof(SignatureLoadRowMapper), + TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)] + namespace Microsoft.ML.Transforms.Text { /// @@ -144,18 +156,195 @@ internal static IEstimator CreateEstimator(IHostEnvironment env, O NgramLength = column.NgramLength, SkipLength = column.SkipLength, Weighting = column.Weighting, - UseAllLengths = column.UseAllLengths + UseAllLengths = column.UseAllLengths, }; } IEstimator estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns); - estimator = estimator.Append(new WordTokenizingEstimator(env, tokenizeColumns)); + if (options.FreqSeparator != default) + { + estimator = estimator.Append(new TextExpandingEstimator(h, tokenizeColumns[0].InputColumnName, options.FreqSeparator, options.TermSeparator)); + } + estimator = estimator.Append(new WordTokenizingEstimator(h, tokenizeColumns)); estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema))); return estimator; } internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) => (IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input); + + #region TextExpander + + // Internal only estimator used to facilitate the expansion of ngrams with pre-defined weights + internal sealed class TextExpandingEstimator : TrivialEstimator + { + private readonly string _columnName; + public TextExpandingEstimator(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingEstimator)), new TextExpandingTransformer(env, columnName, freqSeparator, termSeparator)) + { + _columnName = columnName; + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + if (!inputSchema.TryFindColumn(_columnName, out SchemaShape.Column outCol) && outCol.ItemType != TextDataViewType.Instance) + { + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columnName); + } + + return inputSchema; + } + } + + // Internal only transformer used to facilitate the expansion of ngrams with pre-defined weights + internal sealed class TextExpandingTransformer : RowToRowTransformerBase + { + internal const string Summary = "Expands text in the format of term:freq; to have the correct number of terms"; + internal const string UserName = "Text Expanding Transform"; + internal const string LoadName = "TextExpand"; + + internal const string LoaderSignature = "TextExpandTransform"; + + private readonly string _columnName; + private readonly char _freqSeparator; + private readonly char _termSeparator; + + public TextExpandingTransformer(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingTransformer))) + { + _columnName = columnName; + _freqSeparator = freqSeparator; + _termSeparator = termSeparator; + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TEXT EXP", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TextExpandingTransformer).Assembly.FullName); + } + + /// + /// Factory method for SignatureLoadModel. + /// + private TextExpandingTransformer(IHostEnvironment env, ModelLoadContext ctx) : + base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer))) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // string: column n ame + // char: frequency separator + // char: term separator + + _columnName = ctx.Reader.ReadString(); + _freqSeparator = ctx.Reader.ReadChar(); + _termSeparator = ctx.Reader.ReadChar(); + } + + /// + /// Factory method for SignatureLoadRowMapper. + /// + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => new TextExpandingTransformer(env, ctx).MakeRowMapper(inputSchema); + + /// + /// Factory method for SignatureLoadDataTransform. + /// + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => new TextExpandingTransformer(env, ctx).MakeDataTransform(input); + + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) + { + return new Mapper(Host, schema, this); + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // string: column n ame + // char: frequency separator + // char: term separator + + ctx.Writer.Write(_columnName); + ctx.Writer.Write(_freqSeparator); + ctx.Writer.Write(_termSeparator); + } + + private sealed class Mapper : MapperBase + { + private readonly TextExpandingTransformer _parent; + public Mapper(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent) + : base(host, inputSchema, parent) + { + _parent = (TextExpandingTransformer)parent; + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + return new DataViewSchema.DetachedColumn[] + { + new DataViewSchema.DetachedColumn(_parent._columnName, TextDataViewType.Instance) + }; + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + ValueGetter> srcGetter = input.GetGetter>(input.Schema.GetColumnOrNull(_parent._columnName).Value); + ReadOnlyMemory inputMem = default; + var sb = new StringBuilder(); + + ValueGetter> result = (ref ReadOnlyMemory dst) => + { + sb.Clear(); + srcGetter(ref inputMem); + var inputText = inputMem.ToString(); + foreach (var termFreq in inputText.Split(_parent._termSeparator)) + { + var tf = termFreq.Split(_parent._freqSeparator); + if (tf.Length != 2) + sb.Append(tf[0] + " "); + else + { + for (int i = 0; i < int.Parse(tf[1]); i++) + sb.Append(tf[0] + " "); + } + } + + dst = sb.ToString().AsMemory(); + }; + + return result; + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + var active = new bool[InputSchema.Count]; + if (activeOutput(0)) + { + active[InputSchema.GetColumnOrNull(_parent._columnName).Value.Index] = true; + } + return col => active[col]; + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + _parent.SaveModel(ctx); + } + } + } + + #endregion TextExpander } /// @@ -235,6 +424,13 @@ internal abstract class ArgumentsBase [Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")] public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms/frequency pairs.")] + public char TermSeparator = default; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms from their frequency.")] + public char FreqSeparator = default; + } [TlcModule.Component(Name = "NGram", FriendlyName = "NGram Extractor Transform", Alias = "NGramExtractorTransform,NGramExtractor", diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index b9aa7d9649..640a65fc05 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -45,6 +45,8 @@ public sealed class WordBagEstimator : IEstimator private readonly bool _useAllLengths; private readonly int _maxNumTerms; private readonly NgramExtractingEstimator.WeightingCriteria _weighting; + private readonly char _termSeparator; + private readonly char _freqSeparator; /// /// Options for how the n-grams are extracted. @@ -99,6 +101,8 @@ public Options() /// Whether to include all n-gram lengths up to or only . /// Maximum number of n-grams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. + /// Separator used to separate terms/frequency pairs. + /// Separator used to separate terms from their frequency. internal WordBagEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, @@ -106,8 +110,10 @@ internal WordBagEstimator(IHostEnvironment env, int skipLength = 0, bool useAllLengths = true, int maximumNgramsCount = 10000000, - NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) - : this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting) + NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf, + char termSeparator = default, + char freqSeparator = default) + : this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator) { } @@ -123,6 +129,8 @@ internal WordBagEstimator(IHostEnvironment env, /// Whether to include all n-gram lengths up to or only . /// Maximum number of n-grams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. + /// Separator used to separate terms/frequency pairs. + /// Separator used to separate terms from their frequency. internal WordBagEstimator(IHostEnvironment env, string outputColumnName, string[] inputColumnNames, @@ -130,8 +138,10 @@ internal WordBagEstimator(IHostEnvironment env, int skipLength = 0, bool useAllLengths = true, int maximumNgramsCount = 10000000, - NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) - : this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting) + NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf, + char termSeparator = default, + char freqSeparator = default) + : this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator) { } @@ -146,13 +156,17 @@ internal WordBagEstimator(IHostEnvironment env, /// Whether to include all n-gram lengths up to or only . /// Maximum number of n-grams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. + /// Separator used to separate terms/frequency pairs. + /// Separator used to separate terms from their frequency. internal WordBagEstimator(IHostEnvironment env, (string outputColumnName, string[] inputColumnNames)[] columns, int ngramLength = 1, int skipLength = 0, bool useAllLengths = true, int maximumNgramsCount = 10000000, - NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) + NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf, + char termSeparator = default, + char freqSeparator = default) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(WordBagEstimator)); @@ -169,6 +183,8 @@ internal WordBagEstimator(IHostEnvironment env, _useAllLengths = useAllLengths; _maxNumTerms = maximumNgramsCount; _weighting = weighting; + _termSeparator = termSeparator; + _freqSeparator = freqSeparator; } /// Trains and returns a . @@ -187,7 +203,9 @@ private WordBagBuildingTransformer.Options CreateOptions() SkipLength = _skipLength, UseAllLengths = _useAllLengths, MaxNumTerms = new[] { _maxNumTerms }, - Weighting = _weighting + Weighting = _weighting, + TermSeparator = _termSeparator, + FreqSeparator = _freqSeparator, }; } diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 5f30327262..ba235d72c3 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -30099,6 +30099,24 @@ "SortOrder": 150.0, "IsNullable": false, "Default": "Tf" + }, + { + "Name": "TermSeparator", + "Type": "Char", + "Desc": "Separator used to separate terms/frequency pairs.", + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "\u0000" + }, + { + "Name": "FreqSeparator", + "Type": "Char", + "Desc": "Separator used to separate terms from their frequency.", + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": "\u0000" } ] }, diff --git a/test/Microsoft.ML.Tests/Transformers/WordBagTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/WordBagTransformerTests.cs new file mode 100644 index 0000000000..767a319faa --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/WordBagTransformerTests.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.TestFrameworkCommon; +using Microsoft.ML.Transforms.Text; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public sealed class WordBagTransformerTests : TestDataPipeBase + { + public WordBagTransformerTests(ITestOutputHelper helper) : base(helper) + { + } + + [Fact] + public void WordBagsPreDefined() + { + var mlContext = new MLContext(1); + var samples = new List() + { + new TextData(){ Text = "div:12;strong:9;span:13;br:1;a:2" }, + new TextData(){ Text = "p:5;strong:1;code:2;br:2;a:2;img:1;span:6;script:1" }, + new TextData(){ Text = "p:5" }, + new TextData(){ Text = "p" }, + new TextData(){ Text = "-1" }, + }; + + var dataview = mlContext.Data.LoadFromEnumerable(samples); + var textPipeline = + mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ';', freqSeparator: ':'); + + + var textTransformer = textPipeline.Fit(dataview); + var pred = textTransformer.Preview(dataview); + + var expected = new float[] { 12, 9, 13, 1, 2, 0, 0, 0, 0, 0 }; + + Assert.Equal(expected, ((VBuffer)pred.ColumnView[4].Values[0]).DenseValues().ToArray()); + + TestEstimatorCore(textPipeline, dataview); + Done(); + } + + [Fact] + public void WordBagsPreDefinedNonDefault() + { + var mlContext = new MLContext(1); + var samples = new List() + { + new TextData(){ Text = "div;12:strong;9:span;13:br;1:a;2" }, + new TextData(){ Text = "p;5:strong;1:code;2:br;2:a;2:img;1:span;6:script;1" }, + new TextData(){ Text = "p;5" }, + new TextData(){ Text = "p" }, + new TextData(){ Text = "-1" }, + }; + + var dataview = mlContext.Data.LoadFromEnumerable(samples); + var textPipeline = + mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ':', freqSeparator: ';'); + + + var textTransformer = textPipeline.Fit(dataview); + var pred = textTransformer.Preview(dataview); + var expected = new float[] { 12, 9, 13, 1, 2, 0, 0, 0, 0, 0 }; + + Assert.Equal(expected, ((VBuffer)pred.ColumnView[4].Values[0]).DenseValues().ToArray()); + + TestEstimatorCore(textPipeline, dataview); + Done(); + } + + [Fact] + public void WordBagsPreDefinedMakeSureDefaultAndNonDefaultHaveSameOutput() + { + var mlContext = new MLContext(1); + + var samplesDefault = new List() + { + new TextData(){ Text = "div:12;strong:9;span:13;br:1;a:2" }, + new TextData(){ Text = "p:5;strong:1;code:2;br:2;a:2;img:1;span:6;script:1" }, + new TextData(){ Text = "p:5" }, + new TextData(){ Text = "p" }, + new TextData(){ Text = "-1" }, + }; + + var samplesNonDefault = new List() + { + new TextData(){ Text = "div;12:strong;9:span;13:br;1:a;2" }, + new TextData(){ Text = "p;5:strong;1:code;2:br;2:a;2:img;1:span;6:script;1" }, + new TextData(){ Text = "p;5" }, + new TextData(){ Text = "p" }, + new TextData(){ Text = "-1" }, + }; + + var dataviewDefault = mlContext.Data.LoadFromEnumerable(samplesDefault); + var dataviewNonDefault = mlContext.Data.LoadFromEnumerable(samplesNonDefault); + var textPipelineDefault = mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ';', freqSeparator: ':'); + var textPipelineNonDefault = mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ':', freqSeparator: ';'); + + + var textTransformerDefault = textPipelineDefault.Fit(dataviewDefault); + var textTransformerNonDefault = textPipelineNonDefault.Fit(dataviewNonDefault); + var predDefault = textTransformerDefault.Preview(dataviewDefault); + var predNonDefault = textTransformerNonDefault.Preview(dataviewNonDefault); + + Assert.Equal(((VBuffer)predDefault.ColumnView[4].Values[0]).DenseValues().ToArray(), ((VBuffer)predNonDefault.ColumnView[4].Values[0]).DenseValues().ToArray()); + + Done(); + } + + private class TextData + { + public string Text { get; set; } + } + + private class TransformedTextData + { + public float[] Text { get; set; } + } + } + +}