Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ private protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
Host.Assert(LabelColumn.IsValid);

if (!LabelColumn.IsCompatibleWith(labelCol))
throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", WeightColumn.Name,
throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", LabelColumn.Name,
LabelColumn.GetTypeString(), labelCol.GetTypeString());
}

Expand Down
43 changes: 31 additions & 12 deletions src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,33 @@

namespace Microsoft.ML.Transforms
{
internal sealed class MissingValueDroppingEstimator : TrivialEstimator<MissingValueDroppingTransformer>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this useful in other cases? Do we need to add this to the catalog and make it public?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is - seems to be something very specific to Ngram extraction.


In reply to: 370393234 [](ancestors = 370393234)

{
public MissingValueDroppingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingEstimator)), new MissingValueDroppingTransformer(env, columns))
{
}

public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

var resultDic = inputSchema.ToDictionary(x => x.Name);
foreach (var (outputColumnName, inputColumnName) in Transformer.Columns)
{
if (!inputSchema.TryFindColumn(inputColumnName, out var originalColumn))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName);
if (originalColumn.Kind == SchemaShape.Column.VectorKind.Scalar)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", originalColumn.Name, "Vector", "Scalar");
if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(originalColumn.ItemType, out _))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", originalColumn.Name, "Single, Double or Key", originalColumn.ItemType.ToString());
var col = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.VariableVector, originalColumn.ItemType, originalColumn.IsKey, originalColumn.Annotations);
resultDic[outputColumnName] = col;
}
return new SchemaShape(resultDic.Values);
}
}

/// <include file='doc.xml' path='doc/members/member[@name="NADrop"]'/>
internal sealed class MissingValueDroppingTransformer : OneToOneTransformerBase
{
Expand Down Expand Up @@ -163,23 +190,15 @@ public Mapper(MissingValueDroppingTransformer parent, DataViewSchema inputSchema
{
inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]);
var srcCol = inputSchema[_srcCols[i]];
if (!(srcCol.Type is VectorDataViewType))
throw _parent.Host.Except($"Column '{srcCol.Name}' is not a vector column");
if (!Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(srcCol.Type.GetItemType(), out _isNAs[i]))
throw _parent.Host.Except($"Column '{srcCol.Name}' is of type {srcCol.Type.GetItemType()}, which does not support missing values");
_srcTypes[i] = srcCol.Type;
_types[i] = new VectorDataViewType((PrimitiveDataViewType)srcCol.Type.GetItemType());
_isNAs[i] = GetIsNADelegate(srcCol.Type);
}
}

/// <summary>
/// Returns the isNA predicate for the respective type.
/// </summary>
private Delegate GetIsNADelegate(DataViewType type)
{
Func<DataViewType, Delegate> func = GetIsNADelegate<int>;
return Utils.MarshalInvoke(func, type.GetItemType().RawType, type);
}

private Delegate GetIsNADelegate<T>(DataViewType type) => Data.Conversion.Conversions.Instance.GetIsNAPredicate<T>(type.GetItemType());

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
Expand Down
65 changes: 22 additions & 43 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,11 @@ internal sealed class Options : NgramExtractorTransform.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";

internal static ITransformer CreateTransfomer(IHostEnvironment env, Options options, IDataView input)
internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");

// Compose the WordBagTransform from a tokenize transform,
Expand Down Expand Up @@ -149,17 +148,14 @@ internal static ITransformer CreateTransfomer(IHostEnvironment env, Options opti
};
}

IDataView view = input;
ITransformer t0 = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns);
view = t0.Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view);
view = t1.Transform(view);
ITransformer t2 = NgramExtractorTransform.Create(h, extractorArgs, view);
return new TransformerChain<ITransformer>(new[] { t0, t1, t2 });
IEstimator<ITransformer> estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns);
estimator = estimator.Append(new WordTokenizingEstimator(env, tokenizeColumns));
estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema)));
return estimator;
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransfomer(env, options, input).Transform(input);
(IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input);
}

/// <summary>
Expand Down Expand Up @@ -262,15 +258,14 @@ internal sealed class Options : ArgumentsBase

internal const string LoaderSignature = "NgramExtractor";

internal static ITransformer Create(IHostEnvironment env, Options options, IDataView input, TermLoaderArguments termLoaderArgs = null)
internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema, TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");

var chain = new TransformerChain<ITransformer>();
var chain = new EstimatorChain<ITransformer>();

var termCols = new List<Column>();
var isTermCol = new bool[options.Columns.Length];
Expand All @@ -281,9 +276,8 @@ internal static ITransformer Create(IHostEnvironment env, Options options, IData

h.CheckNonWhiteSpace(col.Name, nameof(col.Name));
h.CheckNonWhiteSpace(col.Source, nameof(col.Source));
int colId;
if (input.Schema.TryGetColumnIndex(col.Source, out colId) &&
input.Schema[colId].Type.GetItemType() is TextDataViewType)
if (inputSchema.TryFindColumn(col.Source, out var colShape) &&
colShape.ItemType is TextDataViewType)
{
termCols.Add(col);
isTermCol[i] = true;
Expand Down Expand Up @@ -327,9 +321,9 @@ internal static ITransformer Create(IHostEnvironment env, Options options, IData
using (var ch = env.Start("Create key data view"))
keyData = ValueToKeyMappingTransformer.GetKeyDataViewOrNull(env, ch, termLoaderArgs.DataFile, termLoaderArgs.TermsColumn, termLoaderArgs.Loader, out var autoConvert);
}
chain = chain.Append<ITransformer>(new ValueToKeyMappingEstimator(h, columnOptions.ToArray(), keyData).Fit(input));
chain = chain.Append<ITransformer>(new ValueToKeyMappingEstimator(h, columnOptions.ToArray(), keyData));
if (missingDropColumns != null)
chain = chain.Append<ITransformer>(new MissingValueDroppingTransformer(h, missingDropColumns.Select(x => (x, x)).ToArray()));
chain = chain.Append<ITransformer>(new MissingValueDroppingEstimator(h, missingDropColumns.Select(x => (x, x)).ToArray()));
}

var ngramColumns = new NgramExtractingEstimator.ColumnOptions[options.Columns.Length];
Expand All @@ -345,30 +339,19 @@ internal static ITransformer Create(IHostEnvironment env, Options options, IData
isTermCol[iinfo] ? column.Name : column.Source
);
}

input = chain.Transform(input);
return chain.Append<ITransformer>(new NgramExtractingEstimator(env, ngramColumns).Fit(input));
return chain.Append<ITransformer>(new NgramExtractingEstimator(env, ngramColumns));
}

internal static IDataTransform CreateDataTransform(IHostEnvironment env, Options options, IDataView input,
TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));
return Create(env, options, input, termLoaderArgs).Transform(input) as IDataTransform;
return CreateEstimator(env, options, SchemaShape.Create(input.Schema), termLoaderArgs).Fit(input).Transform(input)/* Create(env, options, input, termLoaderArgs).Transform(input) */as IDataTransform;
}

internal static ITransformer Create(IHostEnvironment env, NgramExtractorArguments extractorArgs, IDataView input,
ExtractorColumn[] cols, TermLoaderArguments termLoaderArgs = null)
internal static Options CreateNgramExtractorOptions(NgramExtractorArguments extractorArgs, ExtractorColumn[] cols)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
h.CheckValue(extractorArgs, nameof(extractorArgs));
h.CheckValue(input, nameof(input));
h.CheckUserArg(extractorArgs.SkipLength < extractorArgs.NgramLength, nameof(extractorArgs.SkipLength), "Should be less than " + nameof(extractorArgs.NgramLength));
h.CheckUserArg(Utils.Size(cols) > 0, nameof(Options.Columns), "Must be specified");
h.CheckValueOrNull(termLoaderArgs);

var extractorCols = new Column[cols.Length];
for (int i = 0; i < cols.Length; i++)
{
Expand All @@ -385,8 +368,7 @@ internal static ITransformer Create(IHostEnvironment env, NgramExtractorArgument
MaxNumTerms = extractorArgs.MaxNumTerms,
Weighting = extractorArgs.Weighting
};

return Create(h, options, input, termLoaderArgs);
return options;
}

internal static INgramExtractorFactory Create(IHostEnvironment env, NgramExtractorArguments extractorArgs,
Expand Down Expand Up @@ -468,7 +450,8 @@ public NgramExtractorFactory(NgramExtractorTransform.NgramExtractorArguments ext

public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols)
{
return NgramExtractorTransform.Create(env, _extractorArgs, input, cols, _termLoaderArgs);
var options = NgramExtractorTransform.CreateNgramExtractorOptions(_extractorArgs, cols);
return NgramExtractorTransform.CreateEstimator(env, options, SchemaShape.Create(input.Schema), _termLoaderArgs).Fit(input);
}
}

Expand All @@ -495,26 +478,22 @@ public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColum

internal static class NgramExtractionUtils
{
public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns)
public static IEstimator<ITransformer> GetConcatEstimator(IHostEnvironment env, ManyToOneColumn[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));

var concatColumns = new List<ColumnConcatenatingTransformer.ColumnOptions>();
var estimator = new EstimatorChain<ITransformer>();
foreach (var col in columns)
{
env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Options.Columns));
env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source));
env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));
if (col.Source.Length > 1)
concatColumns.Add(new ColumnConcatenatingTransformer.ColumnOptions(col.Name, col.Source));
estimator = estimator.Append<ITransformer>(new ColumnConcatenatingEstimator(env, col.Name, col.Source));
}

if (concatColumns.Count > 0)
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray());

return new TransformerChain<ITransformer>();
return estimator;
}

/// <summary>
Expand Down
16 changes: 9 additions & 7 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,13 @@ internal WordBagEstimator(IHostEnvironment env,
/// <summary> Trains and returns a <see cref="ITransformer"/>.</summary>
public ITransformer Fit(IDataView input)
{
// Create arguments.
var options = new WordBagBuildingTransformer.Options
var estimator = WordBagBuildingTransformer.CreateEstimator(_host, CreateOptions(), SchemaShape.Create(input.Schema));
return estimator.Fit(input);
}

private WordBagBuildingTransformer.Options CreateOptions()
{
return new WordBagBuildingTransformer.Options
{
Columns = _columns.Select(x => new WordBagBuildingTransformer.Column { Name = x.outputColumnName, Source = x.sourceColumnsNames }).ToArray(),
NgramLength = _ngramLength,
Expand All @@ -184,8 +189,6 @@ public ITransformer Fit(IDataView input)
MaxNumTerms = new[] { _maxNumTerms },
Weighting = _weighting
};

return WordBagBuildingTransformer.CreateTransfomer(_host, options, input);
}

/// <summary>
Expand All @@ -196,9 +199,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));

var fakeSchema = FakeSchemaFactory.Create(inputSchema);
var transformer = Fit(new EmptyDataView(_host, fakeSchema));
return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema));
var estimator = WordBagBuildingTransformer.CreateEstimator(_host, CreateOptions(), inputSchema);
return estimator.GetOutputSchema(inputSchema);
}
}

Expand Down
29 changes: 22 additions & 7 deletions test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,7 @@ public void WordBagWorkout()
var est = new WordBagEstimator(ML, "bag_of_words", "text").
Append(new WordHashBagEstimator(ML, "bag_of_wordshash", "text", maximumNumberOfInverts: -1));

// The following call fails because of the following issue
// https://github.com/dotnet/machinelearning/issues/969
// TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic);
TestEstimatorCore(est, data, invalidInput: invalidData);

var outputPath = GetOutputPath("Text", "bag_of_words.tsv");
var savedData = ML.Data.TakeRows(est.Fit(data).Transform(data), 4);
Expand Down Expand Up @@ -686,10 +684,11 @@ public void LdaWorkout()
Append(new LatentDirichletAllocationEstimator(env, "topics", "bag_of_words", 10, maximumNumberOfIterations: 10,
resetRandomGenerator: true));

// The following call fails because of the following issue
// https://github.com/dotnet/machinelearning/issues/969
// In this test it manifests because of the WordBagEstimator in the estimator chain
// TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic);
// Diabling this check due to the following issue with consitency of output.
// `seed` specified in ConsoleEnvironment has no effect.
// https://github.com/dotnet/machinelearning/issues/1004
// On single box, setting `s.ResetRandomGenerator = true` works but fails on build server
// TestEstimatorCore(est, data, invalidInput: invalidData);

var outputPath = GetOutputPath("Text", "ldatopics.tsv");
using (var ch = env.Start("save"))
Expand Down Expand Up @@ -764,5 +763,21 @@ public void TestTextFeaturizerBackCompat()
Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
}
}

[Fact]
public void TestWordBagInPipeline()
{
string dataPath = GetDataPath("breast-cancer.txt");
var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("Label", DataKind.Boolean, 0),
new TextLoader.Column("Features", DataKind.String, 1, 9)
});

var pipeline = ML.Transforms.Text.ProduceWordBags("Features")
.Append(ML.BinaryClassification.Trainers.FastTree());

TestEstimatorCore(pipeline, dataView);
Done();
}
}
}