Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 12 additions & 104 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.DataLoadSave;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Runtime;

[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
"Transform wrapper", TransformWrapper.LoaderSignature)]

namespace Microsoft.ML.Data
{
/// <summary>
Expand All @@ -23,128 +16,43 @@ namespace Microsoft.ML.Data
internal sealed class TransformWrapper : ITransformer
{
internal const string LoaderSignature = "TransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";

private readonly IHost _host;
private readonly IDataView _xf;
private readonly bool _allowSave;
private readonly bool _isRowToRowMapper;

public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false)
public TransformWrapper(IHostEnvironment env, IDataView xf)
{
Contracts.CheckValue(env, nameof(env));
Contracts.Check(xf is IDataTransform);

_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(xf, nameof(xf));
_xf = xf;
_allowSave = allowSave;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
}

public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));

var dv = new EmptyDataView(_host, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
return output.Schema;
}
var output = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv);

void ICanSaveModel.Save(ModelSaveContext ctx)
{
if (!_allowSave)
throw _host.Except("Saving is not permitted.");
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "XF WRPR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TransformWrapper).Assembly.FullName);
}

// Factory for SignatureLoadModel.
private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(ctx, nameof(ctx));
_allowSave = true;
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();
_host.CheckDecode(n >= 0);

ctx.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}

_xf = data;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
return output.Schema;
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
void ICanSaveModel.Save(ModelSaveContext ctx) => throw _host.Except("Saving is not permitted.");

private static bool IsChainRowToRowMapper(IDataView view)
{
for (; view is IDataTransform xf; view = xf.Source)
{
if (!(xf is IRowToRowMapper))
return false;
}
return true;
}
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);

bool ITransformer.IsRowToRowMapper => _isRowToRowMapper;
bool ITransformer.IsRowToRowMapper => _xf is IRowToRowMapper;

IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var input = new EmptyDataView(_host, inputSchema);
var revMaps = new List<IRowToRowMapper>();
IDataView chain;
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source)
{
// Everything in the chain ought to be a row mapper.
_host.Assert(xf is IRowToRowMapper);
revMaps.Add((IRowToRowMapper)xf);
}
// The walkback should have ended at the input.
Contracts.Assert(chain == input);
revMaps.Reverse();
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
var transform = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) as IRowToRowMapper;
_host.Check(transform is IRowToRowMapper);

return new CompositeRowToRowMapper(inputSchema, new[] { transform });
}
}

Expand Down
23 changes: 14 additions & 9 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ internal sealed class Options : NgramExtractorTransform.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransfomer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -150,10 +150,16 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
}

IDataView view = input;
view = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns, view);
view = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view).Transform(view);
return NgramExtractorTransform.CreateDataTransform(h, extractorArgs, view);
ITransformer t0 = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns);
view = t0.Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view);
view = t1.Transform(view);
ITransformer t2 = NgramExtractorTransform.Create(h, extractorArgs, view);
return new TransformerChain<ITransformer>(new[] { t0, t1, t2 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransfomer(env, options, input).Transform(input);
}

/// <summary>
Expand Down Expand Up @@ -489,13 +495,11 @@ public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColum

internal static class NgramExtractionUtils
{
public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input)
public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));
env.CheckValue(input, nameof(input));

IDataView view = input;
var concatColumns = new List<ColumnConcatenatingTransformer.ColumnOptions>();
foreach (var col in columns)
{
Expand All @@ -506,10 +510,11 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu
if (col.Source.Length > 1)
concatColumns.Add(new ColumnConcatenatingTransformer.ColumnOptions(col.Name, col.Source));
}

if (concatColumns.Count > 0)
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray()).Transform(view);
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray());

return view;
return new TransformerChain<ITransformer>();
}

/// <summary>
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. "
+ "It does so by hashing each n-gram and using the hash value as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransformer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -132,7 +132,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
};
}

view = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view).Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view);

var featurizeArgs =
new NgramHashExtractingTransformer.Options
Expand All @@ -147,11 +147,17 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
MaximumNumberOfInverts = options.MaximumNumberOfInverts
};

view = NgramHashExtractingTransformer.Create(h, featurizeArgs, view).Transform(view);
view = t1.Transform(view);
ITransformer t2 = NgramHashExtractingTransformer.Create(h, featurizeArgs, view);

// Since we added columns with new names, we need to explicitly drop them before we return the IDataTransform.
return ColumnSelectingTransformer.CreateDrop(h, view, tmpColNames.ToArray()) as IDataTransform;
ITransformer t3 = new ColumnSelectingTransformer(env, null, tmpColNames.ToArray());

return new TransformerChain<ITransformer>(new[] { t1, t2, t3 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransformer(env, options, input).Transform(input);
}

/// <summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public ITransformer Fit(IDataView input)
Weighting = _weighting
};

return new TransformWrapper(_host, WordBagBuildingTransformer.Create(_host, options, input), true);
return WordBagBuildingTransformer.CreateTransfomer(_host, options, input);
}

/// <summary>
Expand Down Expand Up @@ -365,7 +365,7 @@ public ITransformer Fit(IDataView input)
MaximumNumberOfInverts = _maximumNumberOfInverts
};

return new TransformWrapper(_host, WordHashBagProducingTransformer.Create(_host, options, input), true);
return WordHashBagProducingTransformer.CreateTransformer(_host, options, input);
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,11 @@ public void EntryPointPipelineEnsembleText()
new WordHashBagProducingTransformer.Options()
{
Columns =
new[] { new WordHashBagProducingTransformer.Column() { Name = "Features", Source = new[] { "Text" } }, }
new[]
{
new WordHashBagProducingTransformer.Column()
{Name = "Features", Source = new[] {"Text"}},
}
},
data);
}
Expand Down
32 changes: 32 additions & 0 deletions test/Microsoft.ML.Functional.Tests/ModelFiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,38 @@ void AssertIsGam(ITransformer trans)
Done();
}

public class ModelInput
{
#pragma warning disable SA1401
public string[] CategoricalFeatures;
public float[] NumericalFeatures;
#pragma warning restore SA1401
}

public class ModelOutput
{
#pragma warning disable SA1401
public float[] Score;
#pragma warning restore SA1401
}


[Fact]
public void LoadModelWithOptionalColumnTransform()
{
SchemaDefinition inputSchemaDefinition = SchemaDefinition.Create(typeof(ModelInput));
inputSchemaDefinition[nameof(ModelInput.CategoricalFeatures)].ColumnType = new VectorDataViewType(TextDataViewType.Instance, 5);
inputSchemaDefinition[nameof(ModelInput.NumericalFeatures)].ColumnType = new VectorDataViewType(NumberDataViewType.Single, 3);
var mlContext = new MLContext();
ITransformer trainedModel;
DataViewSchema dataViewSchema;
trainedModel = mlContext.Model.Load(GetDataPath("backcompat", "modelwithoptionalcolumntransform.zip"), out dataViewSchema);
var model = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchemaDefinition: inputSchemaDefinition);
var prediction = model.Predict(new ModelInput() { CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" }, NumericalFeatures = new float [] { 1, 1, 1 } });

Assert.Equal(1, prediction.Score[0]);
}

[Fact]
public void SaveAndLoadModelWithLoader()
{
Expand Down
Loading