Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global double tensor
/// </summary>
/// <param name="values">The doubles which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the doubles</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<double> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global string tensor
/// </summary>
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,17 @@ public override string AddInitializer(IEnumerable<long> values, IEnumerable<long
return name;
}

public override string AddInitializer(IEnumerable<double> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

name = AddVariable(name ?? "double", makeUniqueName);
_initializers.Add(OnnxUtils.MakeDouble(name, values, dims));
return name;
}

public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
Expand Down
14 changes: 14 additions & 0 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,20 @@ public static TensorProto MakeInt64s(string name, IEnumerable<long> values, IEnu
return tensor;
}

// Make double vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
public static TensorProto MakeDouble(string name, IEnumerable<double> values, IEnumerable<long> dims = null)
{
var tensor = new TensorProto();
tensor.Name = name;
tensor.DataType = (int)TensorProto.Types.DataType.Double;
tensor.DoubleData.AddRange(values);
if (dims != null)
tensor.Dims.AddRange(dims);
else
tensor.Dims.Add(values.Count());
return tensor;
}

// Make float scalar in ONNX from native C# number
public static TensorProto MakeFloat(string name, float value)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;

Expand Down Expand Up @@ -222,7 +223,8 @@ internal static CommonOutputs.MulticlassClassificationOutput TrainMulticlassNaiv
/// </summary>
public sealed class NaiveBayesMulticlassModelParameters :
ModelParametersBase<VBuffer<float>>,
IValueMapper
IValueMapper,
ISingleCanSaveOnnx
{
internal const string LoaderSignature = "MultiClassNaiveBayesPred";
private static VersionInfo GetVersionInfo()
Expand Down Expand Up @@ -252,6 +254,8 @@ private static VersionInfo GetVersionInfo()

DataViewType IValueMapper.OutputType => _outputType;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

/// <summary>
/// Get the label histogram.
/// </summary>
Expand Down Expand Up @@ -383,6 +387,155 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
return (ValueMapper<TIn, TOut>)(Delegate)del;
}

/// <summary>
/// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/>
/// </summary>
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length];
float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];

for (int i = 0; i < _featureHistogram.Length; i++)
{
Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
}
for (int i = 0; i < _featureHistogram[0].Length; i++)
{
Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
}

var one = ctx.AddInitializer(1.0f, "one");
var zero = ctx.AddInitializer(0.0f, "zero");
var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount");
var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");

var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");

var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true);
var opType = "Greater";
ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true);
var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

//initialize logProb
opType = "Div";
var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true);
ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");

opType = "Log";
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");

//log1
opType = "Sum";
var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true);
LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);

//log2
opType = "Transpose";
var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true);
ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");

opType = "Sub";
var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true);
ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");

opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true);
LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);

//log3
opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true);
LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);

//result
opType = "Sub";
var logProb = ctx.AddIntermediateVariable(null, "LogProb", true);
ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");

opType = "Sub";
var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true);
ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");

opType = "ReduceSum";
var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
long[] list = { 1 };
node.AddAttribute("axes", list);

opType = "ReduceSum";
var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", list);

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true);
node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

opType = "Sub";
var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true);
ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");

opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

opType = "Transpose";
var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true);
ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");

opType = "ArgMax";
var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true);
ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");

opType = "Cast";
castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true);
node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

//log3
opType = "Sum";
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
node.AddAttribute("to", t);

return true;
}

private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
{
var opType = "Log";
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");

opType = "Mul";
ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), "");
}

private void ComputeLabelProbabilityFromFeature(double labelOccurrenceCount, int labelIndex, int featureIndex,
float featureValue, ref double logProb, ref double absentFeatureLogProb)
{
Expand Down
1 change: 1 addition & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,7 @@ public void MulticlassTrainersOnnxConversionTest()
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(),
mlContext.MulticlassClassification.Trainers.NaiveBayes(),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(),
Expand Down