diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 31933dd680..d86aaa320e 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -370,6 +370,7 @@ public sealed class BinaryPredictionTransformer : SingleFeaturePredictio { internal readonly string ThresholdColumn; internal readonly float Threshold; + internal readonly string LabelColumnName; [BestFriend] internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, @@ -383,6 +384,17 @@ internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataVie SetScorer(); } + internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, string labelColumn, + float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), model, inputSchema, featureColumn) + { + Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); + Threshold = threshold; + ThresholdColumn = thresholdColumn; + LabelColumnName = labelColumn; + + SetScorer(); + } internal BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), ctx) { @@ -409,7 +421,7 @@ private void InitializationLogic(ModelLoadContext ctx, out float threshold, out private void SetScorer() { - var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumnName); + var schema = new RoleMappedSchema(TrainSchema, LabelColumnName, FeatureColumnName); var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index b7f6c4da87..7419c13d42 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -426,9 +426,16 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema var mapper = ValueMapper as ISingleCanSaveOnnx; Contracts.CheckValue(mapper, nameof(mapper)); - Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) == 3); // Predicted Label, Score and Probablity. + // Prior doesn't have a feature column and uses the training label column to determine predicted labels + if (!schema.Feature.HasValue) + { + Contracts.Assert(schema.Label.HasValue); + var labelColumnName = schema.Label.Value.Name; + return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(labelColumnName)); + } + var featName = schema.Feature.Value.Name; if (!ctx.ContainsColumn(featName)) return false; @@ -511,7 +518,7 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto public IEnumerable> GetInputColumnRoles() { - yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name); + yield return (InputRoleMappedSchema.Feature.HasValue) ? RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name) : RoleMappedSchema.ColumnRole.Label.Bind(InputRoleMappedSchema.Label?.Name); } private Delegate[] CreateGetters(DataViewRow input, bool[] active) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs index 9f90651436..6fd462528c 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Trainers; @@ -240,9 +241,9 @@ internal PriorTrainer(IHostEnvironment env, String labelColumn, String weightCol /// public BinaryPredictionTransformer Fit(IDataView input) { - RoleMappedData trainRoles = new RoleMappedData(input, feature: null, label: _labelColumnName, weight: _weightColumnName); + RoleMappedData trainRoles = new RoleMappedData(input, label: _labelColumnName, feature: null, weight: _weightColumnName); var pred = ((ITrainer)this).Train(new TrainContext(trainRoles)); - return new BinaryPredictionTransformer(_host, pred, input.Schema, featureColumn: null); + return new BinaryPredictionTransformer(_host, pred, input.Schema, featureColumn: null, labelColumn: _labelColumnName); } private PriorModelParameters Train(TrainContext context) @@ -330,7 +331,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) public sealed class PriorModelParameters : ModelParametersBase, IDistPredictorProducing, - IValueMapperDist + IValueMapperDist, ISingleCanSaveOnnx { internal const string LoaderSignature = "PriorPredictor"; private static VersionInfo GetVersionInfo() @@ -346,6 +347,7 @@ private static VersionInfo GetVersionInfo() private readonly float _prob; private readonly float _raw; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; /// /// Instantiates a model that returns the prior probability of the positive class in the training set. @@ -397,6 +399,38 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(_prob); } + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string labelColumn) + { + Host.CheckValue(ctx, nameof(ctx)); + Host.Check(Utils.Size(outputs) >= 3); + + string scoreVarName = outputs[1]; + string probVarName = outputs[2]; + var prob = ctx.AddInitializer(_prob, "probability"); + var score = ctx.AddInitializer(_raw, "score"); + + var xorOutput = ctx.AddIntermediateVariable(null, "XorOutput", true); + string opType = "Xor"; + ctx.CreateNode(opType, new[] { labelColumn, labelColumn }, new[] { xorOutput }, ctx.GetNodeName(opType), ""); + + var notOutput = ctx.AddIntermediateVariable(null, "NotOutput", true); + opType = "Not"; + ctx.CreateNode(opType, xorOutput, notOutput, ctx.GetNodeName(opType), ""); + + var castOutput = ctx.AddIntermediateVariable(null, "CastOutput", true); + opType = "Cast"; + var node = ctx.CreateNode(opType, notOutput, castOutput, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); + node.AddAttribute("to", t); + + opType = "Mul"; + ctx.CreateNode(opType, new[] { castOutput, prob }, new[] { probVarName }, ctx.GetNodeName(opType), ""); + + opType = "Mul"; + ctx.CreateNode(opType, new[] { castOutput, score }, new[] { scoreVarName }, ctx.GetNodeName(opType), ""); + return true; + } + private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification; private readonly DataViewType _inputType; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 0b1c32d2ae..12688e1392 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -270,6 +270,7 @@ public void BinaryClassificationTrainersOnnxConversionTest() mlContext.BinaryClassification.Trainers.FastTree(), mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), mlContext.BinaryClassification.Trainers.LinearSvm(), + mlContext.BinaryClassification.Trainers.Prior(), mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(), mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(), mlContext.BinaryClassification.Trainers.SgdCalibrated(), @@ -301,8 +302,8 @@ public void BinaryClassificationTrainersOnnxConversionTest() var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); - CompareSelectedScalarColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); + CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores + CompareSelectedScalarColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels } } Done();