diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs index 20c2d8be37..8bb0371089 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs @@ -43,7 +43,7 @@ namespace Microsoft.ML.Trainers /// | Is normalization required? | Yes | /// | Is caching required? | No | /// | Required NuGet in addition to Microsoft.ML | None | - /// | Exportable to ONNX | No | + /// | Exportable to ONNX | Yes | /// /// ### Training Algorithm Details /// [Naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier) diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 19f26c0ef0..6df8d5a604 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -442,7 +442,7 @@ private bool IsCombiningDiacritic(char ch) /// | Does this estimator need to look at the data to train its parameters? | No | /// | Input column data type | Scalar or Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType)| /// | Output column data type | Scalar or variable-sized Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType)| - /// | Exportable to ONNX | No | + /// | Exportable to ONNX | Yes | /// /// The resulting creates a new column, named as specified /// in the output column name parameters, and normalizes the textual input data by changing case, removing diacritical marks, diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index bb403d5577..9aa1772de3 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -11,6 +11,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms.Text; @@ -184,7 +185,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly DataViewType _type; private readonly TokenizingByCharactersTransformer _parent; @@ -204,6 +205,54 @@ public Mapper(TokenizingByCharactersTransformer parent, DataViewSchema inputSche _isSourceVector[i] = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type is VectorDataViewType; } + public bool CanSaveOnnx(OnnxContext ctx) => true; + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + for (int iinfo = 0; iinfo < _isSourceVector.Length; ++iinfo) + { + string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) + continue; + + string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName; + string srcVariableName = ctx.GetVariableName(inputColumnName); + string dstVariableName = ctx.AddIntermediateVariable(_type, outputColumnName, true); + SaveAsOnnxCore(ctx, srcVariableName, dstVariableName); + } + } + + private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) + { + string opType = "Tokenizer"; + string tokenizerOutput = ctx.AddIntermediateVariable(null, "TokenizerOutput", true); + var node = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft"); + node.AddAttribute("mark", _parent._useMarkerChars); + node.AddAttribute("mincharnum", 1); + node.AddAttribute("pad_value", ""); + node.AddAttribute("separators", new string[] { "" }); + + opType = "Squeeze"; + var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true); + node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", new long[] { 0 }); + + opType = "LabelEncoder"; + var labelEncoderOutput = ctx.AddIntermediateVariable(null, "LabelEncoderOutput", true); + node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + + IEnumerable charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString()); + IEnumerable charValues = Enumerable.Range(0, 65535).Select(x => Convert.ToInt64(x)); ; + node.AddAttribute("keys_strings", charStrings); + node.AddAttribute("values_int64s", charValues); + + opType = "Cast"; + var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt16).ToType(); + castNode.AddAttribute("to", t); + } + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length]; @@ -558,7 +607,7 @@ private ValueGetter> MakeGetterVec(DataViewRow input, int iinfo) /// | Does this estimator need to look at the data to train its parameters? | Yes | /// | Input column data type | Scalar or Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) | /// | Output column data type | Variable-sized vector of [key](xref:Microsoft.ML.Data.KeyDataViewType) type. | - /// | Exportable to ONNX | No | + /// | Exportable to ONNX | Yes | /// /// The estimator tokenizes characters by splitting text into sequences of characters using a sliding window. /// During training, the estimator builds a key-value pair dictionary with the encountered sequences of characters. diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 7fd42545fb..b7da9dab3f 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -921,6 +921,37 @@ public void WordEmbeddingsTest() Done(); } + [Theory] + [CombinatorialData] + public void TokenizingByCharactersOnnxConversionTest(bool useMarkerCharacters) + { + var mlContext = new MLContext(seed: 1); + var dataPath = GetDataPath("wikipedia-detox-250-line-test.tsv"); + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("label", DataKind.Boolean, 0), + new TextLoader.Column("text", DataKind.String, 1) + }, hasHeader: true); + var pipeline = new TokenizingByCharactersEstimator(mlContext, useMarkerCharacters: useMarkerCharacters, columns: new[] { ("TokenizedText", "text") }); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + // Compare model scores produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + var onnxFileName = $"TokenizingByCharacters.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedVectorColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); //compare scores + } + Done(); + } + [Theory] // These are the supported conversions // ML.NET does not allow any conversions between signed and unsigned numeric types