diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 6df8d5a604..2a944842ad 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -11,6 +11,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms.Text; @@ -194,7 +195,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly DataViewType[] _types; private readonly TextNormalizingTransformer _parent; @@ -212,6 +213,44 @@ public Mapper(TextNormalizingTransformer parent, DataViewSchema inputSchema) } } + public bool CanSaveOnnx(OnnxContext ctx) => (_parent._keepDiacritics && _parent._keepNumbers && _parent._keepPunctuations); + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + for (int iinfo = 0; iinfo < _types.Length; ++iinfo) + { + string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) + continue; + + string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName; + string srcVariableName = ctx.GetVariableName(inputColumnName); + string dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], outputColumnName, true); + SaveAsOnnxCore(ctx, srcVariableName, dstVariableName); + } + } + + private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) + { + // StringNormalizer only takes input of shapes [C] or [1,C], + // so the input is squeezed to support inferred shapes ( e.g. [-1,C] ). + var opType = "Squeeze"; + var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true); + var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", new long[] { 0 }); + + opType = "StringNormalizer"; + var normalizerOutput = ctx.AddIntermediateVariable(null, "NormalizerOutput", true); + node = ctx.CreateNode(opType, squeezeOutput, normalizerOutput, ctx.GetNodeName(opType), ""); + var isCaseChange = (_parent._caseMode == TextNormalizingEstimator.CaseMode.Lower) ? "LOWER" : + (_parent._caseMode == TextNormalizingEstimator.CaseMode.Upper) ? "UPPER" : "NONE"; + node.AddAttribute("case_change_action", isCaseChange); + + opType = "Unsqueeze"; + node = ctx.CreateNode(opType, normalizerOutput, dstVariableName, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", new long[] { 0 }); + } protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length]; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 40b8edac2b..bf0f03bf20 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -448,6 +448,42 @@ public void PlattCalibratorOnnxConversionTest2() Done(); } + [Fact] + public void TextNormalizingOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + var dataPath = GetDataPath("wikipedia-detox-250-line-test.tsv"); + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("label", DataKind.Boolean, 0), + new TextLoader.Column("text", DataKind.String, 1) + }, hasHeader: true); + var pipeline = new TextNormalizingEstimator(mlContext, keepDiacritics: true, columns: new[] { ("NormText", "text") }).Append( + new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.Upper, columns: new[] { ("UpperText", "text") })).Append( + new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.None, columns: new[] { ("OriginalText", "text") })); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + // Compare model scores produced by ML.NET and ONNX's runtime. + // Skipping test in Linux platforms temporarily + if (IsOnnxRuntimeSupported() && !RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + var onnxFileName = $"TextNormalizing.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedColumns>(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); //compare NormText + CompareSelectedColumns>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); //compare UpperText + CompareSelectedColumns>(transformedData.Schema[4].Name, outputNames[4], transformedData, onnxResult); //compare OriginalText + } + Done(); + } + private class DataPoint { [VectorType(3)]