From fc1925ceaef44845d10d7788c0a1f544a5219fba Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Sat, 8 Feb 2020 01:06:29 -0800 Subject: [PATCH 1/3] Changed Binarizer node to be cast to the type of the predicted label column's data type --- src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index ba015ebe6f..26d1d3831f 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -216,10 +216,11 @@ the predicted label is based on the sign of the score. node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType)); node.AddAttribute("threshold", 0.0); } + opType = "Cast"; node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), ""); - var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType(); - node.AddAttribute("to", t); + var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]); + node.AddAttribute("to", predictedLabelCol.HasValue ? predictedLabelCol.Value.Type.RawType : typeof(bool)); } private protected override IDataTransform ApplyToDataCore(IHostEnvironment env, IDataView newSource) From fb8aefb19d3149566365c3baf06709f99d557d2e Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Mon, 10 Feb 2020 11:36:54 -0800 Subject: [PATCH 2/3] Fixed threshold comparison, addressed review comments and updated resulting baseline changes --- .../Scorers/BinaryClassifierScorer.cs | 19 ++++--------------- .../ExcludeVariablesInOnnxConversion.txt | 3 +-- .../BreastCancer/ModelWithLessIO.txt | 3 +-- .../BreastCancer/OneHotBagPipeline.txt | 3 +-- 4 files changed, 7 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 26d1d3831f..a546ce693e 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -198,29 +198,18 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx) for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo) outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo)); - /* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise, - the predicted label is based on the sign of the score. - */ string opType = "Binarizer"; OnnxNode node; var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true); - if (Bindings.InfoCount >= 3) - { - Host.Assert(ctx.ContainsColumn(outColumnNames[2])); - node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType)); - node.AddAttribute("threshold", 0.5); - } - else - { - node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType)); - node.AddAttribute("threshold", 0.0); - } + node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType)); + node.AddAttribute("threshold", _threshold); opType = "Cast"; node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), ""); var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]); - node.AddAttribute("to", predictedLabelCol.HasValue ? predictedLabelCol.Value.Type.RawType : typeof(bool)); + Host.Assert(predictedLabelCol.HasValue); + node.AddAttribute("to", predictedLabelCol.Value.Type.RawType); } private protected override IDataTransform ApplyToDataCore(IHostEnvironment env, IDataView newSource) diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index 50ed57a58c..b9c6508af7 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -413,7 +413,7 @@ }, { "input": [ - "Probability" + "Score" ], "output": [ "BinarizerOutput" @@ -423,7 +423,6 @@ "attribute": [ { "name": "threshold", - "f": 0.5, "type": "FLOAT" } ], diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index d1ef2ebd05..9cb2ba4450 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -739,7 +739,7 @@ }, { "input": [ - "Probability" + "Score" ], "output": [ "BinarizerOutput" @@ -749,7 +749,6 @@ "attribute": [ { "name": "threshold", - "f": 0.5, "type": "FLOAT" } ], diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt index f33f3276ea..586220cd6c 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt @@ -366,7 +366,7 @@ }, { "input": [ - "Probability" + "Score" ], "output": [ "BinarizerOutput" @@ -376,7 +376,6 @@ "attribute": [ { "name": "threshold", - "f": 0.5, "type": "FLOAT" } ], From 08eb58345aee5c7067337d1bc613a57709af2628 Mon Sep 17 00:00:00 2001 From: "Harish S. Kulkarni" Date: Tue, 11 Feb 2020 10:16:20 -0800 Subject: [PATCH 3/3] Changed outColumnName to handle the case of threshold being assigned to the probability column --- .../Scorers/BinaryClassifierScorer.cs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index a546ce693e..d6b3516456 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; @@ -202,7 +203,15 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx) OnnxNode node; var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true); - node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType)); + string scoreColumn; + if (Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name == "Score") + scoreColumn = outColumnNames[1]; + else + { + Host.Assert(Bindings.InfoCount >= 3); + scoreColumn = outColumnNames[2]; + } + node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType)); node.AddAttribute("threshold", _threshold); opType = "Cast";