diff --git a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs index 46105e242a..34415ed0de 100644 --- a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs +++ b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs @@ -143,7 +143,7 @@ public CalibratorTransformer Fit(IDataView input) { var calibrator = (TICalibrator)CalibratorUtils.TrainCalibrator(Host, ch, _calibratorTrainer, input, LabelColumn.Name, ScoreColumn.Name, WeightColumn.Name); - return Create(Host, calibrator); + return Create(Host, calibrator, ScoreColumn.Name); } } @@ -151,7 +151,7 @@ public CalibratorTransformer Fit(IDataView input) /// Implemented by deriving classes that create a concrete calibrator. /// [BestFriend] - private protected abstract CalibratorTransformer Create(IHostEnvironment env, TICalibrator calibrator); + private protected abstract CalibratorTransformer Create(IHostEnvironment env, TICalibrator calibrator, string scoreColumnName); } /// @@ -167,12 +167,14 @@ public abstract class CalibratorTransformer : RowToRowTransformerB { private readonly TICalibrator _calibrator; private readonly string _loaderSignature; + private readonly string _scoreColumnName; - private protected CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature) + private protected CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature, string scoreColumnName) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer))) { _loaderSignature = loaderSignature; _calibrator = calibrator; + _scoreColumnName = scoreColumnName; } // Factory method for SignatureLoadModel. @@ -186,7 +188,16 @@ private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext c // *** Binary format *** // model: _calibrator + // scoreColumnName: _scoreColumnName ctx.LoadModel(env, out _calibrator, "Calibrator"); + if (ctx.Header.ModelVerWritten >= 0x00010002) + { + _scoreColumnName = ctx.LoadString(); + } + else + { + _scoreColumnName = DefaultColumnNames.Score; + } } string ISingleFeaturePredictionTransformer.FeatureColumnName => DefaultColumnNames.Score; @@ -205,16 +216,19 @@ private protected override void SaveModel(ModelSaveContext ctx) // *** Binary format *** // model: _calibrator + // scoreColumnName: _scoreColumnName ctx.SaveModel(_calibrator, "Calibrator"); + ctx.SaveString(_scoreColumnName); } - private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, _calibrator, schema); + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, _calibrator, schema, _scoreColumnName); private protected VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "CALTRANS", - verWrittenCur: 0x00010001, // Initial + // verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Added _scoreColumnName verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: _loaderSignature, @@ -227,18 +241,20 @@ private sealed class Mapper : MapperBase, ISaveAsOnnx private TCalibrator _calibrator; private readonly int _scoreColIndex; private CalibratorTransformer _parent; + private string _scoreColumnName; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false; - internal Mapper(CalibratorTransformer parent, TCalibrator calibrator, DataViewSchema inputSchema) : + internal Mapper(CalibratorTransformer parent, TCalibrator calibrator, DataViewSchema inputSchema, string scoreColumnName) : base(parent.Host, inputSchema, parent) { _calibrator = calibrator; _parent = parent; - _scoreColIndex = inputSchema.GetColumnOrNull(DefaultColumnNames.Score)?.Index ?? -1; + _scoreColumnName = scoreColumnName; + _scoreColIndex = inputSchema.GetColumnOrNull(_scoreColumnName)?.Index ?? -1; - parent.Host.Check(_scoreColIndex > 0, "The data to calibrate contains no 'Score' column"); + parent.Host.Check(_scoreColIndex >= 0, "The data to calibrate contains no \'" + scoreColumnName + "\' column."); } private protected override Func GetDependenciesCore(Func activeOutput) @@ -337,8 +353,8 @@ internal PlattCalibratorEstimator(IHostEnvironment env, } [BestFriend] - private protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) - => new PlattCalibratorTransformer(env, calibrator); + private protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator, string scoreColumnName) + => new PlattCalibratorTransformer(env, calibrator, scoreColumnName); } /// @@ -374,8 +390,8 @@ internal FixedPlattCalibratorEstimator(IHostEnvironment env, } [BestFriend] - private protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) - => new PlattCalibratorTransformer(env, calibrator); + private protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator, string scoreColumnName) + => new PlattCalibratorTransformer(env, calibrator, scoreColumnName); } /// @@ -385,8 +401,8 @@ public sealed class PlattCalibratorTransformer : CalibratorTransformer Create(IHostEnvironment env, NaiveCalibrator calibrator) - => new NaiveCalibratorTransformer(env, calibrator); + private protected override CalibratorTransformer Create(IHostEnvironment env, NaiveCalibrator calibrator, string scoreColumnName) + => new NaiveCalibratorTransformer(env, calibrator, scoreColumnName); } /// @@ -436,8 +452,8 @@ public sealed class NaiveCalibratorTransformer : CalibratorTransformer Create(IHostEnvironment env, IsotonicCalibrator calibrator) - => new IsotonicCalibratorTransformer(env, calibrator); + private protected override CalibratorTransformer Create(IHostEnvironment env, IsotonicCalibrator calibrator, string scoreColumnName) + => new IsotonicCalibratorTransformer(env, calibrator, scoreColumnName); } @@ -486,8 +502,8 @@ public sealed class IsotonicCalibratorTransformer : CalibratorTransformer PlattGetData() { for (int i = 0; i < 100; i++) @@ -313,6 +319,14 @@ static IEnumerable PlattGetData() } } + static IEnumerable PlattGetData2() + { + for (int i = 0; i < 100; i++) + { + yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 }; + } + } + [Fact] public void PlattCalibratorOnnxConversionTest2() { @@ -327,6 +341,15 @@ public void PlattCalibratorOnnxConversionTest2() TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); + // Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer + IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2()); + + var pipeline2 = mlContext.BinaryClassification.Calibrators + .Platt(scoreColumnName: "ScoreX"); + var onnxFileName2 = $"{pipeline2}.onnx"; + + TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); + Done(); } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs index 3d3ea98fc3..4be66af65a 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.IO; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.Trainers; @@ -137,5 +138,143 @@ private void CheckValidCalibratedData(IDataView scoredData, ITransformer transfo } } + /// + /// Test to confirm calibrator estimators work with classes + /// where order of label and score columns are reversed, and + /// where name of score column is different than the default. + /// + [Fact] + public void TestNonStandardCalibratorEstimatorClasses() + { + var mlContext = new MLContext(0); + // Store different possible variations of calibrator data classes. + IDataView[] dataArray = new IDataView[] + { + mlContext.Data.LoadFromEnumerable( + new CalibratorTestInputReversedOrder[] + { + new CalibratorTestInputReversedOrder { Score = 10, Label = true }, + new CalibratorTestInputReversedOrder { Score = 15, Label = false } + }), + mlContext.Data.LoadFromEnumerable( + new CalibratorTestInputUniqueScoreColumnName[] + { + new CalibratorTestInputUniqueScoreColumnName { Label = true, ScoreX = 10 }, + new CalibratorTestInputUniqueScoreColumnName { Label = false, ScoreX = 15 } + }), + mlContext.Data.LoadFromEnumerable( + new CalibratorTestInputReversedOrderAndUniqueScoreColumnName[] + { + new CalibratorTestInputReversedOrderAndUniqueScoreColumnName { ScoreX = 10, Label = true }, + new CalibratorTestInputReversedOrderAndUniqueScoreColumnName { ScoreX = 15, Label = false } + }) + }; + + // When label and/or score columns are different from their default names ("Label" and "Score", respectively), they + // need to be manually defined as done below. + // Successful training of estimators and transforming with transformers indicate correct label and score columns + // have been found. + for (int i = 0; i < dataArray.Length; i++) + { + // Test PlattCalibratorEstimator + var calibratorPlattEstimator = new PlattCalibratorEstimator(Env, + scoreColumnName: i > 0 ? "ScoreX" : DefaultColumnNames.Score); + var calibratorPlattTransformer = calibratorPlattEstimator.Fit(dataArray[i]); + calibratorPlattTransformer.Transform(dataArray[i]); + + // Test FixedPlattCalibratorEstimator + var calibratorFixedPlattEstimator = new FixedPlattCalibratorEstimator(Env, + scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score); + var calibratorFixedPlattTransformer = calibratorFixedPlattEstimator.Fit(dataArray[i]); + calibratorFixedPlattTransformer.Transform(dataArray[i]); + + // Test NaiveCalibratorEstimator + var calibratorNaiveEstimator = new NaiveCalibratorEstimator(Env, + scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score); + var calibratorNaiveTransformer = calibratorNaiveEstimator.Fit(dataArray[i]); + calibratorNaiveTransformer.Transform(dataArray[i]); + + // Test IsotonicCalibratorEstimator + var calibratorIsotonicEstimator = new IsotonicCalibratorEstimator(Env, + scoreColumn: i > 0 ? "ScoreX" : DefaultColumnNames.Score); + var calibratorIsotonicTransformer = calibratorIsotonicEstimator.Fit(dataArray[i]); + calibratorIsotonicTransformer.Transform(dataArray[i]); + } + } + + /// + /// Test class where the column order of the label and score + /// columns are reversed (by default, label column is before + /// that of score column). + /// + private sealed class CalibratorTestInputReversedOrder + { + public float Score { get; set; } + public bool Label { get; set; } + } + + /// + /// Test class where name of score column is different than + /// the default column name of "Score". + /// + private sealed class CalibratorTestInputUniqueScoreColumnName + { + public bool Label { get; set; } + public float ScoreX { get; set; } + } + + /// + /// Test class where the column order of the label and score + /// columns are reversed (by default, label column is before + /// that of score column), and where name of score column is + /// different than the default column name of "Score". + /// + private sealed class CalibratorTestInputReversedOrderAndUniqueScoreColumnName + { + public float ScoreX { get; set; } + public bool Label { get; set; } + } + + /// + /// Test to check backwards compatibility of calibrator estimators + /// trained before the current version of VerWritten: 0x00010001. + /// + [Fact] + public void TestCalibratorEstimatorBackwardsCompatibility() + { + // The legacy model being loaded below was trained and saved with + // version as such: + /* + * var mlContext = new MLContext(seed: 1); + * var calibratorTestData = GetCalibratorTestData(); + * var plattCalibratorEstimator = new PlattCalibratorEstimator(Env); + * var plattCalibratorTransformer = plattCalibratorEstimator.Fit(calibratorTestData.ScoredData); + * mlContext.Model.Save(plattCalibratorTransformer, calibratorTestData.ScoredData.Schema, "calibrator-model_VerWritten_0x00010001xyz.zip"); + */ + + var modelPath = GetDataPath("backcompat", "Calibrator_Model_VerWritten_0x00010001.zip"); + ITransformer oldPlattCalibratorTransformer; + using (var fs = File.OpenRead(modelPath)) + oldPlattCalibratorTransformer = ML.Model.Load(fs, out var schema); + + var calibratorTestData = GetCalibratorTestData(); + var newPlattCalibratorEstimator = new PlattCalibratorEstimator(Env); + var newPlattCalibratorTransformer = newPlattCalibratorEstimator.Fit(calibratorTestData.ScoredData); + + // Check that both models produce the same output + var oldCalibratedData = oldPlattCalibratorTransformer.Transform(calibratorTestData.ScoredData).Preview(); + var newCalibratedData = newPlattCalibratorTransformer.Transform(calibratorTestData.ScoredData).Preview(); + + // Check first that the produced schemas and outputs are of the same size + Assert.True(oldCalibratedData.RowView.Length == newCalibratedData.RowView.Length); + Assert.True(oldCalibratedData.ColumnView.Length == newCalibratedData.ColumnView.Length); + + // Then check the produced probabilities (5th value corresponds to probabilities) for + // equality, within rounding error. + for (int i = 0; i < 10; i++) + Assert.True((float)oldCalibratedData.RowView[i].Values[5].Value == (float)newCalibratedData.RowView[i].Values[5].Value); + + Done(); + } } } diff --git a/test/data/backcompat/Calibrator_Model_VerWritten_0x00010001.zip b/test/data/backcompat/Calibrator_Model_VerWritten_0x00010001.zip new file mode 100644 index 0000000000..0078d0a8b6 Binary files /dev/null and b/test/data/backcompat/Calibrator_Model_VerWritten_0x00010001.zip differ