diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index d45f3ab469..e61a61e59c 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3112,7 +3112,7 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, } string opType = "TreeEnsembleRegressor"; - string scoreVarName = (Utils.Size(outputNames) == 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns + string scoreVarName = (Utils.Size(outputNames) >= 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType)); node.AddAttribute("post_transform", PostTransform.None.GetDescription()); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs index d2dea6f60e..8f1745b2a0 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs @@ -138,7 +138,7 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, str Host.CheckValue(ctx, nameof(ctx)); Host.Check(Utils.Size(outputs) >= 1); string opType = "LinearRegressor"; - string scoreVarName = (Utils.Size(outputs) == 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns + string scoreVarName = (Utils.Size(outputs) >= 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} node.AddAttribute("post_transform", "NONE"); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 185da7a4eb..54f74abad0 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -563,7 +563,7 @@ public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool var clipNode = ctx.CreateNode(opType, new[] { clipInput, zeroVar }, new[] { outputs[i] }, ctx.GetNodeName(opType), ""); } else - outputs[i] = predictorOutputNames[2]; + outputs[i] = predictorOutputNames[1]; } return outputs; } @@ -659,7 +659,7 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { - var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true); + var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false); string opType = "Concat"; var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true); @@ -794,22 +794,27 @@ public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fe opType = "Sum"; var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScores", true); - var sumNode = ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), ""); + ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; - var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZero", true); + var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "CastOutput", true); var castNode = ctx.CreateNode(opType, sumOutput, castOutput, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType(); castNode.AddAttribute("to", t); + opType = "Not"; + var notOutput = ctx.AddIntermediateVariable(null, "IsSumZero", true); + ctx.CreateNode(opType, castOutput, notOutput, ctx.GetNodeName(opType), ""); + + opType = "Cast"; var castIsZeroSumToFloat = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZeroAsFloat", true); - var castIsZeroSumToFloatNode = ctx.CreateNode(opType, castOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), ""); + var castIsZeroSumToFloatNode = ctx.CreateNode(opType, notOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), ""); var t1 = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); castIsZeroSumToFloatNode.AddAttribute("to", t1); opType = "Sum"; var sumOutputNonZero = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScoresNonZero", true); - var sumOutputNonZeroNode = ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat }, + ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat }, new[] { sumOutputNonZero }, ctx.GetNodeName(opType), ""); string[] divOutputs = new string[Predictors.Length]; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index e22b82133f..b52196a56e 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1319,8 +1319,22 @@ public void MulticlassTrainersOnnxConversionTest() { mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(), mlContext.MulticlassClassification.Trainers.NaiveBayes(), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.AveragedPerceptron()), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.AveragedPerceptron(), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression()), mlContext.MulticlassClassification.Trainers.OneVersusAll( mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LinearSvm()), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LinearSvm(), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.FastForest()), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.FastForest(), useProbabilities:false), mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(), mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated() }; @@ -1357,7 +1371,8 @@ public void MulticlassTrainersOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedScalarColumns(transformedData.Schema[5].Name, outputNames[2], transformedData, onnxResult); + CompareSelectedScalarColumns(transformedData.Schema[5].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels + CompareSelectedR4VectorColumns(transformedData.Schema[6].Name, outputNames[3], transformedData, onnxResult, 4); //compare scores } } Done();