diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index 375e2a8fb1..a1f28897ed 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -1924,7 +1924,7 @@ public override ICalibrator CreateCalibrator(IChannel ch)
/// - [n], if x > [n]
///
///
- public sealed class IsotonicCalibrator : ICalibrator, ICanSaveInBinaryFormat
+ public sealed class IsotonicCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "PAVCaliExec";
internal const string RegistrationName = "PAVCalibrator";
@@ -1958,6 +1958,11 @@ private static VersionInfo GetVersionInfo()
/// Values of PAV intervals.
///
public readonly ImmutableArray Values;
+ ///
+ /// Bool required by the interface ISingleCanSaveOnnx, returns true if
+ /// and only if calibrator can be exported in ONNX.
+ ///
+ bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
///
/// Initializes a new instance of .
@@ -2115,6 +2120,187 @@ private float FindValue(float score)
float t = (score - Maxes[pos - 1]) / (Mins[pos] - Maxes[pos - 1]);
return Values[pos - 1] + t * (Values[pos] - Values[pos - 1]);
}
+
+ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
+ {
+ _host.CheckValue(ctx, nameof(ctx));
+ _host.CheckValue(outputNames, nameof(outputNames));
+ _host.Check(Utils.Size(outputNames) == 2);
+
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, "IsotonicCalibrator");
+
+ var minsLengthVar = ctx.AddInitializer(Mins.Length, "MinsLength");
+ var minsLengthMinusOneVar = ctx.AddInitializer(Mins.Length - 1, "MinsLengthMinusOne");
+ var maxesLengthVar = ctx.AddInitializer(Maxes.Length, "MaxesLength");
+ var minToReturnVar = ctx.AddInitializer((float)1e-15, "MinToReturn");
+ var maxToReturnVar = ctx.AddInitializer((float)(1 - 1e-15), "MaxToReturn");
+ var minsVar = ctx.AddInitializer(Mins, new long[] { Mins.Length, 1 }, "Mins");
+ var maxesVar = ctx.AddInitializer(Maxes, new long[] { Maxes.Length, 1 }, "Maxes");
+ var valuesVar = ctx.AddInitializer(Values, new long[] { Values.Length, 1 }, "Values");
+ var minsZeroVar = ctx.AddInitializer(Mins[0], "MinsZero");
+ var maxesPMinusOneVar = ctx.AddInitializer(Maxes[Mins.Length - 1], "MaxesPMinusOne");
+ var zeroVar = ctx.AddInitializer(0, "Zero");
+ var oneVar = ctx.AddInitializer(1, "One");
+
+ //The isotonic regression optimization problem is defined by:
+ // min(sum(w_i, (y[i] - y_[j])^2))
+ // subject to y_[i] <= y_[j] whenever X[i] <= X[j] (non-decreasing)
+ // and min(y_) = y_min, max(y_) = y_max
+ // where:
+ // *y[i] are inputs(real numbers)
+ // *y_[i] are fitted
+ // *X specifies the order.If X is non-decreasing then y_ is non-decreasing.
+ // *w[i] are optional strictly positive weights(default to 1.0)
+
+ // Goal: Given output, calculate prob
+
+ // --- STEP 1: implement if-then-else logic for (p = Mins.Length): ------------------------------------
+ // If p == 0, Return 0
+ // If score < Mins[0], Return prob = Values[0]
+ // If score > Maxes[p-1], Return prob = Values[p-1]
+ // Else, continue
+
+ // Get Values[0]
+ string opType = "GatherElements";
+ var valuesZeroOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesZeroOutput");
+ OnnxNode node = ctx.CreateNode(opType, new[] { valuesVar, zeroVar }, new[] { valuesZeroOutput }, ctx.GetNodeName(opType), "");
+
+ // Get Values[p-1] = Values[Mins.Length-1]
+ opType = "GatherElements";
+ var valuesPMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesPMinusOneOutput");
+ node = ctx.CreateNode(opType, new[] { valuesVar, minsLengthMinusOneVar }, new[] { valuesPMinusOneOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Equal";
+ var pEqualToZeroOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "pEqualToZeroOutput");
+ node = ctx.CreateNode(opType, new[] { minsLengthVar, zeroVar }, new[] { pEqualToZeroOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Less";
+ var scoreLessThenMinsZeroOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "scoreLessThenMinsZeroOutput");
+ node = ctx.CreateNode(opType, new[] { outputNames[0], minsZeroVar }, new[] { scoreLessThenMinsZeroOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Greater";
+ var scoreGreaterThenMaxesPMinusOneOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "scoreGreaterThenMaxesPMinusOneOutput");
+ node = ctx.CreateNode(opType, new[] { outputNames[0], maxesPMinusOneVar }, new[] { scoreGreaterThenMaxesPMinusOneOutput }, ctx.GetNodeName(opType), "");
+
+ // Implement if statements
+ // To-do
+
+ // --- STEP 2: calculate pos, which is the index of the given score in the already-sorted Maxes ------------------------------------
+ // AKA: Find closest element to score in maxes
+
+ // scoreRepeatedAsVectorOutput, which has score repeated in all indices, length of vector is same as that of Maxes
+ // Calculate with mul_broadcast element-wise binary multiplication
+ // Note: score = outputNames[0]
+ var shapeOnesAsVector = new long[] { Maxes.Length };
+ var onesAsVector = new List();
+ for(int i = 0; i < Maxes.Length; i++)
+ onesAsVector.Add(1.0f);
+ var onesAsVectorVar = ctx.AddInitializer(onesAsVector, shapeOnesAsVector, "OnesAsVector");
+
+ opType = "Mul";
+ var scoreRepeatedAsVectorOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "scoreRepeatedAsVectorOutput");
+ node = ctx.CreateNode(opType, new[] { onesAsVectorVar, outputNames[0] }, new[] { scoreRepeatedAsVectorOutput }, ctx.GetNodeName(opType), "");
+
+ // Subtract scoreRepeatedAsVectorOutput from Maxes
+ opType = "Sub";
+ var subVectorsOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "scoreAsVectorOutput");
+ node = ctx.CreateNode(opType, new[] { maxesVar, scoreRepeatedAsVectorOutput }, new[] { subVectorsOutput }, ctx.GetNodeName(opType), "");
+
+ // Square values in subVectorsOutput by multiplying it with itself
+ opType = "Mul";
+ var squaredVectorOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "squaredVectorOutput");
+ node = ctx.CreateNode(opType, new[] { subVectorsOutput, subVectorsOutput }, new[] { squaredVectorOutput }, ctx.GetNodeName(opType), "");
+
+ // Return index of given score, or if given score doesn't exist, return "logical index"
+ // of given score, which is the index of the first element greater than the given score.
+ opType = "ArgMin";
+ var posOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "posOutput");
+ node = ctx.CreateNode(opType, new[] { squaredVectorOutput }, new[] { posOutput }, ctx.GetNodeName(opType), "");
+
+ // --- STEP 3: if score >= Mins[pos], then prob = Values[pos] ------------------------------------
+
+ // Get Mins[pos]
+ opType = "GatherElements";
+ var minsPosOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "minsPosOutput");
+ node = ctx.CreateNode(opType, new[] { minsVar, posOutput }, new[] { minsPosOutput }, ctx.GetNodeName(opType), "");
+
+ // Get Values[pos]
+ opType = "GatherElements";
+ var valuesPosOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesPosOutput");
+ node = ctx.CreateNode(opType, new[] { valuesVar, posOutput }, new[] { valuesPosOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "GreaterOrEqual";
+ var scoreGreaterThanEqualToMinsPosOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "scoreGreaterThanEqualToMinsPosOutput");
+ node = ctx.CreateNode(opType, new[] { outputNames[0], posOutput }, new[] { scoreGreaterThanEqualToMinsPosOutput }, ctx.GetNodeName(opType), "");
+
+ // Implement if statements
+ // To-do
+
+ // --- STEP 4: calculate (score - Maxes[pos - 1]) / (Mins[pos] - Maxes[pos - 1]) ------------------------------------
+ // score: outputNames[0]
+
+ opType = "Sub";
+ var posMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "posMinusOneOutput");
+ node = ctx.CreateNode(opType, new[] { posOutput, oneVar }, new[] { posMinusOneOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "GatherElements";
+ var maxesPosMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "maxesPosMinusOneOutput");
+ node = ctx.CreateNode(opType, new[] { maxesVar, posMinusOneOutput }, new[] { maxesPosMinusOneOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Sub";
+ var subNode1Output = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeUpperOutput");
+ node = ctx.CreateNode(opType, new[] { outputNames[0], maxesPosMinusOneOutput }, new[] { subNode1Output }, ctx.GetNodeName(opType), "");
+
+ opType = "Sub";
+ var subNode2Output = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeLowerOutput");
+ node = ctx.CreateNode(opType, new[] { minsPosOutput, maxesPosMinusOneOutput }, new[] { subNode2Output }, ctx.GetNodeName(opType), "");
+
+ opType = "Div";
+ var tNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "divNodeOutput");
+ node = ctx.CreateNode(opType, new[] { subNode1Output, subNode2Output }, new[] { tNodeOutput }, ctx.GetNodeName(opType), "");
+
+ // --- STEP 5: calculate and return prob = Values[pos - 1] + t * (Values[pos] - Values[pos - 1]); ------------------------------------
+
+ opType = "GatherElements";
+ var valuesPosMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesPosMinusOneOutput");
+ node = ctx.CreateNode(opType, new[] { valuesVar, posMinusOneOutput }, new[] { valuesPosMinusOneOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Sub";
+ var subNode3Output = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNode3Output");
+ node = ctx.CreateNode(opType, new[] { valuesPosOutput, valuesPosMinusOneOutput }, new[] { subNode3Output }, ctx.GetNodeName(opType), "");
+
+ opType = "Mul";
+ var mulNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "mulNodeOutput");
+ node = ctx.CreateNode(opType, new[] { tNodeOutput, subNode3Output }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Add";
+ var probabilityNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "probabilityNodeOutput");
+ node = ctx.CreateNode(opType, new[] { valuesPosMinusOneOutput, mulNodeOutput }, new[] { probabilityNodeOutput }, ctx.GetNodeName(opType), "");
+
+ // --- STEP 6: continue with logic ------------------------------------
+ // if (prob < MinToReturn)
+ // return MinToReturn;
+ // if (prob > MaxToReturn)
+ // return MaxToReturn;
+ // return prob
+
+ opType = "Less";
+ var probLessThanMinToReturnOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "probLessThanMinToReturnOutput");
+ node = ctx.CreateNode(opType, new[] { probabilityNodeOutput, minToReturnVar }, new[] { probLessThanMinToReturnOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Greater";
+ var probGreaterThanMaxToReturnOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "probGreaterThanMaxToReturnOutput");
+ node = ctx.CreateNode(opType, new[] { probabilityNodeOutput, maxToReturnVar }, new[] { probGreaterThanMaxToReturnOutput }, ctx.GetNodeName(opType), "");
+
+ // Implement if statements
+ // To-do
+
+ opType = "Identity";
+ node = ctx.CreateNode(opType, new[] { probabilityNodeOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");
+
+ return true;
+ }
}
internal static class Calibrate
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index 56d50ac702..bccb3306fa 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -338,6 +338,13 @@ public void NaiveCalibratorOnnxConversionTest()
ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
}
+ [Fact]
+ public void IsotonicCalibratorOnnxConversionTest()
+ {
+ CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Isotonic(),
+ ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
+ }
+
class CalibratorInput
{
public bool Label { get; set; }