Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ public CalibratorTransformer<TICalibrator> Fit(IDataView input)
{
var calibrator = (TICalibrator)CalibratorUtils.TrainCalibrator(Host, ch,
_calibratorTrainer, input, LabelColumn.Name, ScoreColumn.Name, WeightColumn.Name);
return Create(Host, calibrator);
return Create(Host, calibrator, ScoreColumn.Name);
}
}

/// <summary>
/// Implemented by deriving classes that create a concrete calibrator.
/// </summary>
[BestFriend]
private protected abstract CalibratorTransformer<TICalibrator> Create(IHostEnvironment env, TICalibrator calibrator);
private protected abstract CalibratorTransformer<TICalibrator> Create(IHostEnvironment env, TICalibrator calibrator, string scoreColumnName);
}

/// <summary>
Expand All @@ -167,12 +167,14 @@ public abstract class CalibratorTransformer<TICalibrator> : RowToRowTransformerB
{
private readonly TICalibrator _calibrator;
private readonly string _loaderSignature;
private readonly string _scoreColumnName;

private protected CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature)
private protected CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature, string scoreColumnName)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer<TICalibrator>)))
{
_loaderSignature = loaderSignature;
_calibrator = calibrator;
_scoreColumnName = scoreColumnName;
}

// Factory method for SignatureLoadModel.
Expand All @@ -186,7 +188,16 @@ private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext c

// *** Binary format ***
// model: _calibrator
// scoreColumnName: _scoreColumnName
ctx.LoadModel<TICalibrator, SignatureLoadModel>(env, out _calibrator, "Calibrator");
if (ctx.Header.ModelVerWritten >= 0x00010002)
{
_scoreColumnName = ctx.LoadString();
}
else
{
_scoreColumnName = DefaultColumnNames.Score;
}
}

string ISingleFeaturePredictionTransformer<TICalibrator>.FeatureColumnName => DefaultColumnNames.Score;
Expand All @@ -205,16 +216,19 @@ private protected override void SaveModel(ModelSaveContext ctx)

// *** Binary format ***
// model: _calibrator
// scoreColumnName: _scoreColumnName
ctx.SaveModel(_calibrator, "Calibrator");
ctx.SaveString(_scoreColumnName);
}

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper<TICalibrator>(this, _calibrator, schema);
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper<TICalibrator>(this, _calibrator, schema, _scoreColumnName);

private protected VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "CALTRANS",
verWrittenCur: 0x00010001, // Initial
// verWrittenCur: 0x00010001, // Initial
verWrittenCur: 0x00010002, // Added _scoreColumnName
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: _loaderSignature,
Expand All @@ -227,18 +241,20 @@ private sealed class Mapper<TCalibrator> : MapperBase, ISaveAsOnnx
private TCalibrator _calibrator;
private readonly int _scoreColIndex;
private CalibratorTransformer<TCalibrator> _parent;
private string _scoreColumnName;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;

internal Mapper(CalibratorTransformer<TCalibrator> parent, TCalibrator calibrator, DataViewSchema inputSchema) :
internal Mapper(CalibratorTransformer<TCalibrator> parent, TCalibrator calibrator, DataViewSchema inputSchema, string scoreColumnName) :
base(parent.Host, inputSchema, parent)
{
_calibrator = calibrator;
_parent = parent;

_scoreColIndex = inputSchema.GetColumnOrNull(DefaultColumnNames.Score)?.Index ?? -1;
_scoreColumnName = scoreColumnName;
_scoreColIndex = inputSchema.GetColumnOrNull(_scoreColumnName)?.Index ?? -1;

parent.Host.Check(_scoreColIndex > 0, "The data to calibrate contains no 'Score' column");
parent.Host.Check(_scoreColIndex >= 0, "The data to calibrate contains no \'" + scoreColumnName + "\' column.");
}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
Expand Down Expand Up @@ -337,8 +353,8 @@ internal PlattCalibratorEstimator(IHostEnvironment env,
}

[BestFriend]
private protected override CalibratorTransformer<PlattCalibrator> Create(IHostEnvironment env, PlattCalibrator calibrator)
=> new PlattCalibratorTransformer(env, calibrator);
private protected override CalibratorTransformer<PlattCalibrator> Create(IHostEnvironment env, PlattCalibrator calibrator, string scoreColumnName)
=> new PlattCalibratorTransformer(env, calibrator, scoreColumnName);
}

/// <summary>
Expand Down Expand Up @@ -374,8 +390,8 @@ internal FixedPlattCalibratorEstimator(IHostEnvironment env,
}

[BestFriend]
private protected override CalibratorTransformer<PlattCalibrator> Create(IHostEnvironment env, PlattCalibrator calibrator)
=> new PlattCalibratorTransformer(env, calibrator);
private protected override CalibratorTransformer<PlattCalibrator> Create(IHostEnvironment env, PlattCalibrator calibrator, string scoreColumnName)
=> new PlattCalibratorTransformer(env, calibrator, scoreColumnName);
}

/// <summary>
Expand All @@ -385,8 +401,8 @@ public sealed class PlattCalibratorTransformer : CalibratorTransformer<PlattCali
{
internal const string LoadName = "PlattCalibratTransf";

internal PlattCalibratorTransformer(IHostEnvironment env, PlattCalibrator calibrator)
: base(env, calibrator, LoadName)
internal PlattCalibratorTransformer(IHostEnvironment env, PlattCalibrator calibrator, string scoreColumnName)
: base(env, calibrator, LoadName, scoreColumnName)
{
}

Expand Down Expand Up @@ -425,8 +441,8 @@ internal NaiveCalibratorEstimator(IHostEnvironment env,
}

[BestFriend]
private protected override CalibratorTransformer<NaiveCalibrator> Create(IHostEnvironment env, NaiveCalibrator calibrator)
=> new NaiveCalibratorTransformer(env, calibrator);
private protected override CalibratorTransformer<NaiveCalibrator> Create(IHostEnvironment env, NaiveCalibrator calibrator, string scoreColumnName)
=> new NaiveCalibratorTransformer(env, calibrator, scoreColumnName);
}

/// <summary>
Expand All @@ -436,8 +452,8 @@ public sealed class NaiveCalibratorTransformer : CalibratorTransformer<NaiveCali
{
internal const string LoadName = "NaiveCalibratTransf";

internal NaiveCalibratorTransformer(IHostEnvironment env, NaiveCalibrator calibrator)
: base(env, calibrator, LoadName)
internal NaiveCalibratorTransformer(IHostEnvironment env, NaiveCalibrator calibrator, string scoreColumnName)
: base(env, calibrator, LoadName, scoreColumnName)
{
}

Expand Down Expand Up @@ -474,8 +490,8 @@ internal IsotonicCalibratorEstimator(IHostEnvironment env,
}

[BestFriend]
private protected override CalibratorTransformer<IsotonicCalibrator> Create(IHostEnvironment env, IsotonicCalibrator calibrator)
=> new IsotonicCalibratorTransformer(env, calibrator);
private protected override CalibratorTransformer<IsotonicCalibrator> Create(IHostEnvironment env, IsotonicCalibrator calibrator, string scoreColumnName)
=> new IsotonicCalibratorTransformer(env, calibrator, scoreColumnName);

}

Expand All @@ -486,8 +502,8 @@ public sealed class IsotonicCalibratorTransformer : CalibratorTransformer<Isoton
{
internal const string LoadName = "PavCalibratTransf";

internal IsotonicCalibratorTransformer(IHostEnvironment env, IsotonicCalibrator calibrator)
: base(env, calibrator, LoadName)
internal IsotonicCalibratorTransformer(IHostEnvironment env, IsotonicCalibrator calibrator, string scoreColumnName)
: base(env, calibrator, LoadName, scoreColumnName)
{
}

Expand Down
23 changes: 23 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ class PlattModelInput
public float Score { get; set; }
}

class PlattModelInput2
{
public bool Label { get; set; }
public float ScoreX { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
{
for (int i = 0; i < 100; i++)
Expand All @@ -313,6 +319,14 @@ static IEnumerable<PlattModelInput> PlattGetData()
}
}

static IEnumerable<PlattModelInput2> PlattGetData2()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
Expand All @@ -327,6 +341,15 @@ public void PlattCalibratorOnnxConversionTest2()

TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2());

var pipeline2 = mlContext.BinaryClassification.Calibrators
.Platt(scoreColumnName: "ScoreX");
var onnxFileName2 = $"{pipeline2}.onnx";

TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

Expand Down
139 changes: 139 additions & 0 deletions test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
Expand Down Expand Up @@ -137,5 +138,143 @@ private void CheckValidCalibratedData(IDataView scoredData, ITransformer transfo
}
}

/// <summary>
/// Test to confirm calibrator estimators work with classes
/// where order of label and score columns are reversed, and
/// where name of score column is different than the default.
/// </summary>
[Fact]
public void TestNonStandardCalibratorEstimatorClasses()
{
var mlContext = new MLContext(0);
// Store different possible variations of calibrator data classes.
IDataView[] dataArray = new IDataView[]
{
mlContext.Data.LoadFromEnumerable<CalibratorTestInputReversedOrder>(
new CalibratorTestInputReversedOrder[]
{
new CalibratorTestInputReversedOrder { Score = 10, Label = true },
new CalibratorTestInputReversedOrder { Score = 15, Label = false }
}),
mlContext.Data.LoadFromEnumerable<CalibratorTestInputUniqueScoreColumnName>(
new CalibratorTestInputUniqueScoreColumnName[]
{
new CalibratorTestInputUniqueScoreColumnName { Label = true, ScoreX = 10 },
new CalibratorTestInputUniqueScoreColumnName { Label = false, ScoreX = 15 }
}),
mlContext.Data.LoadFromEnumerable<CalibratorTestInputReversedOrderAndUniqueScoreColumnName>(
new CalibratorTestInputReversedOrderAndUniqueScoreColumnName[]
{
new CalibratorTestInputReversedOrderAndUniqueScoreColumnName { ScoreX = 10, Label = true },
new CalibratorTestInputReversedOrderAndUniqueScoreColumnName { ScoreX = 15, Label = false }
})
};

// When label and/or score columns are different from their default names ("Label" and "Score", respectively), they
// need to be manually defined as done below.
// Successful training of estimators and transforming with transformers indicate correct label and score columns
// have been found.
for (int i = 0; i < dataArray.Length; i++)
{
// Test PlattCalibratorEstimator
var calibratorPlattEstimator = new PlattCalibratorEstimator(Env,
scoreColumnName: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
var calibratorPlattTransformer = calibratorPlattEstimator.Fit(dataArray[i]);
calibratorPlattTransformer.Transform(dataArray[i]);

// Test FixedPlattCalibratorEstimator
var calibratorFixedPlattEstimator = new FixedPlattCalibratorEstimator(Env,
scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
var calibratorFixedPlattTransformer = calibratorFixedPlattEstimator.Fit(dataArray[i]);
calibratorFixedPlattTransformer.Transform(dataArray[i]);

// Test NaiveCalibratorEstimator
var calibratorNaiveEstimator = new NaiveCalibratorEstimator(Env,
scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
var calibratorNaiveTransformer = calibratorNaiveEstimator.Fit(dataArray[i]);
calibratorNaiveTransformer.Transform(dataArray[i]);

// Test IsotonicCalibratorEstimator
var calibratorIsotonicEstimator = new IsotonicCalibratorEstimator(Env,
scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score);
var calibratorIsotonicTransformer = calibratorIsotonicEstimator.Fit(dataArray[i]);
calibratorIsotonicTransformer.Transform(dataArray[i]);
}
}

/// <summary>
/// Test class where the column order of the label and score
/// columns are reversed (by default, label column is before
/// that of score column).
/// </summary>
private sealed class CalibratorTestInputReversedOrder
{
public float Score { get; set; }
public bool Label { get; set; }
}

/// <summary>
/// Test class where name of score column is different than
/// the default column name of "Score".
/// </summary>
private sealed class CalibratorTestInputUniqueScoreColumnName
{
public bool Label { get; set; }
public float ScoreX { get; set; }
}

/// <summary>
/// Test class where the column order of the label and score
/// columns are reversed (by default, label column is before
/// that of score column), and where name of score column is
/// different than the default column name of "Score".
/// </summary>
private sealed class CalibratorTestInputReversedOrderAndUniqueScoreColumnName
{
public float ScoreX { get; set; }
public bool Label { get; set; }
}

/// <summary>
/// Test to check backwards compatibility of calibrator estimators
/// trained before the current version of VerWritten: 0x00010001.
/// </summary>
[Fact]
public void TestCalibratorEstimatorBackwardsCompatibility()
{
// The legacy model being loaded below was trained and saved with
// version as such:
/*
* var mlContext = new MLContext(seed: 1);
* var calibratorTestData = GetCalibratorTestData();
* var plattCalibratorEstimator = new PlattCalibratorEstimator(Env);
* var plattCalibratorTransformer = plattCalibratorEstimator.Fit(calibratorTestData.ScoredData);
* mlContext.Model.Save(plattCalibratorTransformer, calibratorTestData.ScoredData.Schema, "calibrator-model_VerWritten_0x00010001xyz.zip");
*/

var modelPath = GetDataPath("backcompat", "Calibrator_Model_VerWritten_0x00010001.zip");
ITransformer oldPlattCalibratorTransformer;
using (var fs = File.OpenRead(modelPath))
oldPlattCalibratorTransformer = ML.Model.Load(fs, out var schema);

var calibratorTestData = GetCalibratorTestData();
var newPlattCalibratorEstimator = new PlattCalibratorEstimator(Env);
var newPlattCalibratorTransformer = newPlattCalibratorEstimator.Fit(calibratorTestData.ScoredData);

// Check that both models produce the same output
var oldCalibratedData = oldPlattCalibratorTransformer.Transform(calibratorTestData.ScoredData).Preview();
var newCalibratedData = newPlattCalibratorTransformer.Transform(calibratorTestData.ScoredData).Preview();

// Check first that the produced schemas and outputs are of the same size
Assert.True(oldCalibratedData.RowView.Length == newCalibratedData.RowView.Length);
Assert.True(oldCalibratedData.ColumnView.Length == newCalibratedData.ColumnView.Length);

// Then check the produced probabilities (5th value corresponds to probabilities) for
// equality, within rounding error.
for (int i = 0; i < 10; i++)
Assert.True((float)oldCalibratedData.RowView[i].Values[5].Value == (float)newCalibratedData.RowView[i].Values[5].Value);

Done();
}
}
}
Binary file not shown.