diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs index c2c4b1eaea..77b27f3a92 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs @@ -227,7 +227,7 @@ private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType) var dataType = TensorProto.Types.DataType.Undefined; if (rawType == typeof(bool)) - dataType = TensorProto.Types.DataType.Float; + dataType = TensorProto.Types.DataType.Bool; else if (rawType == typeof(ReadOnlyMemory)) dataType = TensorProto.Types.DataType.String; else if (rawType == typeof(sbyte)) @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List nodes, string producerName, s model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion; model.ModelVersion = modelVersion; model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 }); - model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 7 }); + model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 9 }); model.Graph = new GraphProto(); var graph = model.Graph; graph.Node.Add(nodes); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index a8058ead78..5df5cb5a6b 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -10,6 +10,7 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -140,7 +141,7 @@ private protected override void SaveModel(ModelSaveContext ctx) private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly MissingValueIndicatorTransformer _parent; private readonly ColInfo[] _infos; @@ -426,6 +427,46 @@ private void FillValues(int srcLength, ref VBuffer dst, List indices, dst = editor.Commit(); } } + + public bool CanSaveOnnx(OnnxContext ctx) => true; + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) + { + ColInfo info = _infos[iinfo]; + string inputColumnName = info.InputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) + { + ctx.RemoveColumn(info.Name, false); + continue; + } + + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName), + ctx.AddIntermediateVariable(_infos[iinfo].OutputType, info.Name))) + { + ctx.RemoveColumn(info.Name, true); + } + } + } + + private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + { + var inputType = _infos[iinfo].InputType; + Type rawType = (inputType is VectorDataViewType vectorType) ? vectorType.ItemType.RawType : inputType.RawType; + + if (rawType != typeof(float)) + return false; + + string opType; + opType = "IsNaN"; + var isNaNOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsNaNOutput", true); + var nanNode = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), ""); + + return true; + } } } diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index f0795a1f13..0e43749793 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -459,7 +459,7 @@ "name": "PredictedLabel0", "type": { "tensorType": { - "elemType": "FLOAT", + "elemType": "BOOL", "shape": { "dim": [ { diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index 22aee806af..e0decf5739 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -786,7 +786,7 @@ "name": "PredictedLabel0", "type": { "tensorType": { - "elemType": "FLOAT", + "elemType": "BOOL", "shape": { "dim": [ { diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt index 68335b20ad..5d88daca32 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt @@ -414,7 +414,7 @@ "name": "Label", "type": { "tensorType": { - "elemType": "FLOAT", + "elemType": "BOOL", "shape": { "dim": [ { @@ -470,7 +470,7 @@ "name": "Label0", "type": { "tensorType": { - "elemType": "FLOAT", + "elemType": "BOOL", "shape": { "dim": [ { @@ -542,7 +542,7 @@ "name": "PredictedLabel0", "type": { "tensorType": { - "elemType": "FLOAT", + "elemType": "BOOL", "shape": { "dim": [ { diff --git a/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt new file mode 100644 index 0000000000..964d073ed9 --- /dev/null +++ b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt @@ -0,0 +1,163 @@ +{ + "irVersion": "3", + "producerName": "ML.NET", + "producerVersion": "##VERSION##", + "domain": "machinelearning.dotnet", + "graph": { + "node": [ + { + "input": [ + "Features" + ], + "output": [ + "MissingIndicator" + ], + "name": "IsNaN", + "opType": "IsNaN" + }, + { + "input": [ + "MissingIndicator" + ], + "output": [ + "MissingIndicator0" + ], + "name": "Cast", + "opType": "Cast", + "attribute": [ + { + "name": "to", + "i": "6", + "type": "INT" + } + ] + }, + { + "input": [ + "Features" + ], + "output": [ + "Features0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "MissingIndicator0" + ], + "output": [ + "MissingIndicator1" + ], + "name": "Identity0", + "opType": "Identity" + } + ], + "name": "model", + "input": [ + { + "name": "Features", + "type": { + "tensorType": { + "elemType": "FLOAT", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "3" + } + ] + } + } + } + } + ], + "output": [ + { + "name": "Features0", + "type": { + "tensorType": { + "elemType": "FLOAT", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "3" + } + ] + } + } + } + }, + { + "name": "MissingIndicator1", + "type": { + "tensorType": { + "elemType": "INT32", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "3" + } + ] + } + } + } + } + ], + "valueInfo": [ + { + "name": "MissingIndicator", + "type": { + "tensorType": { + "elemType": "BOOL", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "3" + } + ] + } + } + } + }, + { + "name": "MissingIndicator0", + "type": { + "tensorType": { + "elemType": "INT32", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "3" + } + ] + } + } + } + } + ] + }, + "opsetImport": [ + { + "domain": "ai.onnx.ml", + "version": "1" + }, + { + "version": "9" + } + ] +} \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 138d49f5d4..9ea36bdadd 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -41,6 +41,11 @@ public OnnxConversionTest(ITestOutputHelper output) : base(output) { } + private bool IsOnnxRuntimeSupported() + { + return Environment.Is64BitProcess && (!RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || AttributeHelpers.CheckLibcVersionGreaterThanMinimum(new System.Version(2, 23))); + } + /// /// In this test, we convert a trained into ONNX file and then /// call to evaluate that file. The outputs of are checked against the original @@ -740,6 +745,67 @@ public void OnnxTypeConversionTest() } } + private class TransformedDataPoint : DataPoint, IEquatable + { + [VectorType(3)] + public int[] MissingIndicator { get; set; } + + public bool Equals(TransformedDataPoint other) + { + return Enumerable.SequenceEqual(MissingIndicator, other.MissingIndicator); + } + } + + [Fact] + void IndicateMissingValuesOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + + var samples = new List() + { + new DataPoint() { Features = new float[3] {1, 1, 0}, }, + new DataPoint() { Features = new float[3] {0, float.NaN, 1}, }, + new DataPoint() { Features = new float[3] {-1, float.NaN, float.PositiveInfinity}, }, + }; + var dataView = mlContext.Data.LoadFromEnumerable(samples); + + // IsNaN outputs a binary tensor. Support for this has been added in the latest version + // of Onnxruntime, but that hasn't been released yet. + // So we need to convert its type to Int32 until then. + // ConvertType part of the pipeline can be removed once we pick up a new release of the Onnx runtime + + var pipeline = mlContext.Transforms.IndicateMissingValues(new[] { new InputOutputColumnPair("MissingIndicator", "Features"), }) + .Append(mlContext.Transforms.Conversion.ConvertType("MissingIndicator", outputKind: DataKind.Int32)); + + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var mlnetData = mlContext.Data.CreateEnumerable(transformedData, false); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms"); + var onnxFileName = "IndicateMissingValues.onnx"; + var onnxTextName = "IndicateMissingValues.txt"; + var onnxModelPath = GetOutputPath(onnxFileName); + var onnxTextPath = GetOutputPath(subDir, onnxTextName); + + SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath); + + // Compare results produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedVectorColumns(model.LastTransformer.ColumnPairs[0].outputColumnName, outputNames[1], transformedData, onnxResult); + } + + CheckEquality(subDir, onnxTextName, parseOption: NumberParseOption.UseSingle); + Done(); + } + private void CreateDummyExamplesToMakeComplierHappy() { var dummyExample = new BreastCancerFeatureVector() { Features = null };