Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3112,7 +3112,7 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames,
}

string opType = "TreeEnsembleRegressor";
string scoreVarName = (Utils.Size(outputNames) == 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns
string scoreVarName = (Utils.Size(outputNames) >= 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));

node.AddAttribute("post_transform", PostTransform.None.GetDescription());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, str
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 1);
string opType = "LinearRegressor";
string scoreVarName = (Utils.Size(outputs) == 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns
string scoreVarName = (Utils.Size(outputs) >= 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What cases result in the output size being greater than 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a calibrator is used, the probability column is added. More specifically, the OVA trainer always creates a probability output column, so the output count will be 3. An alternative would be to distinguish between trainers that need the probability column and those that don't inside OVA, but I think that would be more involved.

var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
node.AddAttribute("post_transform", "NONE");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool
var clipNode = ctx.CreateNode(opType, new[] { clipInput, zeroVar }, new[] { outputs[i] }, ctx.GetNodeName(opType), "");
}
else
outputs[i] = predictorOutputNames[2];
outputs[i] = predictorOutputNames[1];
}
return outputs;
}
Expand Down Expand Up @@ -659,7 +659,7 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)

public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true);
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false);

string opType = "Concat";
var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true);
Expand Down Expand Up @@ -794,22 +794,27 @@ public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fe

opType = "Sum";
var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScores", true);
var sumNode = ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), "");
ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZero", true);
var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "CastOutput", true);
var castNode = ctx.CreateNode(opType, sumOutput, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
castNode.AddAttribute("to", t);

opType = "Not";
var notOutput = ctx.AddIntermediateVariable(null, "IsSumZero", true);
ctx.CreateNode(opType, castOutput, notOutput, ctx.GetNodeName(opType), "");

opType = "Cast";
var castIsZeroSumToFloat = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZeroAsFloat", true);
var castIsZeroSumToFloatNode = ctx.CreateNode(opType, castOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), "");
var castIsZeroSumToFloatNode = ctx.CreateNode(opType, notOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), "");
var t1 = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
castIsZeroSumToFloatNode.AddAttribute("to", t1);

opType = "Sum";
var sumOutputNonZero = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScoresNonZero", true);
var sumOutputNonZeroNode = ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat },
ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat },
new[] { sumOutputNonZero }, ctx.GetNodeName(opType), "");

string[] divOutputs = new string[Predictors.Length];
Expand Down
17 changes: 16 additions & 1 deletion test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1319,8 +1319,22 @@ public void MulticlassTrainersOnnxConversionTest()
{
mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(),
mlContext.MulticlassClassification.Trainers.NaiveBayes(),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.AveragedPerceptron()),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.AveragedPerceptron(), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression()),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LinearSvm()),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LinearSvm(), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.FastForest()),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.FastForest(), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(),
mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated()
};
Expand Down Expand Up @@ -1357,7 +1371,8 @@ public void MulticlassTrainersOnnxConversionTest()
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedScalarColumns<uint>(transformedData.Schema[5].Name, outputNames[2], transformedData, onnxResult);
CompareSelectedScalarColumns<uint>(transformedData.Schema[5].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels
CompareSelectedR4VectorColumns(transformedData.Schema[6].Name, outputNames[3], transformedData, onnxResult, 4); //compare scores
}
}
Done();
Expand Down