diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 3602ba9adc..cf3669b630 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -978,13 +978,35 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol { Host.CheckValue(ctx, nameof(ctx)); + string predictedLabelInt64 = null; + string predictedLabelUint32 = null; + // REVIEW: What is the right way to get the name of the predicted column? + for (int i = 0; i < outputs.Length; i++) + { + if (outputs[i] != DefaultColumnNames.PredictedLabel) + continue; + predictedLabelUint32 = DefaultColumnNames.PredictedLabel; + predictedLabelInt64 = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "PredictedLabelInt64", true); + outputs[i] = predictedLabelInt64; + break; + } + + Host.CheckValue(predictedLabelInt64, nameof(predictedLabelInt64)); + string opType = "LinearClassifier"; var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); node.AddAttribute("post_transform", GetOnnxPostTransform()); node.AddAttribute("multi_class", true); node.AddAttribute("coefficients", Weights.SelectMany(w => w.DenseValues())); node.AddAttribute("intercepts", Biases); - node.AddAttribute("classlabels_ints", Enumerable.Range(0, NumberOfClasses).Select(x => (long)x)); + node.AddAttribute("classlabels_ints", Enumerable.Range(1, NumberOfClasses).Select(x => (long)x)); + + // Onnx outputs an Int64, but ML.NET outputs UInt32. So cast the Onnx output here + opType = "Cast"; + var castNode = ctx.CreateNode(opType, predictedLabelInt64, predictedLabelUint32, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); + castNode.AddAttribute("to", t); + return true; } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 7ab0482acf..1874c5ab92 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -16,6 +16,7 @@ using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Model.Pfa; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers; @@ -244,7 +245,8 @@ public sealed class OneVersusAllModelParameters : IValueMapper, ICanSaveInSourceCode, ICanSaveInTextFormat, - ISingleCanSavePfa + ISingleCanSavePfa, + ISingleCanSaveOnnx { internal const string LoaderSignature = "OVAExec"; internal const string RegistrationName = "OVAPredictor"; @@ -490,7 +492,11 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) } } - private abstract class ImplBase : ISingleCanSavePfa + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _impl.CanSaveOnnx(ctx); + + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) => _impl.SaveAsOnnx(ctx, outputNames, featureColumn); + + private abstract class ImplBase : ISingleCanSavePfa, ISingleCanSaveOnnx { public OutputFormula OutputFormula; public abstract DataViewType InputType { get; } @@ -499,6 +505,10 @@ private abstract class ImplBase : ISingleCanSavePfa public abstract ValueMapper, VBuffer> GetMapper(); public abstract JToken SaveAsPfa(BoundPfaContext ctx, JToken input); + public bool CanSaveOnnx(OnnxContext ctx) => Predictors.All(pred => (pred as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true); + + public abstract bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn); + protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType) { Contracts.AssertValueOrNull(mapper); @@ -521,6 +531,65 @@ protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType) } return true; } + + public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool clipToZero) + { + string[] outputs = new string[Predictors.Length]; + + string[] localOutputNames = { DefaultColumnNames.PredictedLabel, DefaultColumnNames.Score, DefaultColumnNames.Probability }; + + for (int i = 0; i < Predictors.Length; i++) + { + var predictorOutputNames = new string[localOutputNames.Length]; + + predictorOutputNames[0] = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, $"{DefaultColumnNames.PredictedLabel}_{i}", true); + predictorOutputNames[1] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Score}_{i}", true); + predictorOutputNames[2] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Probability}_{i}", true); + + string clipInput = predictorOutputNames[2]; + + var pred = Predictors[i] as ISingleCanSaveOnnx; + Contracts.AssertValue(pred); + pred.SaveAsOnnx(ctx, predictorOutputNames, featureColumn); + + if (clipToZero) + { + var clipOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"ClipOutput_{i}", true); + outputs[i] = clipOutput; + + string opType = "Clip"; + var clipNode = ctx.CreateNode(opType, clipInput, outputs[i], ctx.GetNodeName(opType), ""); + clipNode.AddAttribute("min", 0.0); + } + else + outputs[i] = predictorOutputNames[2]; + } + return outputs; + } + + public void SaveAsOnnxPostProcess(OnnxContext ctx, string inputName, string[] outputNames) + { + Contracts.Assert(outputNames.Length >= 2); + + string opType; + opType = "ArgMax"; + var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ArgMaxOutput", true); + var argMaxNode = ctx.CreateNode(opType, inputName, argMaxOutput, ctx.GetNodeName(opType), ""); + argMaxNode.AddAttribute("keepdims", 0); + + opType = "Add"; + var one = ctx.AddInitializer(1); + var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput", true); + var addNode = ctx.CreateNode(opType, new[] { argMaxOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), ""); + + opType = "Cast"; + var castToUint32Node = ctx.CreateNode(opType, addOutput, outputNames[0], ctx.GetNodeName(opType), ""); + var t2 = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); + castToUint32Node.AddAttribute("to", t2); + + opType = "Max"; + ctx.CreateNode(opType, inputName, outputNames[1], ctx.GetNodeName(opType), ""); + } } private sealed class ImplRaw : ImplBase @@ -586,6 +655,21 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) JObject jobj = null; return jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects); } + + public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + { + var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true); + + string opType = "Concat"; + var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true); + var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), ""); + concatNode.AddAttribute("axis", 0); + + base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames); + + return true; + + } } private sealed class ImplDist : ImplBase @@ -699,6 +783,51 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) var factorVar = ctx.DeclareVar(null, PfaUtils.Call("/", 1.0, PfaUtils.Call("a.sum", resultVar))); return PfaUtils.Call("la.scale", resultVar, factorVar); } + + public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + { + Contracts.Assert(outputNames.Length >= 2); + + string opType; + var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true); + + opType = "Sum"; + var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScores", true); + var sumNode = ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), ""); + + opType = "Cast"; + var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZero", true); + var castNode = ctx.CreateNode(opType, sumOutput, castOutput, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType(); + castNode.AddAttribute("to", t); + + var castIsZeroSumToFloat = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZeroAsFloat", true); + var castIsZeroSumToFloatNode = ctx.CreateNode(opType, castOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), ""); + var t1 = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); + castIsZeroSumToFloatNode.AddAttribute("to", t1); + + opType = "Sum"; + var sumOutputNonZero = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScoresNonZero", true); + var sumOutputNonZeroNode = ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat }, + new[] { sumOutputNonZero }, ctx.GetNodeName(opType), ""); + + string[] divOutputs = new string[Predictors.Length]; + for (int i = 0; i < Predictors.Length; i++) + { + opType = "Div"; + divOutputs[i] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"DivOutput_{i}", true); + ctx.CreateNode(opType, new[] { probabilityOutputs[i], sumOutputNonZero }, new[] { divOutputs[i] }, ctx.GetNodeName(opType), ""); + } + + opType = "Concat"; + var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true); + var concatNode = ctx.CreateNode(opType, divOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), ""); + concatNode.AddAttribute("axis", 0); + + base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames); + + return true; + } } private sealed class ImplSoftmax : ImplBase @@ -768,6 +897,36 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) { throw new NotImplementedException("Softmax's PFA exporter is not implemented yet."); } + + public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + { + Contracts.Assert(outputNames.Length >= 2); + + var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false); + + string opType; + opType = "Concat"; + var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true); + var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), ""); + concatNode.AddAttribute("axis", 0); + + opType = "Exp"; + var expOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ExpOutput", true); + var expNode = ctx.CreateNode(opType, concatOutput, expOutput, ctx.GetNodeName(opType), ""); + + opType = "ReduceSum"; + var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOutput", true); + var sumNode = ctx.CreateNode(opType, expOutput, sumOutput, ctx.GetNodeName(opType), ""); + sumNode.AddAttribute("keepdims", 0); + + opType = "Div"; + var divOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "DivOutput", true); + var divNode = ctx.CreateNode(opType, new[] { expOutput, sumOutput }, new[] { divOutput }, ctx.GetNodeName(opType), ""); + + base.SaveAsOnnxPostProcess(ctx, divOutput, outputNames); + + return true; + } } } } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index 4686df5dba..12fabdd840 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -118,7 +118,7 @@ "Features0" ], "output": [ - "PredictedLabel", + "PredictedLabelInt64", "Score" ], "name": "LinearClassifier", @@ -239,7 +239,6 @@ { "name": "classlabels_ints", "ints": [ - "0", "1", "2", "3", @@ -248,13 +247,31 @@ "6", "7", "8", - "9" + "9", + "10" ], "type": "INTS" } ], "domain": "ai.onnx.ml" }, + { + "input": [ + "PredictedLabelInt64" + ], + "output": [ + "PredictedLabel" + ], + "name": "Cast0", + "opType": "Cast", + "attribute": [ + { + "name": "to", + "i": "12", + "type": "INT" + } + ] + }, { "input": [ "Label0" diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 17a04d7fc7..b70e59d34d 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -19,6 +19,7 @@ using Microsoft.ML.TestFrameworkCommon.Attributes; using Microsoft.ML.Tools; using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.LightGbm; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Onnx; using Microsoft.ML.Transforms.Text; @@ -1076,6 +1077,61 @@ private void KeyToValueOnnxConversionTest() Done(); } + [Fact] + void MulticlassTrainersOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + + string dataPath = GetDataPath("breast-cancer.txt"); + var dataView = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: true); + + List> estimators = new List>() + { + mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(), + mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated() + }; + + if (Environment.Is64BitProcess) + { + estimators.Add(mlContext.MulticlassClassification.Trainers.LightGbm()); + estimators.Add(mlContext.MulticlassClassification.Trainers.LightGbm( + new LightGbmMulticlassTrainer.Options { UseSoftmax = true })); + } + + var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features") + .Append(mlContext.Transforms.NormalizeMinMax("Features")) + .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")); + + foreach (var estimator in estimators) + { + var pipeline = initialPipeline.Append(estimator); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + var onnxFileName = $"{estimator.ToString()}.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + // Compare results produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedScalarColumns(transformedData.Schema[5].Name, outputNames[2], transformedData, onnxResult); + } + } + Done(); + } + private void CreateDummyExamplesToMakeComplierHappy() { var dummyExample = new BreastCancerFeatureVector() { Features = null }; @@ -1141,34 +1197,6 @@ private void CompareSelectedVectorColumns(string leftColumnName, string right } } - private void CompareSelectedScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) - { - var leftColumn = left.Schema[leftColumnName]; - var rightColumn = right.Schema[rightColumnName]; - - using (var expectedCursor = left.GetRowCursor(leftColumn)) - using (var actualCursor = right.GetRowCursor(rightColumn)) - { - T expected = default; - VBuffer actual = default; - var expectedGetter = expectedCursor.GetGetter(leftColumn); - var actualGetter = actualCursor.GetGetter>(rightColumn); - while (expectedCursor.MoveNext() && actualCursor.MoveNext()) - { - expectedGetter(ref expected); - actualGetter(ref actual); - var actualVal = actual.GetItemOrDefault(0); - - Assert.Equal(1, actual.Length); - - if (typeof(T) == typeof(ReadOnlyMemory)) - Assert.Equal(expected.ToString(), actualVal.ToString()); - else - Assert.Equal(expected, actualVal); - } - } - } - private void CompareSelectedR8VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) { var leftColumn = left.Schema[leftColumnName]; @@ -1248,6 +1276,34 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC } } + private void CompareSelectedScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) + { + var leftColumn = left.Schema[leftColumnName]; + var rightColumn = right.Schema[rightColumnName]; + + using (var expectedCursor = left.GetRowCursor(leftColumn)) + using (var actualCursor = right.GetRowCursor(rightColumn)) + { + T expected = default; + VBuffer actual = default; + var expectedGetter = expectedCursor.GetGetter(leftColumn); + var actualGetter = actualCursor.GetGetter>(rightColumn); + while (expectedCursor.MoveNext() && actualCursor.MoveNext()) + { + expectedGetter(ref expected); + actualGetter(ref actual); + var actualVal = actual.GetItemOrDefault(0); + + Assert.Equal(1, actual.Length); + + if (typeof(T) == typeof(ReadOnlyMemory)) + Assert.Equal(expected.ToString(), actualVal.ToString()); + else + Assert.Equal(expected, actualVal); + } + } + } + private void SaveOnnxModel(ModelProto model, string binaryFormatPath, string textFormatPath) { DeleteOutputPath(binaryFormatPath); // Clean if such a file exists.