diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs index fa6f2ff048..8736484e9d 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs @@ -29,5 +29,6 @@ internal abstract class OnnxNode public abstract void AddAttribute(string argName, string[] value); public abstract void AddAttribute(string argName, IEnumerable value); public abstract void AddAttribute(string argName, IEnumerable value); + public abstract void AddAttribute(string argName, Type t); } } diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index d86f1219db..ba4107496a 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -381,13 +381,13 @@ internal static bool GetNewType(IExceptionContext ectx, DataViewType srcType, In return true; } - private sealed class Mapper : OneToOneMapperBase, ICanSaveOnnx + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly TypeConvertingTransformer _parent; private readonly DataViewType[] _types; private readonly int[] _srcCols; - public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental; + public bool CanSaveOnnx(OnnxContext ctx) => true; public Mapper(TypeConvertingTransformer parent, DataViewSchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) @@ -497,22 +497,17 @@ public void SaveAsOnnx(OnnxContext ctx) private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { - var opType = "CSharp"; - var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - node.AddAttribute("type", LoaderSignature); - node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind); - if (_parent._columns[iinfo].OutputKeyCount != null) - { - var key = (KeyDataViewType)_types[iinfo].GetItemType(); - node.AddAttribute("max", key.Count); - } + var opType = "Cast"; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), ""); + var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType(); + node.AddAttribute("to", t); return true; } } } /// - /// Estimator for . Converts the underlying input column type to a new type. + /// Estimator for . Converts the underlying input column type to a new type. /// The input and output column types need to be compatible. /// /// diff --git a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs index 12cf0fd333..db8876ab4a 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs @@ -351,11 +351,17 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string var addNodeY = ctx.CreateNode("Add", new[] { nameZ, nameC2 }, new[] { nameY }, ctx.GetNodeName("Add"), ""); // Compute the most-matched cluster index, L - var nameL = outputNames[0]; + var nameL = "ArgMinInt64"; var predictNodeL = ctx.CreateNode("ArgMin", nameY, nameL, ctx.GetNodeName("ArgMin"), ""); predictNodeL.AddAttribute("axis", 1); predictNodeL.AddAttribute("keepdims", 1); + // ArgMin outputs an Int64. But ML.NET's KMeans trainer outputs a UINT32. + // Cast the output here to UInt32 to make them compatible + var predictedNode = ctx.CreateNode("Cast", nameL, outputNames[0], ctx.GetNodeName("Cast"), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); + predictedNode.AddAttribute("to", t); + return true; } } diff --git a/src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs index b3c67566ea..0dac8f5029 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs @@ -43,5 +43,7 @@ public override void AddAttribute(string argName, string value) => OnnxUtils.NodeAddAttributes(_node, argName, value); public override void AddAttribute(string argName, bool value) => OnnxUtils.NodeAddAttributes(_node, argName, value); + public override void AddAttribute(string argName, Type value) + => OnnxUtils.NodeAddAttributes(_node, argName, value); } } diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs index 82ad63508a..c2c4b1eaea 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs @@ -70,6 +70,14 @@ private static AttributeProto MakeAttribute(string key) return attribute; } + private static AttributeProto MakeAttribute(string key, TensorProto.Types.DataType value) + { + AttributeProto attribute = MakeAttribute(key); + attribute.Type = AttributeProto.Types.AttributeType.Int; + attribute.I = (int)value; + return attribute; + } + private static AttributeProto MakeAttribute(string key, double value) { AttributeProto attribute = MakeAttribute(key); @@ -211,6 +219,45 @@ public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable public static void NodeAddAttributes(NodeProto node, string argName, bool value) => node.Attribute.Add(MakeAttribute(argName, value)); + public static void NodeAddAttributes(NodeProto node, string argName, Type value) + => node.Attribute.Add(MakeAttribute(argName, ConvertToTensorProtoType(value))); + + private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType) + { + var dataType = TensorProto.Types.DataType.Undefined; + + if (rawType == typeof(bool)) + dataType = TensorProto.Types.DataType.Float; + else if (rawType == typeof(ReadOnlyMemory)) + dataType = TensorProto.Types.DataType.String; + else if (rawType == typeof(sbyte)) + dataType = TensorProto.Types.DataType.Int8; + else if (rawType == typeof(byte)) + dataType = TensorProto.Types.DataType.Uint8; + else if (rawType == typeof(short)) + dataType = TensorProto.Types.DataType.Int16; + else if (rawType == typeof(ushort)) + dataType = TensorProto.Types.DataType.Uint16; + else if (rawType == typeof(int)) + dataType = TensorProto.Types.DataType.Int32; + else if (rawType == typeof(uint)) + dataType = TensorProto.Types.DataType.Uint32; + else if (rawType == typeof(long)) + dataType = TensorProto.Types.DataType.Int64; + else if (rawType == typeof(ulong)) + dataType = TensorProto.Types.DataType.Uint64; + else if (rawType == typeof(float)) + dataType = TensorProto.Types.DataType.Float; + else if (rawType == typeof(double)) + dataType = TensorProto.Types.DataType.Double; + else + { + string msg = "Unsupported type: " + rawType.ToString(); + Contracts.Check(false, msg); + } + + return dataType; + } private static ByteString StringToByteString(ReadOnlyMemory str) => ByteString.CopyFrom(Encoding.UTF8.GetBytes(str.ToString())); private static IEnumerable StringToByteString(IEnumerable> str) @@ -295,42 +342,12 @@ public static ModelArgs GetModelArgs(DataViewType type, string colName, Contracts.CheckValue(type, nameof(type)); Contracts.CheckNonEmpty(colName, nameof(colName)); - TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined; Type rawType; if (type is VectorDataViewType vectorType) rawType = vectorType.ItemType.RawType; else rawType = type.RawType; - - if (rawType == typeof(bool)) - dataType = TensorProto.Types.DataType.Float; - else if (rawType == typeof(ReadOnlyMemory)) - dataType = TensorProto.Types.DataType.String; - else if (rawType == typeof(sbyte)) - dataType = TensorProto.Types.DataType.Int8; - else if (rawType == typeof(byte)) - dataType = TensorProto.Types.DataType.Uint8; - else if (rawType == typeof(short)) - dataType = TensorProto.Types.DataType.Int16; - else if (rawType == typeof(ushort)) - dataType = TensorProto.Types.DataType.Uint16; - else if (rawType == typeof(int)) - dataType = TensorProto.Types.DataType.Int32; - else if (rawType == typeof(uint)) - dataType = TensorProto.Types.DataType.Int64; - else if (rawType == typeof(long)) - dataType = TensorProto.Types.DataType.Int64; - else if (rawType == typeof(ulong)) - dataType = TensorProto.Types.DataType.Uint64; - else if (rawType == typeof(float)) - dataType = TensorProto.Types.DataType.Float; - else if (rawType == typeof(double)) - dataType = TensorProto.Types.DataType.Double; - else - { - string msg = "Unsupported type: " + type.ToString(); - Contracts.Check(false, msg); - } + var dataType = ConvertToTensorProtoType(rawType); string name = colName; List dimsLocal = null; diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index 0e2a01ba0d..f0795a1f13 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -515,7 +515,7 @@ "name": "F20", "type": { "tensorType": { - "elemType": "INT64", + "elemType": "UINT32", "shape": { "dim": [ { diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index b52cce8e97..22aee806af 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -842,7 +842,7 @@ "name": "F20", "type": { "tensorType": { - "elemType": "INT64", + "elemType": "UINT32", "shape": { "dim": [ { diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt index d66e9291fe..68335b20ad 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt @@ -598,7 +598,7 @@ "name": "F20", "type": { "tensorType": { - "elemType": "INT64", + "elemType": "UINT32", "shape": { "dim": [ { diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt index 7f2a6d82e4..44f74a7022 100644 --- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt +++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt @@ -109,7 +109,7 @@ "Score" ], "output": [ - "PredictedLabel" + "ArgMinInt64" ], "name": "ArgMin", "opType": "ArgMin", @@ -126,6 +126,23 @@ } ] }, + { + "input": [ + "ArgMinInt64" + ], + "output": [ + "PredictedLabel" + ], + "name": "Cast", + "opType": "Cast", + "attribute": [ + { + "name": "to", + "i": "12", + "type": "INT" + } + ] + }, { "input": [ "Features0" @@ -272,7 +289,7 @@ "name": "PredictedLabel0", "type": { "tensorType": { - "elemType": "INT64", + "elemType": "UINT32", "shape": { "dim": [ { diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index 107f833949..a83f522dc5 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -312,7 +312,7 @@ "name": "Label1", "type": { "tensorType": { - "elemType": "INT64", + "elemType": "UINT32", "shape": { "dim": [ { @@ -348,7 +348,7 @@ "name": "PredictedLabel0", "type": { "tensorType": { - "elemType": "INT64", + "elemType": "UINT32", "shape": { "dim": [ { @@ -404,7 +404,7 @@ "name": "Label0", "type": { "tensorType": { - "elemType": "INT64", + "elemType": "UINT32", "shape": { "dim": [ { diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index af87577042..02ecb52add 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -601,6 +601,83 @@ public void WordEmbeddingsTest() Done(); } + [Fact] + public void OnnxTypeConversionTest() + { + var mlContext = new MLContext(seed: 1); + string filePath = GetDataPath("type-conversion.txt"); + + // These are the supported conversions + // ML.NET does not allow any conversions between signed and unsigned numeric types + // Onnx does not seem to support casting a string to any type + // Though the onnx docs claim support for byte and sbyte, + // CreateNamedOnnxValue in OnnxUtils.cs throws a NotImplementedException for those two + DataKind[,] supportedConversions = new[,] + { + { DataKind.Int16, DataKind.Int16}, + { DataKind.Int16, DataKind.Int32}, + { DataKind.Int16, DataKind.Int64}, + { DataKind.Int16, DataKind.Single}, + { DataKind.Int16, DataKind.Double}, + { DataKind.UInt16, DataKind.UInt16}, + { DataKind.UInt16, DataKind.UInt32}, + { DataKind.UInt16, DataKind.UInt64}, + { DataKind.UInt16, DataKind.Single}, + { DataKind.UInt16, DataKind.Double}, + { DataKind.Int32, DataKind.Int16}, + { DataKind.Int32, DataKind.Int32}, + { DataKind.Int32, DataKind.Int64}, + { DataKind.Int32, DataKind.Single}, + { DataKind.Int32, DataKind.Double}, + { DataKind.Int64, DataKind.Int16}, + { DataKind.Int64, DataKind.Int32}, + { DataKind.Int64, DataKind.Int64}, + { DataKind.Int64, DataKind.Single}, + { DataKind.Int64, DataKind.Double}, + { DataKind.UInt64, DataKind.UInt16}, + { DataKind.UInt64, DataKind.UInt32}, + { DataKind.UInt64, DataKind.UInt64}, + { DataKind.UInt64, DataKind.Single}, + { DataKind.UInt64, DataKind.Double}, + { DataKind.Single, DataKind.Single}, + { DataKind.Single, DataKind.Double}, + { DataKind.Double, DataKind.Single}, + { DataKind.Double, DataKind.Double} + }; + + for (int i = 0; i < supportedConversions.GetLength(0); i++) + { + var fromKind = supportedConversions[i, 0]; + var toKind = supportedConversions[i, 1]; + + TextLoader.Column[] columns = new [] + { + new TextLoader.Column("Value", fromKind, 0, 0) + }; + var dataView = mlContext.Data.LoadFromTextFile(filePath, columns); + + var pipeline = mlContext.Transforms.Conversion.ConvertType("ValueConverted", "Value", outputKind: toKind); + var model = pipeline.Fit(dataView); + var mlnetResult = model.Transform(dataView); + + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + var onnxFileName = "typeconversion.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) + { + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + + CompareResults(model.ColumnPairs[0].outputColumnName, outputNames[1], mlnetResult, onnxResult); + } + } + } + private void CreateDummyExamplesToMakeComplierHappy() { var dummyExample = new BreastCancerFeatureVector() { Features = null }; @@ -609,6 +686,84 @@ private void CreateDummyExamplesToMakeComplierHappy() var dummyExample3 = new SmallSentimentExample() { Tokens = null }; } + private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right) + { + var leftColumn = left.Schema[leftColumnName]; + var rightColumn = right.Schema[rightColumnName]; + var leftType = leftColumn.Type.GetItemType(); + var rightType = rightColumn.Type.GetItemType(); + Assert.Equal(leftType, rightType); + + if (leftType == NumberDataViewType.SByte) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.Byte) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.Int16) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.UInt16) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.Int32) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.UInt32) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.Int64) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.UInt64) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.Single) + CompareSelectedR4VectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == NumberDataViewType.Double) + CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + } + + private void CompareSelectedVectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) + { + var leftColumn = left.Schema[leftColumnName]; + var rightColumn = right.Schema[rightColumnName]; + + using (var expectedCursor = left.GetRowCursor(leftColumn)) + using (var actualCursor = right.GetRowCursor(rightColumn)) + { + VBuffer expected = default; + VBuffer actual = default; + var expectedGetter = expectedCursor.GetGetter>(leftColumn); + var actualGetter = actualCursor.GetGetter>(rightColumn); + while (expectedCursor.MoveNext() && actualCursor.MoveNext()) + { + expectedGetter(ref expected); + actualGetter(ref actual); + + Assert.Equal(expected.Length, actual.Length); + for (int i = 0; i < expected.Length; ++i) + Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i)); + } + } + } + + private void CompareSelectedR8VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) + { + var leftColumn = left.Schema[leftColumnName]; + var rightColumn = right.Schema[rightColumnName]; + + using (var expectedCursor = left.GetRowCursor(leftColumn)) + using (var actualCursor = right.GetRowCursor(rightColumn)) + { + VBuffer expected = default; + VBuffer actual = default; + var expectedGetter = expectedCursor.GetGetter>(leftColumn); + var actualGetter = actualCursor.GetGetter>(rightColumn); + while (expectedCursor.MoveNext() && actualCursor.MoveNext()) + { + expectedGetter(ref expected); + actualGetter(ref actual); + + Assert.Equal(expected.Length, actual.Length); + for (int i = 0; i < expected.Length; ++i) + Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision); + } + } + } + private void CompareSelectedR4VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) { var leftColumn = left.Schema[leftColumnName]; diff --git a/test/data/type-conversion.txt b/test/data/type-conversion.txt new file mode 100644 index 0000000000..e440e5c842 --- /dev/null +++ b/test/data/type-conversion.txt @@ -0,0 +1 @@ +3 \ No newline at end of file