diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 3efc0e63e0..6183bb06c7 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -170,6 +170,16 @@ public OnnxNode CreateNode(string opType, string input, string output, string na /// The initializer's ONNX name public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); + /// + /// Call this function can declare a global double tensor + /// + /// The doubles which are going to be added into the ONNX graph + /// The shape that the doubles + /// A string used as a seed to generate this initializer's name in the ONNX graph. + /// Whether a unique name should be picked for this initializer. + /// The initializer's ONNX name + public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true); + /// /// Call this function can declare a global string tensor /// diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs index 508d419764..8105a81126 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs @@ -313,6 +313,17 @@ public override string AddInitializer(IEnumerable values, IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) + { + _host.CheckValue(values, nameof(values)); + if (dims != null) + _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size"); + + name = AddVariable(name ?? "double", makeUniqueName); + _initializers.Add(OnnxUtils.MakeDouble(name, values, dims)); + return name; + } + public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true) { _host.CheckValue(values, nameof(values)); diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs index 6735538e85..4978a8bd45 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs @@ -405,6 +405,20 @@ public static TensorProto MakeInt64s(string name, IEnumerable values, IEnu return tensor; } + // Make double vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor. + public static TensorProto MakeDouble(string name, IEnumerable values, IEnumerable dims = null) + { + var tensor = new TensorProto(); + tensor.Name = name; + tensor.DataType = (int)TensorProto.Types.DataType.Double; + tensor.DoubleData.AddRange(values); + if (dims != null) + tensor.Dims.AddRange(dims); + else + tensor.Dims.Add(values.Count()); + return tensor; + } + // Make float scalar in ONNX from native C# number public static TensorProto MakeFloat(string name, float value) { diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs index ec9056cef0..20c2d8be37 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs @@ -10,6 +10,7 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers; @@ -222,7 +223,8 @@ internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlassNaiv /// public sealed class NaiveBayesMulticlassModelParameters : ModelParametersBase>, - IValueMapper + IValueMapper, + ISingleCanSaveOnnx { internal const string LoaderSignature = "MultiClassNaiveBayesPred"; private static VersionInfo GetVersionInfo() @@ -252,6 +254,8 @@ private static VersionInfo GetVersionInfo() DataViewType IValueMapper.OutputType => _outputType; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; + /// /// Get the label histogram. /// @@ -383,6 +387,155 @@ ValueMapper IValueMapper.GetMapper() return (ValueMapper)(Delegate)del; } + /// + /// Creates an Onnx inferencing model by vectorizing and following the logic found in + /// + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + { + float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length]; + float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length]; + + for (int i = 0; i < _featureHistogram.Length; i++) + { + Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length); + } + for (int i = 0; i < _featureHistogram[0].Length; i++) + { + Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length); + } + + var one = ctx.AddInitializer(1.0f, "one"); + var zero = ctx.AddInitializer(0.0f, "zero"); + var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount"); + var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount"); + var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram"); + + var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram"); + var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded"); + var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb"); + + var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true); + var opType = "Greater"; + ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), ""); + + opType = "Cast"; + var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true); + var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); + node.AddAttribute("to", t); + + //initialize logProb + opType = "Div"; + var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true); + ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), ""); + + opType = "Log"; + var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true); + ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), ""); + + //log1 + opType = "Sum"; + var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); + ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); + + var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true); + LogMul(ctx, sumOutput, isFeaturePresent, logOutput1); + + //log2 + opType = "Transpose"; + var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true); + ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), ""); + + opType = "Sub"; + var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true); + ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), ""); + + opType = "Sum"; + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); + ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); + + var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true); + LogMul(ctx, sumOutput, isFeaturePresent, logOutput2); + + //log3 + opType = "Sum"; + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); + ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); + + var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true); + LogMul(ctx, sumOutput, isFeaturePresent, logOutput3); + + //result + opType = "Sub"; + var logProb = ctx.AddIntermediateVariable(null, "LogProb", true); + ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), ""); + + opType = "Sub"; + var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true); + ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), ""); + + opType = "ReduceSum"; + var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true); + node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), ""); + long[] list = { 1 }; + node.AddAttribute("axes", list); + + opType = "ReduceSum"; + var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true); + node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), ""); + node.AddAttribute("axes", list); + + opType = "Cast"; + var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true); + node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), ""); + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); + node.AddAttribute("to", t); + + opType = "Sub"; + var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true); + ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), ""); + + opType = "Sum"; + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); + ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); + + opType = "Transpose"; + var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true); + ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), ""); + + opType = "ArgMax"; + var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true); + ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), ""); + + opType = "Cast"; + castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true); + node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), ""); + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); + node.AddAttribute("to", t); + + //log3 + opType = "Sum"; + sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); + ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); + + opType = "Cast"; + node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), ""); + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); + node.AddAttribute("to", t); + + return true; + } + + private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output) + { + var opType = "Log"; + var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true); + ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), ""); + + opType = "Mul"; + ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), ""); + } + private void ComputeLabelProbabilityFromFeature(double labelOccurrenceCount, int labelIndex, int featureIndex, float featureValue, ref double logProb, ref double absentFeatureLogProb) { diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 7e8454229c..0ab47d8a27 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1177,6 +1177,7 @@ public void MulticlassTrainersOnnxConversionTest() List> estimators = new List>() { mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(), + mlContext.MulticlassClassification.Trainers.NaiveBayes(), mlContext.MulticlassClassification.Trainers.OneVersusAll( mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), useProbabilities:false), mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(),