diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
index ceca5fea04..bd1e7c94dd 100644
--- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs
@@ -362,6 +362,31 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting);
+ ///
+ /// Create a , which maps the column specified in
+ /// to a vector of n-gram counts in a new column named .
+ ///
+ ///
+ /// is different from in that the former
+ /// tokenizes text internally and the latter takes tokenized text as input.
+ ///
+ /// The transform's catalog.
+ /// Name of the column resulting from the transformation of .
+ /// This column's data type will be known-size vector of .
+ /// Name of the column to take the data from.
+ /// Maximum number of n-grams to store in the dictionary.
+ /// Separator used to separate terms/frequency pairs.
+ /// Separator used to separate terms from their frequency.
+ /// This estimator operates over vector of text.
+ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog,
+ string outputColumnName,
+ char termSeparator,
+ char freqSeparator,
+ string inputColumnName = null,
+ int maximumNgramsCount = NgramExtractingEstimator.Defaults.MaximumNgramsCount)
+ => new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
+ outputColumnName, inputColumnName, 1, 0, true, maximumNgramsCount, NgramExtractingEstimator.WeightingCriteria.Tf, termSeparator: termSeparator, freqSeparator: freqSeparator);
+
///
/// Create a , which maps the multiple columns specified in
/// to a vector of n-gram counts in a new column named .
diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
index d45c5edc86..658445874d 100644
--- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -12,6 +13,7 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
+using static Microsoft.ML.Transforms.Text.WordBagBuildingTransformer;
[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
@@ -21,6 +23,16 @@
[assembly: EntryPointModule(typeof(NgramExtractorTransform.NgramExtractorArguments))]
+// These are for the internal only TextExpandingTransformer. Not exposed publically
+[assembly: LoadableClass(TextExpandingTransformer.Summary, typeof(IDataTransform), typeof(TextExpandingTransformer), null, typeof(SignatureLoadDataTransform),
+ TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(TextExpandingTransformer), null, typeof(SignatureLoadModel),
+ TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IRowMapper), typeof(TextExpandingTransformer), null, typeof(SignatureLoadRowMapper),
+ TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
+
namespace Microsoft.ML.Transforms.Text
{
///
@@ -144,18 +156,195 @@ internal static IEstimator CreateEstimator(IHostEnvironment env, O
NgramLength = column.NgramLength,
SkipLength = column.SkipLength,
Weighting = column.Weighting,
- UseAllLengths = column.UseAllLengths
+ UseAllLengths = column.UseAllLengths,
};
}
IEstimator estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns);
- estimator = estimator.Append(new WordTokenizingEstimator(env, tokenizeColumns));
+ if (options.FreqSeparator != default)
+ {
+ estimator = estimator.Append(new TextExpandingEstimator(h, tokenizeColumns[0].InputColumnName, options.FreqSeparator, options.TermSeparator));
+ }
+ estimator = estimator.Append(new WordTokenizingEstimator(h, tokenizeColumns));
estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema)));
return estimator;
}
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input);
+
+ #region TextExpander
+
+ // Internal only estimator used to facilitate the expansion of ngrams with pre-defined weights
+ internal sealed class TextExpandingEstimator : TrivialEstimator
+ {
+ private readonly string _columnName;
+ public TextExpandingEstimator(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingEstimator)), new TextExpandingTransformer(env, columnName, freqSeparator, termSeparator))
+ {
+ _columnName = columnName;
+ }
+
+ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ if (!inputSchema.TryFindColumn(_columnName, out SchemaShape.Column outCol) && outCol.ItemType != TextDataViewType.Instance)
+ {
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columnName);
+ }
+
+ return inputSchema;
+ }
+ }
+
+ // Internal only transformer used to facilitate the expansion of ngrams with pre-defined weights
+ internal sealed class TextExpandingTransformer : RowToRowTransformerBase
+ {
+ internal const string Summary = "Expands text in the format of term:freq; to have the correct number of terms";
+ internal const string UserName = "Text Expanding Transform";
+ internal const string LoadName = "TextExpand";
+
+ internal const string LoaderSignature = "TextExpandTransform";
+
+ private readonly string _columnName;
+ private readonly char _freqSeparator;
+ private readonly char _termSeparator;
+
+ public TextExpandingTransformer(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingTransformer)))
+ {
+ _columnName = columnName;
+ _freqSeparator = freqSeparator;
+ _termSeparator = termSeparator;
+ }
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "TEXT EXP",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TextExpandingTransformer).Assembly.FullName);
+ }
+
+ ///
+ /// Factory method for SignatureLoadModel.
+ ///
+ private TextExpandingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
+ base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ // *** Binary format ***
+ // string: column n ame
+ // char: frequency separator
+ // char: term separator
+
+ _columnName = ctx.Reader.ReadString();
+ _freqSeparator = ctx.Reader.ReadChar();
+ _termSeparator = ctx.Reader.ReadChar();
+ }
+
+ ///
+ /// Factory method for SignatureLoadRowMapper.
+ ///
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
+ => new TextExpandingTransformer(env, ctx).MakeRowMapper(inputSchema);
+
+ ///
+ /// Factory method for SignatureLoadDataTransform.
+ ///
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => new TextExpandingTransformer(env, ctx).MakeDataTransform(input);
+
+ private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
+ {
+ return new Mapper(Host, schema, this);
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // string: column n ame
+ // char: frequency separator
+ // char: term separator
+
+ ctx.Writer.Write(_columnName);
+ ctx.Writer.Write(_freqSeparator);
+ ctx.Writer.Write(_termSeparator);
+ }
+
+ private sealed class Mapper : MapperBase
+ {
+ private readonly TextExpandingTransformer _parent;
+ public Mapper(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
+ : base(host, inputSchema, parent)
+ {
+ _parent = (TextExpandingTransformer)parent;
+ }
+
+ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
+ {
+ return new DataViewSchema.DetachedColumn[]
+ {
+ new DataViewSchema.DetachedColumn(_parent._columnName, TextDataViewType.Instance)
+ };
+ }
+
+ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ ValueGetter> srcGetter = input.GetGetter>(input.Schema.GetColumnOrNull(_parent._columnName).Value);
+ ReadOnlyMemory inputMem = default;
+ var sb = new StringBuilder();
+
+ ValueGetter> result = (ref ReadOnlyMemory dst) =>
+ {
+ sb.Clear();
+ srcGetter(ref inputMem);
+ var inputText = inputMem.ToString();
+ foreach (var termFreq in inputText.Split(_parent._termSeparator))
+ {
+ var tf = termFreq.Split(_parent._freqSeparator);
+ if (tf.Length != 2)
+ sb.Append(tf[0] + " ");
+ else
+ {
+ for (int i = 0; i < int.Parse(tf[1]); i++)
+ sb.Append(tf[0] + " ");
+ }
+ }
+
+ dst = sb.ToString().AsMemory();
+ };
+
+ return result;
+ }
+
+ private protected override Func GetDependenciesCore(Func activeOutput)
+ {
+ var active = new bool[InputSchema.Count];
+ if (activeOutput(0))
+ {
+ active[InputSchema.GetColumnOrNull(_parent._columnName).Value.Index] = true;
+ }
+ return col => active[col];
+ }
+
+ private protected override void SaveModel(ModelSaveContext ctx)
+ {
+ _parent.SaveModel(ctx);
+ }
+ }
+ }
+
+ #endregion TextExpander
}
///
@@ -235,6 +424,13 @@ internal abstract class ArgumentsBase
[Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms/frequency pairs.")]
+ public char TermSeparator = default;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms from their frequency.")]
+ public char FreqSeparator = default;
+
}
[TlcModule.Component(Name = "NGram", FriendlyName = "NGram Extractor Transform", Alias = "NGramExtractorTransform,NGramExtractor",
diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
index b9aa7d9649..640a65fc05 100644
--- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
+++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
@@ -45,6 +45,8 @@ public sealed class WordBagEstimator : IEstimator
private readonly bool _useAllLengths;
private readonly int _maxNumTerms;
private readonly NgramExtractingEstimator.WeightingCriteria _weighting;
+ private readonly char _termSeparator;
+ private readonly char _freqSeparator;
///
/// Options for how the n-grams are extracted.
@@ -99,6 +101,8 @@ public Options()
/// Whether to include all n-gram lengths up to or only .
/// Maximum number of n-grams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
+ /// Separator used to separate terms/frequency pairs.
+ /// Separator used to separate terms from their frequency.
internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string inputColumnName = null,
@@ -106,8 +110,10 @@ internal WordBagEstimator(IHostEnvironment env,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
- NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
- : this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting)
+ NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
+ char termSeparator = default,
+ char freqSeparator = default)
+ : this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator)
{
}
@@ -123,6 +129,8 @@ internal WordBagEstimator(IHostEnvironment env,
/// Whether to include all n-gram lengths up to or only .
/// Maximum number of n-grams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
+ /// Separator used to separate terms/frequency pairs.
+ /// Separator used to separate terms from their frequency.
internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string[] inputColumnNames,
@@ -130,8 +138,10 @@ internal WordBagEstimator(IHostEnvironment env,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
- NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
- : this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting)
+ NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
+ char termSeparator = default,
+ char freqSeparator = default)
+ : this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator)
{
}
@@ -146,13 +156,17 @@ internal WordBagEstimator(IHostEnvironment env,
/// Whether to include all n-gram lengths up to or only .
/// Maximum number of n-grams to store in the dictionary.
/// Statistical measure used to evaluate how important a word is to a document in a corpus.
+ /// Separator used to separate terms/frequency pairs.
+ /// Separator used to separate terms from their frequency.
internal WordBagEstimator(IHostEnvironment env,
(string outputColumnName, string[] inputColumnNames)[] columns,
int ngramLength = 1,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
- NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
+ NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
+ char termSeparator = default,
+ char freqSeparator = default)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(WordBagEstimator));
@@ -169,6 +183,8 @@ internal WordBagEstimator(IHostEnvironment env,
_useAllLengths = useAllLengths;
_maxNumTerms = maximumNgramsCount;
_weighting = weighting;
+ _termSeparator = termSeparator;
+ _freqSeparator = freqSeparator;
}
/// Trains and returns a .
@@ -187,7 +203,9 @@ private WordBagBuildingTransformer.Options CreateOptions()
SkipLength = _skipLength,
UseAllLengths = _useAllLengths,
MaxNumTerms = new[] { _maxNumTerms },
- Weighting = _weighting
+ Weighting = _weighting,
+ TermSeparator = _termSeparator,
+ FreqSeparator = _freqSeparator,
};
}
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index 5f30327262..ba235d72c3 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -30099,6 +30099,24 @@
"SortOrder": 150.0,
"IsNullable": false,
"Default": "Tf"
+ },
+ {
+ "Name": "TermSeparator",
+ "Type": "Char",
+ "Desc": "Separator used to separate terms/frequency pairs.",
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": "\u0000"
+ },
+ {
+ "Name": "FreqSeparator",
+ "Type": "Char",
+ "Desc": "Separator used to separate terms from their frequency.",
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": "\u0000"
}
]
},
diff --git a/test/Microsoft.ML.Tests/Transformers/WordBagTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/WordBagTransformerTests.cs
new file mode 100644
index 0000000000..767a319faa
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/WordBagTransformerTests.cs
@@ -0,0 +1,132 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.IO;
+using System.Linq;
+using Microsoft.ML.Data;
+using Microsoft.ML.RunTests;
+using Microsoft.ML.TestFrameworkCommon;
+using Microsoft.ML.Transforms.Text;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public sealed class WordBagTransformerTests : TestDataPipeBase
+ {
+ public WordBagTransformerTests(ITestOutputHelper helper) : base(helper)
+ {
+ }
+
+ [Fact]
+ public void WordBagsPreDefined()
+ {
+ var mlContext = new MLContext(1);
+ var samples = new List()
+ {
+ new TextData(){ Text = "div:12;strong:9;span:13;br:1;a:2" },
+ new TextData(){ Text = "p:5;strong:1;code:2;br:2;a:2;img:1;span:6;script:1" },
+ new TextData(){ Text = "p:5" },
+ new TextData(){ Text = "p" },
+ new TextData(){ Text = "-1" },
+ };
+
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+ var textPipeline =
+ mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ';', freqSeparator: ':');
+
+
+ var textTransformer = textPipeline.Fit(dataview);
+ var pred = textTransformer.Preview(dataview);
+
+ var expected = new float[] { 12, 9, 13, 1, 2, 0, 0, 0, 0, 0 };
+
+ Assert.Equal(expected, ((VBuffer)pred.ColumnView[4].Values[0]).DenseValues().ToArray());
+
+ TestEstimatorCore(textPipeline, dataview);
+ Done();
+ }
+
+ [Fact]
+ public void WordBagsPreDefinedNonDefault()
+ {
+ var mlContext = new MLContext(1);
+ var samples = new List()
+ {
+ new TextData(){ Text = "div;12:strong;9:span;13:br;1:a;2" },
+ new TextData(){ Text = "p;5:strong;1:code;2:br;2:a;2:img;1:span;6:script;1" },
+ new TextData(){ Text = "p;5" },
+ new TextData(){ Text = "p" },
+ new TextData(){ Text = "-1" },
+ };
+
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+ var textPipeline =
+ mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ':', freqSeparator: ';');
+
+
+ var textTransformer = textPipeline.Fit(dataview);
+ var pred = textTransformer.Preview(dataview);
+ var expected = new float[] { 12, 9, 13, 1, 2, 0, 0, 0, 0, 0 };
+
+ Assert.Equal(expected, ((VBuffer)pred.ColumnView[4].Values[0]).DenseValues().ToArray());
+
+ TestEstimatorCore(textPipeline, dataview);
+ Done();
+ }
+
+ [Fact]
+ public void WordBagsPreDefinedMakeSureDefaultAndNonDefaultHaveSameOutput()
+ {
+ var mlContext = new MLContext(1);
+
+ var samplesDefault = new List()
+ {
+ new TextData(){ Text = "div:12;strong:9;span:13;br:1;a:2" },
+ new TextData(){ Text = "p:5;strong:1;code:2;br:2;a:2;img:1;span:6;script:1" },
+ new TextData(){ Text = "p:5" },
+ new TextData(){ Text = "p" },
+ new TextData(){ Text = "-1" },
+ };
+
+ var samplesNonDefault = new List()
+ {
+ new TextData(){ Text = "div;12:strong;9:span;13:br;1:a;2" },
+ new TextData(){ Text = "p;5:strong;1:code;2:br;2:a;2:img;1:span;6:script;1" },
+ new TextData(){ Text = "p;5" },
+ new TextData(){ Text = "p" },
+ new TextData(){ Text = "-1" },
+ };
+
+ var dataviewDefault = mlContext.Data.LoadFromEnumerable(samplesDefault);
+ var dataviewNonDefault = mlContext.Data.LoadFromEnumerable(samplesNonDefault);
+ var textPipelineDefault = mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ';', freqSeparator: ':');
+ var textPipelineNonDefault = mlContext.Transforms.Text.ProduceWordBags("Text", termSeparator: ':', freqSeparator: ';');
+
+
+ var textTransformerDefault = textPipelineDefault.Fit(dataviewDefault);
+ var textTransformerNonDefault = textPipelineNonDefault.Fit(dataviewNonDefault);
+ var predDefault = textTransformerDefault.Preview(dataviewDefault);
+ var predNonDefault = textTransformerNonDefault.Preview(dataviewNonDefault);
+
+ Assert.Equal(((VBuffer)predDefault.ColumnView[4].Values[0]).DenseValues().ToArray(), ((VBuffer)predNonDefault.ColumnView[4].Values[0]).DenseValues().ToArray());
+
+ Done();
+ }
+
+ private class TextData
+ {
+ public string Text { get; set; }
+ }
+
+ private class TransformedTextData
+ {
+ public float[] Text { get; set; }
+ }
+ }
+
+}