Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/Microsoft.ML.Transforms/Text/TextCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,31 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting);

/// <summary>
/// Create a <see cref="WordBagEstimator"/>, which maps the column specified in <paramref name="inputColumnName"/>
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
/// </summary>
/// <remarks>
/// <see cref="WordBagEstimator"/> is different from <see cref="NgramExtractingEstimator"/> in that the former
/// tokenizes text internally and the latter takes tokenized text as input.
/// </remarks>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.
/// This column's data type will be known-size vector of <see cref="System.Single"/>.</param>
/// <param name="inputColumnName">Name of the column to take the data from.
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
/// This estimator operates over vector of text.</param>
public static WordBagEstimator ProduceWordBagsPreDefinedWeight(this TransformsCatalog.TextTransforms catalog,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericstj @luisquintanilla @tannergooding thoughts on this api name? The normal one is just called ProduceWordBags, but when naming this one that it can potentially cause ambiguity issues due to the default parameters not needing to be specified.

I had 2 options, either name it something different, or not have default parameters and make the users specify the term separator and frequency separator manually each time. I went with the first approach, but want to hear your opinions on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tarekgh your thoughts as well if you have time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thoughts:

In main scenarios is it expected to be called once or twice in the code? If yes, then I prefer using ProduceWordBags without any optional parameter. Will be just overload. If the answer is no, then using a different name would be better, I guess. We may try to suggest a better name than ProduceWordBagsPreDefinedWeight. May be ``ProduceWordBagsEstimator` is simpler?

Copy link
Member

@bartonjs bartonjs Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current name, I think a "With" should be in there (ProduceWordBagsWithPreDefinedWeight). Though since the caller doesn't get to specify it "PreDefined" seems more like "Default".

I assume the real reason a caller wants this method is to specify the term and/or freq separators. Can this not just be added as a true extra optional parameter overload using the established patterns?

+       [EditorBrowsable(EditorBrowsableState.Never)]
        public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog,
            string outputColumnName,
            string inputColumnName = null,
-           int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
+           int ngramLength,
-           int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
+           int skipLength,
-           bool useAllLengths = NgramExtractingEstimator.Defaults.UseAllLengths,
+           bool useAllLengths,
-           int maximumNgramsCount = NgramExtractingEstimator.Defaults.MaximumNgramsCount,
+           int maximumNgramsCount,
-           NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
+           NgramExtractingEstimator.WeightingCriteria weighting)
            => new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
                outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting);

+       public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog,
+           string outputColumnName,
+           string inputColumnName = null,
+           int ngramLength = NgramExtractingEstimator.Defaults.NgramLength,
+           int skipLength = NgramExtractingEstimator.Defaults.SkipLength,
+           bool useAllLengths = NgramExtractingEstimator.Defaults.UseAllLengths,
+           int maximumNgramsCount = NgramExtractingEstimator.Defaults.MaximumNgramsCount,
+           NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
+           char termSeparator = ';',
+           char freqSeparator = ':')
+           => new WordBagEstimator(parameters go here);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which overload will be chosen if the caller is passing only the first parameter?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which overload will be chosen if the caller is passing only the first parameter?

If the compiled in the past, the old one. If they compile after, the new one. (The old method, now marked as [EB(Never)] is only called by a) legacy callers or b) someone who happened to specify all of those parameters and none of the new ones... the compiler will still (and only) call it if the callsite has an exact match to the signature.

Since both methods are creating a WordBagEstimator I'm assuming that "specifying the weighting" and "specifying the freq/term-separator" aren't actually conflicting options. If they are conflicting options, then some sort of differentiated name is warranted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would changing the existing method though by adding new default parameters count as a breaking change @ericstj?

Adding the parameters to the existing method is a breaking change (a really, really bad one). Adding it the way I suggested is the way to make it look like all you did was add new default parameters.

From Framework Design Guidelines, 3rd edition, sec 5.1 (General Member Design Guidelines):

DO move all default parameters to the new, longer overload when adding optional parameters to an existing method.

[prose that says to do what I said to do above]

and down in Appendix D (breaking changes) (emphasis mine):

D.11.2 Adding or Removing a Method Parameter

Methods in the .NET CLR are identified by their signature, which is composed
of the name, return type, and ordered list of parameters. Removing
a parameter, adding a required parameter, or adding an optional parameter
all count as changing the signature and, therefore, are logically the
same as deleting the original method. See section D.9.3 for information on
the runtime impact of removing a member.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bartonjs the separators and the weight aren't conflicting options. The default should be just to use the frequency of the words, but there isn't a problem if they are both specified together (though I would bet that most people don't want to specify both together and probably would give them unexpected, though not incorrect, behavior if they do that).

I think making the name the same with a new overload where the user has to specify the separators is probably the best based on this convo. We don't want the separators specified unless the user really needs them, since thats the flag that tells word bag that it needs to handle the input data differently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bartonjs I have made the changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That new catalog overload looks a lot nicer to me 😄

string outputColumnName,
string inputColumnName = null,
int maximumNgramsCount = NgramExtractingEstimator.Defaults.MaximumNgramsCount,
char termSeparator = ';',
char freqSeparator = ':')
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
outputColumnName, inputColumnName, 1, 0, true, maximumNgramsCount, NgramExtractingEstimator.WeightingCriteria.Tf, termSeparator: termSeparator, freqSeparator: freqSeparator);

/// <summary>
/// Create a <see cref="WordBagEstimator"/>, which maps the multiple columns specified in <paramref name="inputColumnNames"/>
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
Expand Down
200 changes: 198 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
Expand All @@ -12,6 +13,7 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
using static Microsoft.ML.Transforms.Text.WordBagBuildingTransformer;

[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
Expand All @@ -21,6 +23,16 @@

[assembly: EntryPointModule(typeof(NgramExtractorTransform.NgramExtractorArguments))]

// These are for the internal only TextExpandingTransformer. Not exposed publically
[assembly: LoadableClass(TextExpandingTransformer.Summary, typeof(IDataTransform), typeof(TextExpandingTransformer), null, typeof(SignatureLoadDataTransform),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]

[assembly: LoadableClass(typeof(TextExpandingTransformer), null, typeof(SignatureLoadModel),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]

[assembly: LoadableClass(typeof(IRowMapper), typeof(TextExpandingTransformer), null, typeof(SignatureLoadRowMapper),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]

namespace Microsoft.ML.Transforms.Text
{
/// <summary>
Expand Down Expand Up @@ -144,18 +156,195 @@ internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, O
NgramLength = column.NgramLength,
SkipLength = column.SkipLength,
Weighting = column.Weighting,
UseAllLengths = column.UseAllLengths
UseAllLengths = column.UseAllLengths,
};
}

IEstimator<ITransformer> estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns);
estimator = estimator.Append(new WordTokenizingEstimator(env, tokenizeColumns));
if (options.FreqSeparator != default)
{
estimator = estimator.Append(new TextExpandingEstimator(h, tokenizeColumns[0].InputColumnName, options.FreqSeparator, options.TermSeparator));
}
estimator = estimator.Append(new WordTokenizingEstimator(h, tokenizeColumns));
estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema)));
return estimator;
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input);

#region TextExpander

// Internal only estimator used to facilitate the expansion of ngrams with pre-defined weights
internal sealed class TextExpandingEstimator : TrivialEstimator<TextExpandingTransformer>
{
private readonly string _columnName;
public TextExpandingEstimator(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingEstimator)), new TextExpandingTransformer(env, columnName, freqSeparator, termSeparator))
{
_columnName = columnName;
}

public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
if (!inputSchema.TryFindColumn(_columnName, out SchemaShape.Column outCol) && outCol.ItemType != TextDataViewType.Instance)
{
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columnName);
}

return inputSchema;
}
}

// Internal only transformer used to facilitate the expansion of ngrams with pre-defined weights
internal sealed class TextExpandingTransformer : RowToRowTransformerBase
{
internal const string Summary = "Expands text in the format of term:freq; to have the correct number of terms";
internal const string UserName = "Text Expanding Transform";
internal const string LoadName = "TextExpand";

internal const string LoaderSignature = "TextExpandTransform";

private readonly string _columnName;
private readonly char _freqSeparator;
private readonly char _termSeparator;

public TextExpandingTransformer(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingTransformer)))
{
_columnName = columnName;
_freqSeparator = freqSeparator;
_termSeparator = termSeparator;
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TEXT EXP",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TextExpandingTransformer).Assembly.FullName);
}

/// <summary>
/// Factory method for SignatureLoadModel.
/// </summary>
private TextExpandingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
// *** Binary format ***
// string: column n ame
// char: frequency separator
// char: term separator

_columnName = ctx.Reader.ReadString();
_freqSeparator = ctx.Reader.ReadChar();
_termSeparator = ctx.Reader.ReadChar();
}

/// <summary>
/// Factory method for SignatureLoadRowMapper.
/// </summary>
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> new TextExpandingTransformer(env, ctx).MakeRowMapper(inputSchema);

/// <summary>
/// Factory method for SignatureLoadDataTransform.
/// </summary>
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> new TextExpandingTransformer(env, ctx).MakeDataTransform(input);

private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
{
return new Mapper(Host, schema, this);
}

private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format ***
// string: column n ame
// char: frequency separator
// char: term separator

ctx.Writer.Write(_columnName);
ctx.Writer.Write(_freqSeparator);
ctx.Writer.Write(_termSeparator);
}

private sealed class Mapper : MapperBase
{
private readonly TextExpandingTransformer _parent;
public Mapper(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
: base(host, inputSchema, parent)
{
_parent = (TextExpandingTransformer)parent;
}

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
return new DataViewSchema.DetachedColumn[]
{
new DataViewSchema.DetachedColumn(_parent._columnName, TextDataViewType.Instance)
};
}

protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;
ValueGetter<ReadOnlyMemory<char>> srcGetter = input.GetGetter<ReadOnlyMemory<char>>(input.Schema.GetColumnOrNull(_parent._columnName).Value);
ReadOnlyMemory<char> inputMem = default;
var sb = new StringBuilder();

ValueGetter<ReadOnlyMemory<char>> result = (ref ReadOnlyMemory<char> dst) =>
{
sb.Clear();
srcGetter(ref inputMem);
var inputText = inputMem.ToString();
foreach (var termFreq in inputText.Split(_parent._termSeparator))
{
var tf = termFreq.Split(_parent._freqSeparator);
if (tf.Length != 2)
sb.Append(tf[0] + " ");
else
{
for (int i = 0; i < int.Parse(tf[1]); i++)
sb.Append(tf[0] + " ");
}
}

dst = sb.ToString().AsMemory();
};

return result;
}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
var active = new bool[InputSchema.Count];
if (activeOutput(0))
{
active[InputSchema.GetColumnOrNull(_parent._columnName).Value.Index] = true;
}
return col => active[col];
}

private protected override void SaveModel(ModelSaveContext ctx)
{
_parent.SaveModel(ctx);
}
}
}

#endregion TextExpander
}

/// <summary>
Expand Down Expand Up @@ -235,6 +424,13 @@ internal abstract class ArgumentsBase

[Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting;

[Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms/frequency pairs.")]
public char TermSeparator = default;

[Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms from their frequency.")]
public char FreqSeparator = default;

}

[TlcModule.Component(Name = "NGram", FriendlyName = "NGram Extractor Transform", Alias = "NGramExtractorTransform,NGramExtractor",
Expand Down
30 changes: 24 additions & 6 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public sealed class WordBagEstimator : IEstimator<ITransformer>
private readonly bool _useAllLengths;
private readonly int _maxNumTerms;
private readonly NgramExtractingEstimator.WeightingCriteria _weighting;
private readonly char _termSeparator;
private readonly char _freqSeparator;

/// <summary>
/// Options for how the n-grams are extracted.
Expand Down Expand Up @@ -99,15 +101,19 @@ public Options()
/// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string inputColumnName = null,
int ngramLength = 1,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
: this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting)
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
char termSeparator = default,
char freqSeparator = default)
: this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator)
{
}

Expand All @@ -123,15 +129,19 @@ internal WordBagEstimator(IHostEnvironment env,
/// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string[] inputColumnNames,
int ngramLength = 1,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
: this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting)
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
char termSeparator = default,
char freqSeparator = default)
: this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator)
{
}

Expand All @@ -146,13 +156,17 @@ internal WordBagEstimator(IHostEnvironment env,
/// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
internal WordBagEstimator(IHostEnvironment env,
(string outputColumnName, string[] inputColumnNames)[] columns,
int ngramLength = 1,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
char termSeparator = default,
char freqSeparator = default)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(WordBagEstimator));
Expand All @@ -169,6 +183,8 @@ internal WordBagEstimator(IHostEnvironment env,
_useAllLengths = useAllLengths;
_maxNumTerms = maximumNgramsCount;
_weighting = weighting;
_termSeparator = termSeparator;
_freqSeparator = freqSeparator;
}

/// <summary> Trains and returns a <see cref="ITransformer"/>.</summary>
Expand All @@ -187,7 +203,9 @@ private WordBagBuildingTransformer.Options CreateOptions()
SkipLength = _skipLength,
UseAllLengths = _useAllLengths,
MaxNumTerms = new[] { _maxNumTerms },
Weighting = _weighting
Weighting = _weighting,
TermSeparator = _termSeparator,
FreqSeparator = _freqSeparator,
};
}

Expand Down
Loading