diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index a07dbeea08..e354877578 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -1009,12 +1009,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); castNode.AddAttribute("to", t); - // The predictedLabel is a scalar. But the onnx output of ML.NET output expects a [1x1] tensor for output. So reshape it here - opType = "Reshape"; - long[] shape = { 1, 1 }; - long[] shapeDim = { 2 }; - var shapeVar = ctx.AddInitializer(shape, shapeDim, "ShapeVar"); - var reshapeNode = ctx.CreateNode(opType, new[] { castNodeOutput, shapeVar }, new[] { predictedLabelUint32 }, ctx.GetNodeName(opType), ""); + opType = "Unsqueeze"; + var unsqueezeNode = ctx.CreateNode(opType, castNodeOutput, predictedLabelUint32, ctx.GetNodeName(opType), ""); + unsqueezeNode.AddAttribute("axes", new long[] { 0 }); return true; } diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 2a944842ad..0185d79951 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -238,7 +238,7 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV var opType = "Squeeze"; var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true); var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), ""); - node.AddAttribute("axes", new long[] { 0 }); + node.AddAttribute("axes", new long[] { 1 }); opType = "StringNormalizer"; var normalizerOutput = ctx.AddIntermediateVariable(null, "NormalizerOutput", true); @@ -249,7 +249,7 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV opType = "Unsqueeze"; node = ctx.CreateNode(opType, normalizerOutput, dstVariableName, ctx.GetNodeName(opType), ""); - node.AddAttribute("axes", new long[] { 0 }); + node.AddAttribute("axes", new long[] { 1 }); } protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 9aa1772de3..027cfb7f0b 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -190,6 +190,7 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx private readonly DataViewType _type; private readonly TokenizingByCharactersTransformer _parent; private readonly bool[] _isSourceVector; + private readonly int[] _sourceVectorLength; // Constructed and cached the first time it is needed. private volatile string _keyValuesStr; private volatile int[] _keyValuesBoundaries; @@ -201,8 +202,13 @@ public Mapper(TokenizingByCharactersTransformer parent, DataViewSchema inputSche var keyType = new KeyDataViewType(typeof(ushort), CharsCount); _type = new VectorDataViewType(keyType); _isSourceVector = new bool[_parent.ColumnPairs.Length]; + _sourceVectorLength = new int[_parent.ColumnPairs.Length]; for (int i = 0; i < _isSourceVector.Length; i++) - _isSourceVector[i] = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type is VectorDataViewType; + { + var type = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type; + _isSourceVector[i] = type is VectorDataViewType; + _sourceVectorLength[i] = type.GetValueCount(); + } } public bool CanSaveOnnx(OnnxContext ctx) => true; @@ -219,14 +225,20 @@ public void SaveAsOnnx(OnnxContext ctx) string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName; string srcVariableName = ctx.GetVariableName(inputColumnName); string dstVariableName = ctx.AddIntermediateVariable(_type, outputColumnName, true); - SaveAsOnnxCore(ctx, srcVariableName, dstVariableName); + SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName); } } - private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) + private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { string opType = "Tokenizer"; - string tokenizerOutput = ctx.AddIntermediateVariable(null, "TokenizerOutput", true); + DataViewType dataViewType; + if (_isSourceVector[iinfo]) + dataViewType = new VectorDataViewType(TextDataViewType.Instance, _sourceVectorLength[iinfo]); + else + dataViewType = TextDataViewType.Instance; + + string tokenizerOutput = ctx.AddIntermediateVariable(dataViewType, "TokenizerOutput", true); var node = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft"); node.AddAttribute("mark", _parent._useMarkerChars); node.AddAttribute("mincharnum", 1); @@ -234,12 +246,12 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV node.AddAttribute("separators", new string[] { "" }); opType = "Squeeze"; - var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true); + var squeezeOutput = ctx.AddIntermediateVariable(dataViewType, "SqueezeOutput"); node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), ""); - node.AddAttribute("axes", new long[] { 0 }); + node.AddAttribute("axes", new long[] { 1 }); opType = "LabelEncoder"; - var labelEncoderOutput = ctx.AddIntermediateVariable(null, "LabelEncoderOutput", true); + var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput"); node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType)); IEnumerable charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString()); diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 59a4a2abcb..1eac17ccaa 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -416,9 +416,9 @@ public void SaveAsOnnx(OnnxContext ctx) tokenizerNode.AddAttribute("separators", separators); opType = "Squeeze"; - var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name, true); + var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name); var squeezeNode = ctx.CreateNode(opType, intermediateVar, squeezeOutput, ctx.GetNodeName(opType), ""); - squeezeNode.AddAttribute("axes", new long[] { 0 }); + squeezeNode.AddAttribute("axes", new long[] { 1 }); } } } diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index 672cdb5bb8..e60d2b5b67 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -274,14 +274,22 @@ }, { "input": [ - "CastNodeOutput", - "ShapeVar" + "CastNodeOutput" ], "output": [ "PredictedLabel" ], - "name": "Reshape", - "opType": "Reshape" + "name": "Unsqueeze", + "opType": "Unsqueeze", + "attribute": [ + { + "name": "axes", + "ints": [ + "0" + ], + "type": "INTS" + } + ] }, { "input": [ @@ -371,17 +379,6 @@ ], "name": "model", "initializer": [ - { - "dims": [ - "2" - ], - "dataType": 7, - "int64Data": [ - "1", - "1" - ], - "name": "ShapeVar" - }, { "dims": [ "1",