Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch)
/// <summary>
/// The naive binning-based calibrator.
/// </summary>
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "NaiveCaliExec";
internal const string RegistrationName = "NaiveCalibrator";
Expand All @@ -1174,6 +1174,12 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName);
}

/// <summary>
/// Bool required by the interface ISingleCanSaveOnnx, returns true if
/// and only if calibrator can be exported in ONNX.
/// </summary>
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

private readonly IHost _host;

/// <summary> The bin size.</summary>
Expand Down Expand Up @@ -1280,6 +1286,45 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
return binIdx;
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
_host.CheckValue(ctx, nameof(ctx));
_host.CheckValue(outputNames, nameof(outputNames));
_host.Check(Utils.Size(outputNames) == 2);
// outputNames[0] refers to the name of the Score column, which is the input of this graph
// outputNames[1] refers to the name of the Probability column, which is the final output of this graph

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "NaiveCalibrator");

string opType = "Sub";
var minVar = ctx.AddInitializer(Min, "Min");
var subNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeOutput");
var node = ctx.CreateNode(opType, new[] { outputNames[0], minVar }, new[] { subNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Div";
var binSizeVar = ctx.AddInitializer(BinSize, "BinSize");
var divNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { subNodeOutput, binSizeVar }, new[] { divNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "castOutput");
node = ctx.CreateNode(opType, divNodeOutput, castOutput, ctx.GetNodeName(opType), "");
var toTypeInt = typeof(long);
node.AddAttribute("to", toTypeInt);

opType = "Clip";
var zeroVar = ctx.AddInitializer(0, "Zero");
var numBinsMinusOneVar = ctx.AddInitializer(_binProbs.Length-1, "NumBinsMinusOne");
var binIndexOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { castOutput, zeroVar, numBinsMinusOneVar }, new[] { binIndexOutput }, ctx.GetNodeName(opType), "");

opType = "GatherElements";
var binProbabilitiesVar = ctx.AddInitializer(_binProbs, new long[] { _binProbs.Length, 1 }, "BinProbabilities");
node = ctx.CreateNode(opType, new[] { binProbabilitiesVar, binIndexOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");

return true;
}
}

/// <summary>
Expand Down
95 changes: 57 additions & 38 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
private (MLContext, IDataView, List<IEstimator<ITransformer>>, EstimatorChain<NormalizingTransformer>) GetEstimatorsForOnnxConversionTests()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath("breast-cancer.txt");
Expand All @@ -289,70 +288,90 @@ public void PlattCalibratorOnnxConversionTest()

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
return (mlContext, dataView, estimators, initialPipeline);
}

private void CommonCalibratorOnnxConversionTest(MLContext mlContext, IDataView dataView,
List<IEstimator<ITransformer>> estimators, EstimatorChain<NormalizingTransformer> initialPipeline,
IEstimator<ITransformer> calibrator, IEstimator<ITransformer> calibratorNonStandard)
{
// Step 1: Test calibrator with binary prediction trainer
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator);
var onnxFileName = $"{estimator}-With-{calibrator}.onnx";
TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}

// Step 2: Test calibrator without any binary prediction trainer
IDataView dataSoloCalibrator = mlContext.Data.LoadFromEnumerable(GetCalibratorTestData());
var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx";
TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
IDataView dataSoloCalibratorNonStandard = mlContext.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard());
var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx";
TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

class PlattModelInput
[Fact]
public void PlattCalibratorOnnxConversionTest()
{
var (mlContext, dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
CommonCalibratorOnnxConversionTest(mlContext, dataView, estimators, initialPipeline,
mlContext.BinaryClassification.Calibrators.Platt(),
mlContext.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX"));
}

[Fact]
public void FixedPlattCalibratorOnnxConversionTest()
{
// Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
var (mlContext, dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
CommonCalibratorOnnxConversionTest(mlContext, dataView, estimators, initialPipeline,
mlContext.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f),
mlContext.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX"));
}

[Fact]
public void NaiveCalibratorOnnxConversionTest()
{
var (mlContext, dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
CommonCalibratorOnnxConversionTest(mlContext, dataView, estimators, initialPipeline,
mlContext.BinaryClassification.Calibrators.Naive(),
mlContext.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
}

class CalibratorInput
{
public bool Label { get; set; }
public float Score { get; set; }
}

class PlattModelInput2
class CalibratorInputNonStandard
{
public bool Label { get; set; }
public float ScoreX { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
static IEnumerable<CalibratorInput> GetCalibratorTestData()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
yield return new CalibratorInput { Score = i, Label = i % 2 == 0 };
}
}

static IEnumerable<PlattModelInput2> PlattGetData2()
static IEnumerable<CalibratorInputNonStandard> GetCalibratorTestDataNonStandard()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 };
yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
// Test PlattCalibrator without any binary prediction trainer
var mlContext = new MLContext(seed: 0);

IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());

var pipeline = mlContext.BinaryClassification.Calibrators
.Platt();
var onnxFileName = $"{pipeline}.onnx";

TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2());

var pipeline2 = mlContext.BinaryClassification.Calibrators
.Platt(scoreColumnName: "ScoreX");
var onnxFileName2 = $"{pipeline2}.onnx";

TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

[Fact]
public void TextNormalizingOnnxConversionTest()
{
Expand Down