diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 87ba1278b4..2f947fab24 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1750,7 +1750,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0], slopVar }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), ""); opType = "Add"; - var betaVar = ctx.AddInitializer(-0.0000001f, "Slope"); + var betaVar = ctx.AddInitializer((float)(-Offset), "Offset"); var linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); node = ctx.CreateNode(opType, new[] { mulNodeOutput, betaVar }, new[] { linearOutput }, ctx.GetNodeName(opType), ""); diff --git a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs index 2579f086be..46105e242a 100644 --- a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs +++ b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs @@ -8,6 +8,7 @@ using Microsoft.ML; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers; @@ -220,13 +221,15 @@ private protected VersionInfo GetVersionInfo() loaderAssemblyName: typeof(CalibratorTransformer<>).Assembly.FullName); } - private sealed class Mapper : MapperBase + private sealed class Mapper : MapperBase, ISaveAsOnnx where TCalibrator : class, ICalibrator { private TCalibrator _calibrator; private readonly int _scoreColIndex; private CalibratorTransformer _parent; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false; + internal Mapper(CalibratorTransformer parent, TCalibrator calibrator, DataViewSchema inputSchema) : base(parent.Host, inputSchema, parent) { @@ -243,6 +246,20 @@ private protected override Func GetDependenciesCore(Func a private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) + { + var scoreName = InputSchema[_scoreColIndex].Name; + var probabilityName = GetOutputColumnsCore()[0].Name; + Host.CheckValue(ctx, nameof(ctx)); + if (_calibrator is ISingleCanSaveOnnx onnx) + { + Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX."); + scoreName = ctx.GetVariableName(scoreName); + probabilityName = ctx.AddIntermediateVariable(NumberDataViewType.Single, probabilityName); + onnx.SaveAsOnnx(ctx, new[] { scoreName, probabilityName }, ""); // No need for featureColumn + } + } + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var builder = new DataViewSchema.Annotations.Builder(); diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index 5f6a356b72..be7ad46506 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -393,7 +393,7 @@ { "input": [ "MulNodeOutput", - "Slope0" + "Offset" ], "output": [ "linearOutput" @@ -489,9 +489,9 @@ { "dataType": 1, "floatData": [ - -1E-07 + 0 ], - "name": "Slope0" + "name": "Offset" } ], "input": [ diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index bf0715cc25..1f7f5ad1fd 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -719,7 +719,7 @@ { "input": [ "MulNodeOutput", - "Slope0" + "Offset" ], "output": [ "linearOutput" @@ -815,9 +815,9 @@ { "dataType": 1, "floatData": [ - -1E-07 + 0 ], - "name": "Slope0" + "name": "Offset" } ], "input": [ diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt index 2df8da4d81..02446da076 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt @@ -346,7 +346,7 @@ { "input": [ "MulNodeOutput", - "Slope0" + "Offset" ], "output": [ "linearOutput" @@ -482,9 +482,9 @@ { "dataType": 1, "floatData": [ - -1E-07 + 0 ], - "name": "Slope0" + "name": "Offset" } ], "input": [ diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 240b7c0c87..fffcadb329 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -343,6 +343,109 @@ public void TestVectorWhiteningOnnxConversionTest() Done(); } + [Fact] + public void PlattCalibratorOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + string dataPath = GetDataPath("breast-cancer.txt"); + // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). + var dataView = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: true); + List> estimators = new List>() + { + mlContext.BinaryClassification.Trainers.AveragedPerceptron(), + mlContext.BinaryClassification.Trainers.FastForest(), + mlContext.BinaryClassification.Trainers.FastTree(), + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), + mlContext.BinaryClassification.Trainers.LinearSvm(), + mlContext.BinaryClassification.Trainers.Prior(), + mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(), + mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(), + mlContext.BinaryClassification.Trainers.SgdCalibrated(), + mlContext.BinaryClassification.Trainers.SgdNonCalibrated(), + mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), + }; + if (Environment.Is64BitProcess) + { + estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm()); + } + + var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features"). + Append(mlContext.Transforms.NormalizeMinMax("Features")); + foreach (var estimator in estimators) + { + var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt()); + var model = pipeline.Fit(dataView); + var outputSchema = model.GetOutputSchema(dataView.Schema); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + // Compare model scores produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + var onnxFileName = $"{estimator.ToString()}-WithPlattCalibrator.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores + CompareSelectedScalarColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels + CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities + } + } + Done(); + } + + class PlattModelInput + { + public bool Label { get; set; } + public float Score { get; set; } + } + + static IEnumerable PlattGetData() + { + for (int i = 0; i < 100; i++) + { + yield return new PlattModelInput { Score = i, Label = i % 2 == 0 }; + } + } + + [Fact] + public void PlattCalibratorOnnxConversionTest2() + { + // Test PlattCalibrator without any binary prediction trainer + var mlContext = new MLContext(seed: 0); + + IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData()); + + var calibratorEstimator = mlContext.BinaryClassification.Calibrators + .Platt(); + + var calibratorTransformer = calibratorEstimator.Fit(data); + var transformedData = calibratorTransformer.Transform(data); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(calibratorTransformer, data); + + // Compare model scores produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + var onnxFileName = $"{calibratorTransformer.ToString()}.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(data); + var onnxResult = onnxTransformer.Transform(data); + CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities + } + Done(); + } + private class DataPoint { [VectorType(3)]