Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 13 additions & 94 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.DataLoadSave;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Runtime;

[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
Comment thread
codemzs marked this conversation as resolved.
"Transform wrapper", TransformWrapper.LoaderSignature)]

namespace Microsoft.ML.Data
{
/// <summary>
Expand All @@ -23,99 +16,34 @@ namespace Microsoft.ML.Data
internal sealed class TransformWrapper : ITransformer
{
internal const string LoaderSignature = "TransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";

private readonly IHost _host;
private readonly IDataView _xf;
private readonly bool _allowSave;
private readonly bool _isRowToRowMapper;

public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false)
public TransformWrapper(IHostEnvironment env, IDataView xf)
{
Contracts.CheckValue(env, nameof(env));
Contracts.Check(xf is IDataTransform);

_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(xf, nameof(xf));
_host.Check(IsChainRowToRowMapper(_xf));
Comment thread
codemzs marked this conversation as resolved.
Outdated
_xf = xf;
_allowSave = allowSave;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
}

public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));

var dv = new EmptyDataView(_host, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
return output.Schema;
}

void ICanSaveModel.Save(ModelSaveContext ctx)
{
if (!_allowSave)
throw _host.Except("Saving is not permitted.");
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));
var output = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv);
Comment thread
codemzs marked this conversation as resolved.

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "XF WRPR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TransformWrapper).Assembly.FullName);
return output.Schema;
}

// Factory for SignatureLoadModel.
private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(ctx, nameof(ctx));
_allowSave = true;
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();
_host.CheckDecode(n >= 0);
void ICanSaveModel.Save(ModelSaveContext ctx) => throw _host.Except("Saving is not permitted.");
Comment thread
codemzs marked this conversation as resolved.

ctx.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}

_xf = data;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);

private static bool IsChainRowToRowMapper(IDataView view)
{
Expand All @@ -127,24 +55,15 @@ private static bool IsChainRowToRowMapper(IDataView view)
return true;
}

bool ITransformer.IsRowToRowMapper => _isRowToRowMapper;
bool ITransformer.IsRowToRowMapper => true;

IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var input = new EmptyDataView(_host, inputSchema);
var revMaps = new List<IRowToRowMapper>();
IDataView chain;
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source)
{
// Everything in the chain ought to be a row mapper.
_host.Assert(xf is IRowToRowMapper);
revMaps.Add((IRowToRowMapper)xf);
}
// The walkback should have ended at the input.
Contracts.Assert(chain == input);
revMaps.Reverse();
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
var transform = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) as IRowToRowMapper;
_host.Check(transform is IRowToRowMapper);

return new CompositeRowToRowMapper(inputSchema, new[] { transform});
Comment thread
codemzs marked this conversation as resolved.
Outdated
}
}

Expand Down
23 changes: 14 additions & 9 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ internal sealed class Options : NgramExtractorTransform.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransfomer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -150,10 +150,16 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
}

IDataView view = input;
view = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns, view);
view = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view).Transform(view);
return NgramExtractorTransform.CreateDataTransform(h, extractorArgs, view);
ITransformer t0 = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns);
view = t0.Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view);
view = t1.Transform(view);
ITransformer t2 = NgramExtractorTransform.Create(h, extractorArgs, view);
return new TransformerChain<ITransformer>(new[] { t0, t1, t2 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransfomer(env, options, input).Transform(input);
}

/// <summary>
Expand Down Expand Up @@ -489,13 +495,11 @@ public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColum

internal static class NgramExtractionUtils
{
public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input)
public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));
env.CheckValue(input, nameof(input));

IDataView view = input;
var concatColumns = new List<ColumnConcatenatingTransformer.ColumnOptions>();
foreach (var col in columns)
{
Expand All @@ -506,10 +510,11 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu
if (col.Source.Length > 1)
concatColumns.Add(new ColumnConcatenatingTransformer.ColumnOptions(col.Name, col.Source));
}

if (concatColumns.Count > 0)
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray()).Transform(view);
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray());

return view;
return new TransformerChain<ITransformer>();
Comment thread
codemzs marked this conversation as resolved.
}

/// <summary>
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. "
+ "It does so by hashing each n-gram and using the hash value as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransformer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -132,7 +132,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
};
}

view = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view).Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view);

var featurizeArgs =
new NgramHashExtractingTransformer.Options
Expand All @@ -147,11 +147,17 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
MaximumNumberOfInverts = options.MaximumNumberOfInverts
};

view = NgramHashExtractingTransformer.Create(h, featurizeArgs, view).Transform(view);
view = t1.Transform(view);
ITransformer t2 = NgramHashExtractingTransformer.Create(h, featurizeArgs, view);

// Since we added columns with new names, we need to explicitly drop them before we return the IDataTransform.
return ColumnSelectingTransformer.CreateDrop(h, view, tmpColNames.ToArray()) as IDataTransform;
ITransformer t3 = new ColumnSelectingTransformer(env, null, tmpColNames.ToArray());

return new TransformerChain<ITransformer>(new[] { t1, t2, t3 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransformer(env, options, input).Transform(input);
}

/// <summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public ITransformer Fit(IDataView input)
Weighting = _weighting
};

return new TransformWrapper(_host, WordBagBuildingTransformer.Create(_host, options, input), true);
return WordBagBuildingTransformer.CreateTransfomer(_host, options, input);
}

/// <summary>
Expand Down Expand Up @@ -365,7 +365,7 @@ public ITransformer Fit(IDataView input)
MaximumNumberOfInverts = _maximumNumberOfInverts
};

return new TransformWrapper(_host, WordHashBagProducingTransformer.Create(_host, options, input), true);
return WordHashBagProducingTransformer.CreateTransformer(_host, options, input);
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,11 @@ public void EntryPointPipelineEnsembleText()
new WordHashBagProducingTransformer.Options()
{
Columns =
new[] { new WordHashBagProducingTransformer.Column() { Name = "Features", Source = new[] { "Text" } }, }
new[]
{
new WordHashBagProducingTransformer.Column()
{Name = "Features", Source = new[] {"Text"}},
}
},
data);
}
Expand Down
32 changes: 32 additions & 0 deletions test/Microsoft.ML.Functional.Tests/ModelFiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,38 @@ void AssertIsGam(ITransformer trans)
Done();
}

public class ModelInput
{
#pragma warning disable SA1401
public string[] CategoricalFeatures;
public float[] NumericalFeatures;
#pragma warning restore SA1401
}

public class ModelOutput
{
#pragma warning disable SA1401
public float[] Score;
#pragma warning restore SA1401
}


[Fact]
public void LoadModelWithOptionalColumnTransform()
{
SchemaDefinition inputSchemaDefinition = SchemaDefinition.Create(typeof(ModelInput));
inputSchemaDefinition[nameof(ModelInput.CategoricalFeatures)].ColumnType = new VectorDataViewType(TextDataViewType.Instance, 5);
inputSchemaDefinition[nameof(ModelInput.NumericalFeatures)].ColumnType = new VectorDataViewType(NumberDataViewType.Single, 3);
var mlContext = new MLContext();
ITransformer trainedModel;
DataViewSchema dataViewSchema;
trainedModel = mlContext.Model.Load(Path.Combine(Directory.GetCurrentDirectory(), @"..\..\..\..\test\data\backcompat\modelwithoptionalcolumntransform.zip"), out dataViewSchema);
var model = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchemaDefinition: inputSchemaDefinition);
Comment thread
codemzs marked this conversation as resolved.
var prediction = model.Predict(new ModelInput() { CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" }, NumericalFeatures = new float [] { 1, 1, 1 } });

Assert.Equal(1, prediction.Score[0]);
}

[Fact]
public void SaveAndLoadModelWithLoader()
{
Expand Down
Loading