diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
index 3efc0e63e0..6183bb06c7 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
@@ -170,6 +170,16 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
/// The initializer's ONNX name
public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true);
+ ///
+ /// Call this function can declare a global double tensor
+ ///
+ /// The doubles which are going to be added into the ONNX graph
+ /// The shape that the doubles
+ /// A string used as a seed to generate this initializer's name in the ONNX graph.
+ /// Whether a unique name should be picked for this initializer.
+ /// The initializer's ONNX name
+ public abstract string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true);
+
///
/// Call this function can declare a global string tensor
///
diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
index 508d419764..8105a81126 100644
--- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
+++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
@@ -313,6 +313,17 @@ public override string AddInitializer(IEnumerable values, IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true)
+ {
+ _host.CheckValue(values, nameof(values));
+ if (dims != null)
+ _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
+
+ name = AddVariable(name ?? "double", makeUniqueName);
+ _initializers.Add(OnnxUtils.MakeDouble(name, values, dims));
+ return name;
+ }
+
public override string AddInitializer(IEnumerable values, IEnumerable dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
index 6735538e85..4978a8bd45 100644
--- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
+++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
@@ -405,6 +405,20 @@ public static TensorProto MakeInt64s(string name, IEnumerable values, IEnu
return tensor;
}
+ // Make double vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
+ public static TensorProto MakeDouble(string name, IEnumerable values, IEnumerable dims = null)
+ {
+ var tensor = new TensorProto();
+ tensor.Name = name;
+ tensor.DataType = (int)TensorProto.Types.DataType.Double;
+ tensor.DoubleData.AddRange(values);
+ if (dims != null)
+ tensor.Dims.AddRange(dims);
+ else
+ tensor.Dims.Add(values.Count());
+ return tensor;
+ }
+
// Make float scalar in ONNX from native C# number
public static TensorProto MakeFloat(string name, float value)
{
diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
index ec9056cef0..20c2d8be37 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
@@ -10,6 +10,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
+using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
@@ -222,7 +223,8 @@ internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlassNaiv
///
public sealed class NaiveBayesMulticlassModelParameters :
ModelParametersBase>,
- IValueMapper
+ IValueMapper,
+ ISingleCanSaveOnnx
{
internal const string LoaderSignature = "MultiClassNaiveBayesPred";
private static VersionInfo GetVersionInfo()
@@ -252,6 +254,8 @@ private static VersionInfo GetVersionInfo()
DataViewType IValueMapper.OutputType => _outputType;
+ bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
+
///
/// Get the label histogram.
///
@@ -383,6 +387,155 @@ ValueMapper IValueMapper.GetMapper()
return (ValueMapper)(Delegate)del;
}
+ ///
+ /// Creates an Onnx inferencing model by vectorizing and following the logic found in
+ ///
+ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
+ {
+ float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length];
+ float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];
+
+ for (int i = 0; i < _featureHistogram.Length; i++)
+ {
+ Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
+ }
+ for (int i = 0; i < _featureHistogram[0].Length; i++)
+ {
+ Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
+ }
+
+ var one = ctx.AddInitializer(1.0f, "one");
+ var zero = ctx.AddInitializer(0.0f, "zero");
+ var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount");
+ var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
+ var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");
+
+ var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
+ var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
+ var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");
+
+ var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true);
+ var opType = "Greater";
+ ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Cast";
+ var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true);
+ var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), "");
+ var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
+ node.AddAttribute("to", t);
+
+ //initialize logProb
+ opType = "Div";
+ var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true);
+ ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Log";
+ var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
+ ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");
+
+ //log1
+ opType = "Sum";
+ var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
+ ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
+
+ var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true);
+ LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);
+
+ //log2
+ opType = "Transpose";
+ var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true);
+ ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");
+
+ opType = "Sub";
+ var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true);
+ ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");
+
+ opType = "Sum";
+ sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
+ ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
+
+ var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true);
+ LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);
+
+ //log3
+ opType = "Sum";
+ sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
+ ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
+
+ var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true);
+ LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);
+
+ //result
+ opType = "Sub";
+ var logProb = ctx.AddIntermediateVariable(null, "LogProb", true);
+ ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");
+
+ opType = "Sub";
+ var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true);
+ ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");
+
+ opType = "ReduceSum";
+ var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
+ node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
+ long[] list = { 1 };
+ node.AddAttribute("axes", list);
+
+ opType = "ReduceSum";
+ var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
+ node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
+ node.AddAttribute("axes", list);
+
+ opType = "Cast";
+ var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true);
+ node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
+ t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
+ node.AddAttribute("to", t);
+
+ opType = "Sub";
+ var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true);
+ ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Sum";
+ sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
+ ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Transpose";
+ var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true);
+ ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");
+
+ opType = "ArgMax";
+ var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true);
+ ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
+
+ opType = "Cast";
+ castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true);
+ node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
+ t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
+ node.AddAttribute("to", t);
+
+ //log3
+ opType = "Sum";
+ sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
+ ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Cast";
+ node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
+ t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
+ node.AddAttribute("to", t);
+
+ return true;
+ }
+
+ private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
+ {
+ var opType = "Log";
+ var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
+ ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");
+
+ opType = "Mul";
+ ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), "");
+ }
+
private void ComputeLabelProbabilityFromFeature(double labelOccurrenceCount, int labelIndex, int featureIndex,
float featureValue, ref double logProb, ref double absentFeatureLogProb)
{
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index 7e8454229c..0ab47d8a27 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -1177,6 +1177,7 @@ public void MulticlassTrainersOnnxConversionTest()
List> estimators = new List>()
{
mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(),
+ mlContext.MulticlassClassification.Trainers.NaiveBayes(),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(),