From 65557b4ff1b7f20c05fe058b5300bcdcfda0c4bb Mon Sep 17 00:00:00 2001 From: Lynx1820 Date: Mon, 23 Mar 2020 16:21:26 -0700 Subject: [PATCH 1/3] fix for logistic regression --- .../MulticlassLogisticRegression.cs | 30 +-- ...nLogisticRegressionSaveModelToOnnxTest.txt | 243 +++++++++++------- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 14 +- 3 files changed, 173 insertions(+), 114 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index e354877578..7235c3d7b9 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -979,40 +979,26 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol { Host.CheckValue(ctx, nameof(ctx)); - string predictedLabelInt64 = null; - string predictedLabelUint32 = null; - // REVIEW: What is the right way to get the name of the predicted column? - for (int i = 0; i < outputs.Length; i++) - { - if (outputs[i] != DefaultColumnNames.PredictedLabel) - continue; - predictedLabelUint32 = DefaultColumnNames.PredictedLabel; - predictedLabelInt64 = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "PredictedLabelInt64", true); - outputs[i] = predictedLabelInt64; - break; - } - - Host.CheckValue(predictedLabelInt64, nameof(predictedLabelInt64)); + string labels = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "Labels", true); string opType = "LinearClassifier"; - var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { labels, outputs[1] }, ctx.GetNodeName(opType)); node.AddAttribute("post_transform", GetOnnxPostTransform()); node.AddAttribute("multi_class", true); node.AddAttribute("coefficients", Weights.SelectMany(w => w.DenseValues())); node.AddAttribute("intercepts", Biases); node.AddAttribute("classlabels_ints", Enumerable.Range(1, NumberOfClasses).Select(x => (long)x)); + opType = "Unsqueeze"; + var unsqueezeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastNodeOutput"); + var unsqueezeNode = ctx.CreateNode(opType, labels, unsqueezeOutput, ctx.GetNodeName(opType), ""); + unsqueezeNode.AddAttribute("axes", new long[] { 1 }); + // Onnx outputs an Int64, but ML.NET outputs UInt32. So cast the Onnx output here opType = "Cast"; - var castNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, "CastNodeOutput", true); - var castNode = ctx.CreateNode(opType, predictedLabelInt64, castNodeOutput, ctx.GetNodeName(opType), ""); + var castNode = ctx.CreateNode(opType, unsqueezeOutput, outputs[0], ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); castNode.AddAttribute("to", t); - - opType = "Unsqueeze"; - var unsqueezeNode = ctx.CreateNode(opType, castNodeOutput, predictedLabelUint32, ctx.GetNodeName(opType), ""); - unsqueezeNode.AddAttribute("axes", new long[] { 0 }); - return true; } diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index 36d8d01b7b..ae36051c5b 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -12,6 +12,31 @@ "output": [ "Features0" ], + "name": "Imputer", + "opType": "Imputer", + "attribute": [ + { + "name": "replaced_value_float", + "f": "NaN", + "type": "FLOAT" + }, + { + "name": "imputed_value_floats", + "floats": [ + 0 + ], + "type": "FLOATS" + } + ], + "domain": "ai.onnx.ml" + }, + { + "input": [ + "Features0" + ], + "output": [ + "Features1" + ], "name": "Scaler", "opType": "Scaler", "attribute": [ @@ -59,15 +84,15 @@ { "name": "keys_strings", "strings": [ - "NQ==", "Mw==", - "Ng==", - "NA==", "OA==", "MQ==", + "NQ==", "Mg==", - "Nw==", "MTA=", + "Ng==", + "NA==", + "Nw==", "OQ==" ], "type": "STRINGS" @@ -115,10 +140,10 @@ }, { "input": [ - "Features0" + "Features1" ], "output": [ - "PredictedLabelInt64", + "Labels", "Score" ], "name": "LinearClassifier", @@ -137,102 +162,102 @@ { "name": "coefficients", "floats": [ - -0.0871891156, - 0.209310874, - 0.747134566, - 0.364765137, - -0.377612084, - -0.6847462, + -0.5040307, + -1.18665814, 0, - -0.5566554, - -0.3849638, - -1.29262471, 0, + -0.4745431, + -0.0457207263, + -0.609427869, 0, - -0.479907274, - -0.08740093, - -0.5489706, + 0.945539951, + 0.9072637, + -0.146731973, 0, - 0.630316138, + 0.595698357, 0, + 0.924152553, 0, + -1.64307475, + -1.12142336, + -0.43682757, + -0.514539361, + -0.984152853, + -0.6729761, + -0.711649537, 0, - 0.07319626, - 0.171390951, - 0.6936194, + -0.0447186679, + 0.138106227, + 0.705147862, + 0.331640273, + -0.329754144, + -0.7123607, 0, + -0.46727252, + -0.543338358, + -1.06657445, + -0.1826125, 0, - -0.6189027, + -0.6470364, + -0.11790736, + -0.4401821, 0, - -0.732489467, - -0.71812433, - 0.2614429, - -0.4669126, - -0.250123739, - 1.01838875, - 0.7936676, + 0.9442133, + 0.9040651, 0, 0, - 0.8072781, + 1.12141752, + 0.6733171, 0, - 0.833407462, + 0.36669904, + 0.49009648, 0, - -1.67462111, - -1.19559848, - -0.553805768, - -0.5710498, - -0.7325714, - -0.5470721, - -0.7483947, + 0.277017027, 0, - -0.5655844, - -0.9892823, - -0.237264976, + 0.0377836041, + 0.272689134, + 0.7089771, 0, - -0.81984, - -0.0930810943, - -0.4526821, 0, + -0.626159668, 0, + -0.5391268, + -0.635246456, + 0.100393988, + -0.299498081, + -0.256104, 0, - 0.726712048, 0, - 1.12171924, - 0.323810369, - 0.245762676, - 0.07872447, - 0.939905643, - 0.923160553, + 0.520324647, 0, + 1.26542234, + 0.33962968, + 0.313578337, + 0.06361275, + 0.5025945, + 1.28040659, 0, - 1.10209334, - 0.704743862, + 0.8730278, + 0.06072715, 0, - 0.322121173, - 0.5064917, - 1.30212963, - 0, - 0.8623323, - 0.0155395176, - 0, - 0.192209348, - 0.262786478 + 0.1766175, + 0.272789866 ], "type": "FLOATS" }, { "name": "intercepts", "floats": [ - 1.23585367, - 1.68783426, - -0.8096311, - 1.35599542, - -1.59806383, - 2.57355452, - 1.03064489, - -1.67592752, - -1.40655541, - -2.39366078 + 1.73919559, + -1.43502069, + 2.63000965, + 1.271419, + 0.9587243, + -1.40365088, + -0.89705795, + 1.27946162, + -1.76300251, + -2.38004065 ], "type": "FLOATS" }, @@ -257,18 +282,20 @@ }, { "input": [ - "PredictedLabelInt64" + "Labels" ], "output": [ "CastNodeOutput" ], - "name": "Cast0", - "opType": "Cast", + "name": "Unsqueeze", + "opType": "Unsqueeze", "attribute": [ { - "name": "to", - "i": "12", - "type": "INT" + "name": "axes", + "ints": [ + "1" + ], + "type": "INTS" } ] }, @@ -279,15 +306,13 @@ "output": [ "PredictedLabel" ], - "name": "Unsqueeze", - "opType": "Unsqueeze", + "name": "Cast0", + "opType": "Cast", "attribute": [ { - "name": "axes", - "ints": [ - "0" - ], - "type": "INTS" + "name": "to", + "i": "12", + "type": "INT" } ] }, @@ -303,7 +328,7 @@ }, { "input": [ - "Features0" + "Features1" ], "output": [ "Features.output" @@ -344,15 +369,15 @@ { "name": "keys_strings", "strings": [ - "NQ==", "Mw==", - "Ng==", - "NA==", "OA==", "MQ==", + "NQ==", "Mg==", - "Nw==", "MTA=", + "Ng==", + "NA==", + "Nw==", "OQ==" ], "type": "STRINGS" @@ -540,6 +565,24 @@ } } }, + { + "name": "Features1", + "type": { + "tensorType": { + "elemType": 1, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "8" + } + ] + } + } + } + }, { "name": "Label0", "type": { @@ -612,6 +655,24 @@ } } }, + { + "name": "CastNodeOutput", + "type": { + "tensorType": { + "elemType": 7, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, { "name": "Label.output", "type": { diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index e992a730d9..68383b8ffa 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -736,7 +736,8 @@ public void MulticlassLogisticRegressionOnnxConversionTest() separatorChar: '\t', hasHeader: true); - var pipeline = mlContext.Transforms.NormalizeMinMax("Features"). + var pipeline = mlContext.Transforms.ReplaceMissingValues("Features"). + Append(mlContext.Transforms.NormalizeMinMax("Features")). Append(mlContext.Transforms.Conversion.MapValueToKey("Label")). Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(new LbfgsMaximumEntropyMulticlassTrainer.Options() { NumberOfThreads = 1 })); @@ -752,6 +753,17 @@ public void MulticlassLogisticRegressionOnnxConversionTest() SaveOnnxModel(onnxModel, onnxFilePath, onnxTextPath); + // Compare results produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxFilePath); + var onnxTransformer = onnxEstimator.Fit(data); + var onnxResult = onnxTransformer.Transform(data); + CompareSelectedColumns("PredictedLabel", "PredictedLabel", transformedData, onnxResult); + CompareSelectedColumns("Score", "Score", transformedData, onnxResult); + } + CheckEquality(subDir, onnxTextName, digitsOfPrecision: 2); Done(); } From e9b2a427110bf69cff46f13a23a01db7b0599459 Mon Sep 17 00:00:00 2001 From: Lynx1820 Date: Tue, 24 Mar 2020 10:18:04 -0700 Subject: [PATCH 2/3] rebase changes --- ...nLogisticRegressionSaveModelToOnnxTest.txt | 150 +++++++++--------- 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index ae36051c5b..b8227a4171 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -84,15 +84,15 @@ { "name": "keys_strings", "strings": [ + "NQ==", "Mw==", + "Ng==", + "NA==", "OA==", "MQ==", - "NQ==", "Mg==", - "MTA=", - "Ng==", - "NA==", "Nw==", + "MTA=", "OQ==" ], "type": "STRINGS" @@ -162,102 +162,102 @@ { "name": "coefficients", "floats": [ - -0.5040307, - -1.18665814, + -0.03167667, + 0.143471345, + 0.712303758, + 0.352516025, + -0.301202744, + -0.741327, + 0, + -0.50539434, + -0.495568782, + -1.18026292, 0, 0, - -0.4745431, - -0.0457207263, - -0.609427869, + -0.498504966, + -0.0420997739, + -0.613307, 0, - 0.945539951, - 0.9072637, - -0.146731973, + 0.607902765, 0, - 0.595698357, + 0.141199484, 0, - 0.924152553, 0, - -1.64307475, - -1.12142336, - -0.43682757, - -0.514539361, - -0.984152853, - -0.6729761, - -0.711649537, + 0.211812049, + 0.7400253, 0, - -0.0447186679, - 0.138106227, - 0.705147862, - 0.331640273, - -0.329754144, - -0.7123607, 0, - -0.46727252, - -0.543338358, - -1.06657445, - -0.1826125, + -0.6818558, 0, - -0.6470364, - -0.11790736, - -0.4401821, + -0.551549, + -0.696967065, + 0.160799474, + -0.330957055, + -0.2750144, + 1.0280081, + 0.88234216, 0, - 0.9442133, - 0.9040651, 0, + 0.6424018, 0, - 1.12141752, - 0.6733171, + 0.8743826, 0, - 0.36669904, - 0.49009648, + -1.70732152, + -1.191878, + -0.5418798, + -0.600155354, + -0.740290344, + -0.674283, + -0.7527394, 0, - 0.277017027, + -0.5712943, + -1.04777157, + -0.226992026, 0, - 0.0377836041, - 0.272689134, - 0.7089771, + -0.706178248, + -0.145423546, + -0.465133369, 0, 0, - -0.626159668, 0, - -0.5391268, - -0.635246456, - 0.100393988, - -0.299498081, - -0.256104, + 0.7297528, 0, + 1.16632056, + 0.311890721, + 0.247274086, + 0.102029644, + 0.955604, + 0.9211974, 0, - 0.520324647, 0, - 1.26542234, - 0.33962968, - 0.313578337, - 0.06361275, - 0.5025945, - 1.28040659, + 1.15463662, + 0.6863172, 0, - 0.8730278, - 0.06072715, + 0.35739857, + 0.51222384, + 1.30172575, 0, - 0.1766175, - 0.272789866 + 0.8770747, + 0.0631457046, + 0, + 0.184382111, + 0.289009124 ], "type": "FLOATS" }, { "name": "intercepts", "floats": [ - 1.73919559, - -1.43502069, - 2.63000965, - 1.271419, - 0.9587243, - -1.40365088, - -0.89705795, - 1.27946162, - -1.76300251, - -2.38004065 + 1.24280047, + 1.713405, + -0.833424449, + 1.31703234, + -1.520917, + 2.636808, + 1.016353, + -1.69996321, + -1.4401859, + -2.43189263 ], "type": "FLOATS" }, @@ -369,15 +369,15 @@ { "name": "keys_strings", "strings": [ + "NQ==", "Mw==", + "Ng==", + "NA==", "OA==", "MQ==", - "NQ==", "Mg==", - "MTA=", - "Ng==", - "NA==", "Nw==", + "MTA=", "OQ==" ], "type": "STRINGS" From 6e2bf6a639c2f7234495f50758f226c4d2af6feb Mon Sep 17 00:00:00 2001 From: Lynx1820 Date: Tue, 24 Mar 2020 11:02:54 -0700 Subject: [PATCH 3/3] resolving comments --- .../LogisticRegression/MulticlassLogisticRegression.cs | 8 +++++--- ...lassificationLogisticRegressionSaveModelToOnnxTest.txt | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 7235c3d7b9..2c141c7520 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -978,11 +978,13 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, JToken input) private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); + Host.Assert(outputs[0] == DefaultColumnNames.PredictedLabel); + Host.Assert(outputs[1] == DefaultColumnNames.Score); - string labels = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "Labels", true); + string classifierLabelOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "ClassifierLabelOutput", true); string opType = "LinearClassifier"; - var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { labels, outputs[1] }, ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { classifierLabelOutput, outputs[1] }, ctx.GetNodeName(opType)); node.AddAttribute("post_transform", GetOnnxPostTransform()); node.AddAttribute("multi_class", true); node.AddAttribute("coefficients", Weights.SelectMany(w => w.DenseValues())); @@ -991,7 +993,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol opType = "Unsqueeze"; var unsqueezeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastNodeOutput"); - var unsqueezeNode = ctx.CreateNode(opType, labels, unsqueezeOutput, ctx.GetNodeName(opType), ""); + var unsqueezeNode = ctx.CreateNode(opType, classifierLabelOutput, unsqueezeOutput, ctx.GetNodeName(opType), ""); unsqueezeNode.AddAttribute("axes", new long[] { 1 }); // Onnx outputs an Int64, but ML.NET outputs UInt32. So cast the Onnx output here diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index b8227a4171..799ca5125b 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -143,7 +143,7 @@ "Features1" ], "output": [ - "Labels", + "ClassifierLabelOutput", "Score" ], "name": "LinearClassifier", @@ -282,7 +282,7 @@ }, { "input": [ - "Labels" + "ClassifierLabelOutput" ], "output": [ "CastNodeOutput"