Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build/BranchInfo.props
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
<MajorVersion>1</MajorVersion>
<MinorVersion>5</MinorVersion>
<PatchVersion>0</PatchVersion>
<PreReleaseLabel>preview2</PreReleaseLabel>
<PreReleaseLabel>preview5</PreReleaseLabel>
Comment thread
kere-nel marked this conversation as resolved.
Outdated
</PropertyGroup>
<PropertyGroup Condition="'$(IsStableProject)' != 'true'">
<MajorVersion>0</MajorVersion>
<MinorVersion>17</MinorVersion>
<PatchVersion>0</PatchVersion>
<PreReleaseLabel>preview2</PreReleaseLabel>
<PreReleaseLabel>preview5</PreReleaseLabel>
</PropertyGroup>
</Project>
2 changes: 1 addition & 1 deletion build/vsts-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ phases:
container: CentosContainer
steps:
# Only build native assets to avoid conflicts.
- script: ./build.sh -buildNative -$(BuildConfig) -skipRIDAgnosticAssets
- script: sudo locale-gen en_US.UTF-8 && sudo update-locale && ./build.sh -buildNative -$(BuildConfig) -skipRIDAgnosticAssets
displayName: Build

- task: PublishBuildArtifacts@1
Expand Down
41 changes: 40 additions & 1 deletion src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;

Expand Down Expand Up @@ -194,7 +195,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewType[] _types;
private readonly TextNormalizingTransformer _parent;
Expand All @@ -212,6 +213,44 @@ public Mapper(TextNormalizingTransformer parent, DataViewSchema inputSchema)
}
}

public bool CanSaveOnnx(OnnxContext ctx) => (_parent._keepDiacritics && _parent._keepNumbers && _parent._keepPunctuations);

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
for (int iinfo = 0; iinfo < _types.Length; ++iinfo)
{
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
continue;

string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
string srcVariableName = ctx.GetVariableName(inputColumnName);
string dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], outputColumnName, true);
SaveAsOnnxCore(ctx, srcVariableName, dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
Comment thread
kere-nel marked this conversation as resolved.
// StringNormalizer only takes input of shapes [C] or [1,C],
// so the input is squeezed to support inferred shapes ( e.g. [-1,C] ).
var opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true);
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });

opType = "StringNormalizer";
var normalizerOutput = ctx.AddIntermediateVariable(null, "NormalizerOutput", true);
node = ctx.CreateNode(opType, squeezeOutput, normalizerOutput, ctx.GetNodeName(opType), "");
var isCaseChange = (_parent._caseMode == TextNormalizingEstimator.CaseMode.Lower) ? "LOWER" :
(_parent._caseMode == TextNormalizingEstimator.CaseMode.Upper) ? "UPPER" : "NONE";
node.AddAttribute("case_change_action", isCaseChange);

opType = "Unsqueeze";
node = ctx.CreateNode(opType, normalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If its 'None' there is no need for the 'StringNormalizer' node at all

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, in the case that we are able to add the other options in the future, it would make sense to keep the option to pass "None". What do you think?

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
Expand Down
36 changes: 36 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,42 @@ public void PlattCalibratorOnnxConversionTest2()
Done();
}

[Fact]
public void TextNormalizingOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
var dataPath = GetDataPath("wikipedia-detox-250-line-test.tsv");
var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("label", DataKind.Boolean, 0),
new TextLoader.Column("text", DataKind.String, 1)
}, hasHeader: true);
var pipeline = new TextNormalizingEstimator(mlContext, keepDiacritics: true, columns: new[] { ("NormText", "text") }).Append(
new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.Upper, columns: new[] { ("UpperText", "text") })).Append(
Comment thread
kere-nel marked this conversation as resolved.
new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.None, columns: new[] { ("OriginalText", "text") }));
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
// Skipping test in Linux platforms temporarily
if (IsOnnxRuntimeSupported() && !RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
var onnxFileName = $"TextNormalizing.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedColumns<ReadOnlyMemory<char>>(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); //compare NormText
CompareSelectedColumns<ReadOnlyMemory<char>>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); //compare UpperText
CompareSelectedColumns<ReadOnlyMemory<char>>(transformedData.Schema[4].Name, outputNames[4], transformedData, onnxResult); //compare OriginalText
}
Done();
}

private class DataPoint
{
[VectorType(3)]
Expand Down