diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 375e2a8fb1..a1f28897ed 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1924,7 +1924,7 @@ public override ICalibrator CreateCalibrator(IChannel ch) /// [n], if x > [n] /// /// - public sealed class IsotonicCalibrator : ICalibrator, ICanSaveInBinaryFormat + public sealed class IsotonicCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx { internal const string LoaderSignature = "PAVCaliExec"; internal const string RegistrationName = "PAVCalibrator"; @@ -1958,6 +1958,11 @@ private static VersionInfo GetVersionInfo() /// Values of PAV intervals. /// public readonly ImmutableArray Values; + /// + /// Bool required by the interface ISingleCanSaveOnnx, returns true if + /// and only if calibrator can be exported in ONNX. + /// + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; /// /// Initializes a new instance of . @@ -2115,6 +2120,187 @@ private float FindValue(float score) float t = (score - Maxes[pos - 1]) / (Mins[pos] - Maxes[pos - 1]); return Values[pos - 1] + t * (Values[pos] - Values[pos - 1]); } + + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + { + _host.CheckValue(ctx, nameof(ctx)); + _host.CheckValue(outputNames, nameof(outputNames)); + _host.Check(Utils.Size(outputNames) == 2); + + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, "IsotonicCalibrator"); + + var minsLengthVar = ctx.AddInitializer(Mins.Length, "MinsLength"); + var minsLengthMinusOneVar = ctx.AddInitializer(Mins.Length - 1, "MinsLengthMinusOne"); + var maxesLengthVar = ctx.AddInitializer(Maxes.Length, "MaxesLength"); + var minToReturnVar = ctx.AddInitializer((float)1e-15, "MinToReturn"); + var maxToReturnVar = ctx.AddInitializer((float)(1 - 1e-15), "MaxToReturn"); + var minsVar = ctx.AddInitializer(Mins, new long[] { Mins.Length, 1 }, "Mins"); + var maxesVar = ctx.AddInitializer(Maxes, new long[] { Maxes.Length, 1 }, "Maxes"); + var valuesVar = ctx.AddInitializer(Values, new long[] { Values.Length, 1 }, "Values"); + var minsZeroVar = ctx.AddInitializer(Mins[0], "MinsZero"); + var maxesPMinusOneVar = ctx.AddInitializer(Maxes[Mins.Length - 1], "MaxesPMinusOne"); + var zeroVar = ctx.AddInitializer(0, "Zero"); + var oneVar = ctx.AddInitializer(1, "One"); + + //The isotonic regression optimization problem is defined by: + // min(sum(w_i, (y[i] - y_[j])^2)) + // subject to y_[i] <= y_[j] whenever X[i] <= X[j] (non-decreasing) + // and min(y_) = y_min, max(y_) = y_max + // where: + // *y[i] are inputs(real numbers) + // *y_[i] are fitted + // *X specifies the order.If X is non-decreasing then y_ is non-decreasing. + // *w[i] are optional strictly positive weights(default to 1.0) + + // Goal: Given output, calculate prob + + // --- STEP 1: implement if-then-else logic for (p = Mins.Length): ------------------------------------ + // If p == 0, Return 0 + // If score < Mins[0], Return prob = Values[0] + // If score > Maxes[p-1], Return prob = Values[p-1] + // Else, continue + + // Get Values[0] + string opType = "GatherElements"; + var valuesZeroOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesZeroOutput"); + OnnxNode node = ctx.CreateNode(opType, new[] { valuesVar, zeroVar }, new[] { valuesZeroOutput }, ctx.GetNodeName(opType), ""); + + // Get Values[p-1] = Values[Mins.Length-1] + opType = "GatherElements"; + var valuesPMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesPMinusOneOutput"); + node = ctx.CreateNode(opType, new[] { valuesVar, minsLengthMinusOneVar }, new[] { valuesPMinusOneOutput }, ctx.GetNodeName(opType), ""); + + opType = "Equal"; + var pEqualToZeroOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "pEqualToZeroOutput"); + node = ctx.CreateNode(opType, new[] { minsLengthVar, zeroVar }, new[] { pEqualToZeroOutput }, ctx.GetNodeName(opType), ""); + + opType = "Less"; + var scoreLessThenMinsZeroOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "scoreLessThenMinsZeroOutput"); + node = ctx.CreateNode(opType, new[] { outputNames[0], minsZeroVar }, new[] { scoreLessThenMinsZeroOutput }, ctx.GetNodeName(opType), ""); + + opType = "Greater"; + var scoreGreaterThenMaxesPMinusOneOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "scoreGreaterThenMaxesPMinusOneOutput"); + node = ctx.CreateNode(opType, new[] { outputNames[0], maxesPMinusOneVar }, new[] { scoreGreaterThenMaxesPMinusOneOutput }, ctx.GetNodeName(opType), ""); + + // Implement if statements + // To-do + + // --- STEP 2: calculate pos, which is the index of the given score in the already-sorted Maxes ------------------------------------ + // AKA: Find closest element to score in maxes + + // scoreRepeatedAsVectorOutput, which has score repeated in all indices, length of vector is same as that of Maxes + // Calculate with mul_broadcast element-wise binary multiplication + // Note: score = outputNames[0] + var shapeOnesAsVector = new long[] { Maxes.Length }; + var onesAsVector = new List(); + for(int i = 0; i < Maxes.Length; i++) + onesAsVector.Add(1.0f); + var onesAsVectorVar = ctx.AddInitializer(onesAsVector, shapeOnesAsVector, "OnesAsVector"); + + opType = "Mul"; + var scoreRepeatedAsVectorOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "scoreRepeatedAsVectorOutput"); + node = ctx.CreateNode(opType, new[] { onesAsVectorVar, outputNames[0] }, new[] { scoreRepeatedAsVectorOutput }, ctx.GetNodeName(opType), ""); + + // Subtract scoreRepeatedAsVectorOutput from Maxes + opType = "Sub"; + var subVectorsOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "scoreAsVectorOutput"); + node = ctx.CreateNode(opType, new[] { maxesVar, scoreRepeatedAsVectorOutput }, new[] { subVectorsOutput }, ctx.GetNodeName(opType), ""); + + // Square values in subVectorsOutput by multiplying it with itself + opType = "Mul"; + var squaredVectorOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "squaredVectorOutput"); + node = ctx.CreateNode(opType, new[] { subVectorsOutput, subVectorsOutput }, new[] { squaredVectorOutput }, ctx.GetNodeName(opType), ""); + + // Return index of given score, or if given score doesn't exist, return "logical index" + // of given score, which is the index of the first element greater than the given score. + opType = "ArgMin"; + var posOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "posOutput"); + node = ctx.CreateNode(opType, new[] { squaredVectorOutput }, new[] { posOutput }, ctx.GetNodeName(opType), ""); + + // --- STEP 3: if score >= Mins[pos], then prob = Values[pos] ------------------------------------ + + // Get Mins[pos] + opType = "GatherElements"; + var minsPosOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "minsPosOutput"); + node = ctx.CreateNode(opType, new[] { minsVar, posOutput }, new[] { minsPosOutput }, ctx.GetNodeName(opType), ""); + + // Get Values[pos] + opType = "GatherElements"; + var valuesPosOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesPosOutput"); + node = ctx.CreateNode(opType, new[] { valuesVar, posOutput }, new[] { valuesPosOutput }, ctx.GetNodeName(opType), ""); + + opType = "GreaterOrEqual"; + var scoreGreaterThanEqualToMinsPosOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "scoreGreaterThanEqualToMinsPosOutput"); + node = ctx.CreateNode(opType, new[] { outputNames[0], posOutput }, new[] { scoreGreaterThanEqualToMinsPosOutput }, ctx.GetNodeName(opType), ""); + + // Implement if statements + // To-do + + // --- STEP 4: calculate (score - Maxes[pos - 1]) / (Mins[pos] - Maxes[pos - 1]) ------------------------------------ + // score: outputNames[0] + + opType = "Sub"; + var posMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "posMinusOneOutput"); + node = ctx.CreateNode(opType, new[] { posOutput, oneVar }, new[] { posMinusOneOutput }, ctx.GetNodeName(opType), ""); + + opType = "GatherElements"; + var maxesPosMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "maxesPosMinusOneOutput"); + node = ctx.CreateNode(opType, new[] { maxesVar, posMinusOneOutput }, new[] { maxesPosMinusOneOutput }, ctx.GetNodeName(opType), ""); + + opType = "Sub"; + var subNode1Output = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeUpperOutput"); + node = ctx.CreateNode(opType, new[] { outputNames[0], maxesPosMinusOneOutput }, new[] { subNode1Output }, ctx.GetNodeName(opType), ""); + + opType = "Sub"; + var subNode2Output = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeLowerOutput"); + node = ctx.CreateNode(opType, new[] { minsPosOutput, maxesPosMinusOneOutput }, new[] { subNode2Output }, ctx.GetNodeName(opType), ""); + + opType = "Div"; + var tNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "divNodeOutput"); + node = ctx.CreateNode(opType, new[] { subNode1Output, subNode2Output }, new[] { tNodeOutput }, ctx.GetNodeName(opType), ""); + + // --- STEP 5: calculate and return prob = Values[pos - 1] + t * (Values[pos] - Values[pos - 1]); ------------------------------------ + + opType = "GatherElements"; + var valuesPosMinusOneOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "valuesPosMinusOneOutput"); + node = ctx.CreateNode(opType, new[] { valuesVar, posMinusOneOutput }, new[] { valuesPosMinusOneOutput }, ctx.GetNodeName(opType), ""); + + opType = "Sub"; + var subNode3Output = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNode3Output"); + node = ctx.CreateNode(opType, new[] { valuesPosOutput, valuesPosMinusOneOutput }, new[] { subNode3Output }, ctx.GetNodeName(opType), ""); + + opType = "Mul"; + var mulNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "mulNodeOutput"); + node = ctx.CreateNode(opType, new[] { tNodeOutput, subNode3Output }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), ""); + + opType = "Add"; + var probabilityNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "probabilityNodeOutput"); + node = ctx.CreateNode(opType, new[] { valuesPosMinusOneOutput, mulNodeOutput }, new[] { probabilityNodeOutput }, ctx.GetNodeName(opType), ""); + + // --- STEP 6: continue with logic ------------------------------------ + // if (prob < MinToReturn) + // return MinToReturn; + // if (prob > MaxToReturn) + // return MaxToReturn; + // return prob + + opType = "Less"; + var probLessThanMinToReturnOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "probLessThanMinToReturnOutput"); + node = ctx.CreateNode(opType, new[] { probabilityNodeOutput, minToReturnVar }, new[] { probLessThanMinToReturnOutput }, ctx.GetNodeName(opType), ""); + + opType = "Greater"; + var probGreaterThanMaxToReturnOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "probGreaterThanMaxToReturnOutput"); + node = ctx.CreateNode(opType, new[] { probabilityNodeOutput, maxToReturnVar }, new[] { probGreaterThanMaxToReturnOutput }, ctx.GetNodeName(opType), ""); + + // Implement if statements + // To-do + + opType = "Identity"; + node = ctx.CreateNode(opType, new[] { probabilityNodeOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), ""); + + return true; + } } internal static class Calibrate diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 56d50ac702..bccb3306fa 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -338,6 +338,13 @@ public void NaiveCalibratorOnnxConversionTest() ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX")); } + [Fact] + public void IsotonicCalibratorOnnxConversionTest() + { + CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Isotonic(), + ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX")); + } + class CalibratorInput { public bool Label { get; set; }