Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 9 additions & 80 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.DataLoadSave;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Runtime;

[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
Comment thread
codemzs marked this conversation as resolved.
"Transform wrapper", TransformWrapper.LoaderSignature)]

namespace Microsoft.ML.Data
{
/// <summary>
Expand All @@ -23,20 +16,19 @@ namespace Microsoft.ML.Data
internal sealed class TransformWrapper : ITransformer
{
internal const string LoaderSignature = "TransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";

private readonly IHost _host;
private readonly IDataView _xf;
private readonly bool _allowSave;
private readonly bool _isRowToRowMapper;

public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false)
public TransformWrapper(IHostEnvironment env, IDataView xf)
{
Contracts.CheckValue(env, nameof(env));
Contracts.Check(xf is IDataTransform);

_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(xf, nameof(xf));
_xf = xf;
_allowSave = allowSave;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
Comment thread
codemzs marked this conversation as resolved.
Outdated
}

Expand All @@ -45,39 +37,12 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
_host.CheckValue(inputSchema, nameof(inputSchema));

var dv = new EmptyDataView(_host, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
var output = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv);
Comment thread
codemzs marked this conversation as resolved.

return output.Schema;
}

void ICanSaveModel.Save(ModelSaveContext ctx)
{
if (!_allowSave)
throw _host.Except("Saving is not permitted.");
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}
void ICanSaveModel.Save(ModelSaveContext ctx) => throw _host.Except("Saving is not permitted.");
Comment thread
codemzs marked this conversation as resolved.

private static VersionInfo GetVersionInfo()
Comment thread
codemzs marked this conversation as resolved.
Outdated
{
Expand All @@ -90,32 +55,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(TransformWrapper).Assembly.FullName);
}

// Factory for SignatureLoadModel.
private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(ctx, nameof(ctx));
_allowSave = true;
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();
_host.CheckDecode(n >= 0);

ctx.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}

_xf = data;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);

private static bool IsChainRowToRowMapper(IDataView view)
{
Expand All @@ -132,19 +72,8 @@ private static bool IsChainRowToRowMapper(IDataView view)
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var input = new EmptyDataView(_host, inputSchema);
var revMaps = new List<IRowToRowMapper>();
IDataView chain;
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source)
{
// Everything in the chain ought to be a row mapper.
_host.Assert(xf is IRowToRowMapper);
revMaps.Add((IRowToRowMapper)xf);
}
// The walkback should have ended at the input.
Contracts.Assert(chain == input);
revMaps.Reverse();
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
return new CompositeRowToRowMapper(inputSchema,
new[] { (IRowToRowMapper)ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) });
Comment thread
codemzs marked this conversation as resolved.
Outdated
}
}

Expand Down
23 changes: 14 additions & 9 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ internal sealed class Options : NgramExtractorTransform.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransfomer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -150,10 +150,16 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
}

IDataView view = input;
view = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns, view);
view = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view).Transform(view);
return NgramExtractorTransform.CreateDataTransform(h, extractorArgs, view);
ITransformer t0 = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns);
view = t0.Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view);
view = t1.Transform(view);
ITransformer t2 = NgramExtractorTransform.Create(h, extractorArgs, t1.Transform(view));
Comment thread
codemzs marked this conversation as resolved.
Outdated
return new TransformerChain<ITransformer>(new[] { t0, t1, t2 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransfomer(env, options, input).Transform(input);
}

/// <summary>
Expand Down Expand Up @@ -489,13 +495,11 @@ public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColum

internal static class NgramExtractionUtils
{
public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input)
public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));
env.CheckValue(input, nameof(input));

IDataView view = input;
var concatColumns = new List<ColumnConcatenatingTransformer.ColumnOptions>();
foreach (var col in columns)
{
Expand All @@ -506,10 +510,11 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu
if (col.Source.Length > 1)
concatColumns.Add(new ColumnConcatenatingTransformer.ColumnOptions(col.Name, col.Source));
}

if (concatColumns.Count > 0)
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray()).Transform(view);
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray());

return view;
return new TransformerChain<ITransformer>();
Comment thread
codemzs marked this conversation as resolved.
}

/// <summary>
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. "
+ "It does so by hashing each n-gram and using the hash value as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransformer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -132,7 +132,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
};
}

view = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view).Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view);

var featurizeArgs =
new NgramHashExtractingTransformer.Options
Expand All @@ -147,11 +147,17 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
MaximumNumberOfInverts = options.MaximumNumberOfInverts
};

view = NgramHashExtractingTransformer.Create(h, featurizeArgs, view).Transform(view);
view = t1.Transform(view);
ITransformer t2 = NgramHashExtractingTransformer.Create(h, featurizeArgs, t1.Transform(view));
Comment thread
codemzs marked this conversation as resolved.
Outdated

// Since we added columns with new names, we need to explicitly drop them before we return the IDataTransform.
return ColumnSelectingTransformer.CreateDrop(h, view, tmpColNames.ToArray()) as IDataTransform;
ITransformer t3 = new ColumnSelectingTransformer(env, null, tmpColNames.ToArray());

return new TransformerChain<ITransformer>(new[] { t1, t2, t3 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransformer(env, options, input).Transform(input);
}

/// <summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public ITransformer Fit(IDataView input)
Weighting = _weighting
};

return new TransformWrapper(_host, WordBagBuildingTransformer.Create(_host, options, input), true);
return WordBagBuildingTransformer.CreateTransfomer(_host, options, input);
}

/// <summary>
Expand Down Expand Up @@ -365,7 +365,7 @@ public ITransformer Fit(IDataView input)
MaximumNumberOfInverts = _maximumNumberOfInverts
};

return new TransformWrapper(_host, WordHashBagProducingTransformer.Create(_host, options, input), true);
return WordHashBagProducingTransformer.CreateTransformer(_host, options, input);
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,11 @@ public void EntryPointPipelineEnsembleText()
new WordHashBagProducingTransformer.Options()
{
Columns =
new[] { new WordHashBagProducingTransformer.Column() { Name = "Features", Source = new[] { "Text" } }, }
new[]
{
new WordHashBagProducingTransformer.Column()
{Name = "Features", Source = new[] {"Text"}},
}
},
data);
}
Expand Down
37 changes: 37 additions & 0 deletions test/Microsoft.ML.Functional.Tests/ModelFiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,43 @@ void AssertIsGam(ITransformer trans)
Done();
}

public class ModelInput
{
#pragma warning disable SA1401
public string[] CategoricalFeatures;
public float[] NumericalFeatures;
#pragma warning restore SA1401
}

public class ModelOutput
{
#pragma warning disable SA1401
public float[] Score;
#pragma warning restore SA1401
}


[Fact]
public void LoadModelWithOptionalColumnTransform()
{
SchemaDefinition inputSchemaDefinition = SchemaDefinition.Create(typeof(ModelInput));
inputSchemaDefinition[nameof(ModelInput.CategoricalFeatures)].ColumnType = new VectorDataViewType(TextDataViewType.Instance, 5);
inputSchemaDefinition[nameof(ModelInput.NumericalFeatures)].ColumnType = new VectorDataViewType(NumberDataViewType.Single, 3);
var mlContext = new MLContext();
ITransformer trainedModel;
DataViewSchema dataViewSchema;
using (var stream = new FileStream(Path.Combine(Directory.GetCurrentDirectory(), @"..\..\..\..\test\data\backcompat\modelwithoptionalcolumntransform.zip"),
FileMode.Open, FileAccess.Read, FileShare.Read))
{
trainedModel = mlContext.Model.Load(stream, out dataViewSchema);
Comment thread
codemzs marked this conversation as resolved.
Outdated
}

var model = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchemaDefinition: inputSchemaDefinition);
Comment thread
codemzs marked this conversation as resolved.
var prediction = model.Predict(new ModelInput() { CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" }, NumericalFeatures = new float [] { 1, 1, 1 } });

Assert.Equal(1, prediction.Score[0]);
}

[Fact]
public void SaveAndLoadModelWithLoader()
{
Expand Down
Loading