diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 1757d241aa..3855729066 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1742,12 +1742,17 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu _host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames)); _host.Check(Utils.Size(scoreProbablityColumnNames) == 2); - string opType = "Affine"; - string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); - var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] }, - new[] { linearOutput }, ctx.GetNodeName(opType), ""); - node.AddAttribute("alpha", Slope * -1); - node.AddAttribute("beta", -0.0000001); + // The Affine operator is no longer supported in the v11 opset. + // So we have to decompose it using Mul and Add + string opType = "Mul"; + var slopVar = ctx.AddInitializer((float)(-Slope), "Slope"); + var mulNodeOutput = ctx.AddIntermediateVariable(null, "MulNodeOutput", true); + var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0], slopVar }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), ""); + + opType = "Add"; + var betaVar = ctx.AddInitializer(-0.0000001f, "Slope"); + var linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); + node = ctx.CreateNode(opType, new[] { mulNodeOutput, betaVar }, new[] { linearOutput }, ctx.GetNodeName(opType), ""); opType = "Sigmoid"; node = ctx.CreateNode(opType, new[] { linearOutput }, diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index 1b759ff9b6..be78dab3dc 100644 --- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -442,11 +443,12 @@ private static bool AreRangesValid(int[][] slotsMin, int[][] slotsMax) private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly SlotsDroppingTransformer _parent; private readonly int[] _cols; private readonly DataViewType[] _srcTypes; + private readonly DataViewType[] _rawTypes; private readonly DataViewType[] _dstTypes; private readonly SlotDropper[] _slotDropper; // Track if all the slots of the column are to be dropped. @@ -459,6 +461,7 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema) _parent = parent; _cols = new int[_parent.ColumnPairs.Length]; _srcTypes = new DataViewType[_parent.ColumnPairs.Length]; + _rawTypes = new DataViewType[_parent.ColumnPairs.Length]; _dstTypes = new DataViewType[_parent.ColumnPairs.Length]; _slotDropper = new SlotDropper[_parent.ColumnPairs.Length]; _suppressed = new bool[_parent.ColumnPairs.Length]; @@ -471,8 +474,8 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema) _srcTypes[i] = inputSchema[_cols[i]].Type; VectorDataViewType srcVectorType = _srcTypes[i] as VectorDataViewType; - DataViewType itemType = srcVectorType?.ItemType ?? _srcTypes[i]; - if (!IsValidColumnType(itemType)) + _rawTypes[i] = srcVectorType?.ItemType ?? _srcTypes[i]; + if (!IsValidColumnType(_rawTypes[i])) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); int valueCount = srcVectorType?.Size ?? 1; @@ -868,6 +871,57 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() } return result; } + + public bool CanSaveOnnx(OnnxContext ctx) => true; + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + for (int iinfo = 0; iinfo < _cols.Length; ++iinfo) + { + string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) + continue; + + string srcVariableName = ctx.GetVariableName(inputColumnName); + string dstVariableName = ctx.AddIntermediateVariable(_dstTypes[iinfo], _parent.ColumnPairs[iinfo].outputColumnName); + if (!SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName)) + ctx.RemoveColumn(dstVariableName); + } + } + + public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) + { + string opType; + if (_srcTypes[iinfo] is VectorDataViewType) + { + opType = "GatherElements"; + IEnumerable slots = _slotDropper[iinfo].GetPreservedSlots(); + var slotsVar = ctx.AddInitializer(slots, new long[] { 1, slots.Count() }, "PreservedSlots"); + var node = ctx.CreateNode(opType, new[] { srcVariableName, slotsVar }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); + node.AddAttribute("axis", 1); + } + else + { + string constVal; + long[] dims = { 1, 1 }; + float[] floatVals = { 0.0f }; + long[] keyVals = { 0 }; + string[] stringVals = { "" }; + if (_rawTypes[iinfo] is TextDataViewType) + constVal = ctx.AddInitializer(stringVals, dims); + else if (_rawTypes[iinfo] is KeyDataViewType) + constVal = ctx.AddInitializer(keyVals, dims); + else + constVal = ctx.AddInitializer(floatVals, dims); + + opType = "Identity"; + ctx.CreateNode(opType, constVal, dstVariableName, ctx.GetNodeName(opType), ""); + } + return true; + } + } } } diff --git a/src/Microsoft.ML.Data/Utilities/SlotDropper.cs b/src/Microsoft.ML.Data/Utilities/SlotDropper.cs index f078bc206e..7ba3e6564f 100644 --- a/src/Microsoft.ML.Data/Utilities/SlotDropper.cs +++ b/src/Microsoft.ML.Data/Utilities/SlotDropper.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; @@ -16,6 +18,7 @@ namespace Microsoft.ML.Internal.Internallearn internal sealed class SlotDropper { private readonly int[] _lengthReduction; + private readonly int _srcLength; /// /// Returns -1 for non vector and unknown length vectors. @@ -43,6 +46,7 @@ public SlotDropper(int srcLength, int[] slotsMin, int[] slotsMax) SlotsMin = slotsMin; SlotsMax = slotsMax; + _srcLength = srcLength; _lengthReduction = ComputeLengthReduction(); Contracts.Check(SlotsMin.Length == _lengthReduction.Length); @@ -212,5 +216,16 @@ public void DropSlots(ref VBuffer src, ref VBuffer dst) dst = editor.CommitTruncated(iiDst); } + + public IEnumerable GetPreservedSlots() + { + var slots = Enumerable.Range(0, _srcLength); + var droppedSlots = Enumerable.Range(SlotsMin[0], SlotsMax[0] - SlotsMin[0] + 1); + for (int i = 1; i < SlotsMin.Length; i++) + { + droppedSlots = droppedSlots.Concat(Enumerable.Range(SlotsMin[i], SlotsMax[i] - SlotsMin[i] + 1)); + } + return slots.Except(droppedSlots).Select(i=>(long)i); + } } } diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs index b43f2d78f3..6735538e85 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List nodes, string producerName, s model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion; model.ModelVersion = modelVersion; model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 2 }); - model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 9 }); + model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 11 }); model.Graph = new GraphProto(); var graph = model.Graph; graph.Node.Add(nodes); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index cd50b8a3fa..764fa4806d 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -1004,10 +1004,18 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol // Onnx outputs an Int64, but ML.NET outputs UInt32. So cast the Onnx output here opType = "Cast"; - var castNode = ctx.CreateNode(opType, predictedLabelInt64, predictedLabelUint32, ctx.GetNodeName(opType), ""); + var castNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, "CastNodeOutput", true); + var castNode = ctx.CreateNode(opType, predictedLabelInt64, castNodeOutput, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); castNode.AddAttribute("to", t); + // The predictedLabel is a scalar. But the onnx output of ML.NET output expects a [1x1] tensor for output. So reshape it here + opType = "Reshape"; + long[] shape = { 1, 1 }; + long[] shapeDim = { 2 }; + var shapeVar = ctx.AddInitializer(shape, shapeDim, "ShapeVar"); + var reshapeNode = ctx.CreateNode(opType, new[] { castNodeOutput, shapeVar }, new[] { predictedLabelUint32 }, ctx.GetNodeName(opType), ""); + return true; } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index f8980c0e0e..185da7a4eb 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -559,8 +559,8 @@ public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool outputs[i] = clipOutput; string opType = "Clip"; - var clipNode = ctx.CreateNode(opType, clipInput, outputs[i], ctx.GetNodeName(opType), ""); - clipNode.AddAttribute("min", 0.0); + var zeroVar = ctx.AddInitializer(0.0f, "Zero"); + var clipNode = ctx.CreateNode(opType, new[] { clipInput, zeroVar }, new[] { outputs[i] }, ctx.GetNodeName(opType), ""); } else outputs[i] = predictorOutputNames[2]; diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index 8217b58436..682b5660e4 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -30,7 +30,7 @@ namespace Microsoft.ML.Transforms /// | Does this estimator need to look at the data to train its parameters? | Yes | /// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types| /// | Output column data type | Same as the input column| - /// | Exportable to ONNX | No | + /// | Exportable to ONNX | Yes | /// /// This transform uses a set of aggregators to count the number of values for each slot (vector element) /// that are non-default and non-missing (for the definitions of default and missing, refer to the remarks section diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index f06d519fa8..d4229a80bf 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -32,7 +32,7 @@ namespace Microsoft.ML.Transforms /// | Does this estimator need to look at the data to train its parameters? | Yes | /// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types| /// | Output column data type | Same as the input column| - /// | Exportable to ONNX | No | + /// | Exportable to ONNX | Yes | /// /// Formally, the mutual information can be written as: /// diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index c7ddd56e30..5f6a356b72 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -381,25 +381,25 @@ }, { "input": [ - "Score" + "Score", + "Slope" + ], + "output": [ + "MulNodeOutput" + ], + "name": "Mul", + "opType": "Mul" + }, + { + "input": [ + "MulNodeOutput", + "Slope0" ], "output": [ "linearOutput" ], - "name": "Affine", - "opType": "Affine", - "attribute": [ - { - "name": "alpha", - "f": 0.4, - "type": "FLOAT" - }, - { - "name": "beta", - "f": -1E-07, - "type": "FLOAT" - } - ] + "name": "Add", + "opType": "Add" }, { "input": [ @@ -478,6 +478,22 @@ } ], "name": "A Simple Pipeline", + "initializer": [ + { + "dataType": 1, + "floatData": [ + 0.4 + ], + "name": "Slope" + }, + { + "dataType": 1, + "floatData": [ + -1E-07 + ], + "name": "Slope0" + } + ], "input": [ { "name": "F1", @@ -671,7 +687,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt index 30b71b981d..3ff37b8f20 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt @@ -526,7 +526,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt index 6fb0858914..74174b3fa5 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt @@ -270,7 +270,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index 166b713fe7..bf0715cc25 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -707,25 +707,25 @@ }, { "input": [ - "Score" + "Score", + "Slope" + ], + "output": [ + "MulNodeOutput" + ], + "name": "Mul", + "opType": "Mul" + }, + { + "input": [ + "MulNodeOutput", + "Slope0" ], "output": [ "linearOutput" ], - "name": "Affine", - "opType": "Affine", - "attribute": [ - { - "name": "alpha", - "f": 0.4, - "type": "FLOAT" - }, - { - "name": "beta", - "f": -1E-07, - "type": "FLOAT" - } - ] + "name": "Add", + "opType": "Add" }, { "input": [ @@ -804,6 +804,22 @@ } ], "name": "modelWithLessIO", + "initializer": [ + { + "dataType": 1, + "floatData": [ + 0.4 + ], + "name": "Slope" + }, + { + "dataType": 1, + "floatData": [ + -1E-07 + ], + "name": "Slope0" + } + ], "input": [ { "name": "F1", @@ -961,7 +977,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt index d007cd95c0..2df8da4d81 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt @@ -334,25 +334,25 @@ }, { "input": [ - "Score" + "Score", + "Slope" + ], + "output": [ + "MulNodeOutput" + ], + "name": "Mul", + "opType": "Mul" + }, + { + "input": [ + "MulNodeOutput", + "Slope0" ], "output": [ "linearOutput" ], - "name": "Affine", - "opType": "Affine", - "attribute": [ - { - "name": "alpha", - "f": 0.4, - "type": "FLOAT" - }, - { - "name": "beta", - "f": -1E-07, - "type": "FLOAT" - } - ] + "name": "Add", + "opType": "Add" }, { "input": [ @@ -471,6 +471,22 @@ } ], "name": "model", + "initializer": [ + { + "dataType": 1, + "floatData": [ + 0.4 + ], + "name": "Slope" + }, + { + "dataType": 1, + "floatData": [ + -1E-07 + ], + "name": "Slope0" + } + ], "input": [ { "name": "Label", @@ -736,7 +752,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt index b675d685b2..b8f150889b 100644 --- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt +++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt @@ -349,7 +349,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index 12fabdd840..2a91f31721 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -260,7 +260,7 @@ "PredictedLabelInt64" ], "output": [ - "PredictedLabel" + "CastNodeOutput" ], "name": "Cast0", "opType": "Cast", @@ -272,6 +272,17 @@ } ] }, + { + "input": [ + "CastNodeOutput", + "ShapeVar" + ], + "output": [ + "PredictedLabel" + ], + "name": "Reshape", + "opType": "Reshape" + }, { "input": [ "Label0" @@ -314,6 +325,19 @@ } ], "name": "model", + "initializer": [ + { + "dims": [ + "2" + ], + "dataType": 7, + "int64Data": [ + "1", + "1" + ], + "name": "ShapeVar" + } + ], "input": [ { "name": "Label", @@ -471,7 +495,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastForestRegressionTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastForestRegressionTrainer.txt index 1c0dcf3ec3..d8b66879a6 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastForestRegressionTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastForestRegressionTrainer.txt @@ -39485,7 +39485,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.txt index e514c9cfc6..5671bd2658 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.txt @@ -39465,7 +39465,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer.txt index 6a92d519e6..927a594aec 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer.txt @@ -39475,7 +39475,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LbfgsPoissonRegressionTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LbfgsPoissonRegressionTrainer.txt index 6d693da4ab..df384fee0b 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LbfgsPoissonRegressionTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LbfgsPoissonRegressionTrainer.txt @@ -215,7 +215,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LightGbm.LightGbmRegressionTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LightGbm.LightGbmRegressionTrainer.txt index 49e40cbc75..cc943d2eab 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LightGbm.LightGbmRegressionTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.LightGbm.LightGbmRegressionTrainer.txt @@ -38425,7 +38425,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OlsTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OlsTrainer.txt index 227faf490e..880224d211 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OlsTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OlsTrainer.txt @@ -205,7 +205,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OnlineGradientDescentTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OnlineGradientDescentTrainer.txt index 49baf9eb1c..7fe033b02e 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OnlineGradientDescentTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.OnlineGradientDescentTrainer.txt @@ -205,7 +205,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.SdcaRegressionTrainer.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.SdcaRegressionTrainer.txt index cda9d947bb..ad998d6832 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.SdcaRegressionTrainer.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/Microsoft.ML.Trainers.SdcaRegressionTrainer.txt @@ -205,7 +205,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt index 6fb0858914..74174b3fa5 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt @@ -270,7 +270,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt index a819b997a5..3524a1e97b 100644 --- a/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt +++ b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt @@ -157,7 +157,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt index 9393234794..3fbfd9ef02 100644 --- a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt +++ b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt @@ -1116,7 +1116,7 @@ "version": "2" }, { - "version": "9" + "version": "11" } ] } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 12688e1392..7e8454229c 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1254,6 +1254,58 @@ public void CopyColumnsOnnxTest() Done(); } + [Fact] + public void FeatureSelectionOnnxTest() + { + var mlContext = new MLContext(seed: 1); + + string dataPath = GetDataPath("breast-cancer.txt"); + + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("ScalarFloat", DataKind.Single, 6), + new TextLoader.Column("VectorFloat", DataKind.Single, 1, 4), + new TextLoader.Column("VectorDouble", DataKind.Double, 4, 8), + new TextLoader.Column("Label", DataKind.Boolean, 0) + }); + + var columns = new[] { + new CountFeatureSelectingEstimator.ColumnOptions("FeatureSelectDouble", "VectorDouble", count: 1), + new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing690", "ScalarFloat", count: 690), + new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing100", "ScalarFloat", count: 100), + new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing690", "VectorDouble", count: 690), + new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing100", "VectorDouble", count: 100) + }; + var pipeline = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("FeatureSelect", "VectorFloat", count: 1) + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount(columns)) + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIScalarFloat", "ScalarFloat")) + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIVectorFloat", "VectorFloat")); + + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = "countfeatures.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4ScalarColumns("FeatureSelectMIScalarFloat", "FeatureSelectMIScalarFloat0", transformedData, onnxResult); + CompareSelectedR4VectorColumns("FeatureSelectMIVectorFloat", "FeatureSelectMIVectorFloat0", transformedData, onnxResult); + CompareSelectedR4ScalarColumns("ScalFeatureSelectMissing690", "ScalFeatureSelectMissing6900", transformedData, onnxResult); + CompareSelectedR8VectorColumns("VecFeatureSelectMissing690", "VecFeatureSelectMissing6900", transformedData, onnxResult); + } + Done(); + } + + private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right) { var leftColumn = left.Schema[leftColumnName];