diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 2bfa94ffd8..4ddf832782 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -187,9 +187,6 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx) Host.Assert(Bindable is IBindableCanSaveOnnx); Host.Assert(Bindings.InfoCount >= 2); - if (!ctx.ContainsColumn(DefaultColumnNames.Features)) - return; - base.SaveAsOnnxCore(ctx); int delta = Bindings.DerivedColumnCount; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 9f2482961f..c14637e745 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1959,6 +1959,144 @@ public void SelectColumnsOnnxTest() Done(); } + private class BreastCancerMulticlassExampleNonDefaultColNames + { + [LoadColumn(1)] + public string Label; + + [LoadColumn(2, 9), VectorType(8)] + public float[] MyFeatureVector; + } + + private class BreastCancerBinaryClassificationNonDefaultColNames + { + [LoadColumn(0)] + public bool Label; + + [LoadColumn(2, 9), VectorType(8)] + public float[] MyFeatureVector; + } + + [Fact] + public void NonDefaultColNamesBinaryClassificationOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + string dataPath = GetDataPath("breast-cancer.txt"); + // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). + var dataView = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: true); + List> estimators = new List>() + { + mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.FastForest("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.FastTree("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.LinearSvm("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.Prior(), + mlContext.BinaryClassification.Trainers.SdcaLogisticRegression("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.SdcaNonCalibrated("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.SgdCalibrated("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.SgdNonCalibrated("Label", "MyFeatureVector"), + mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression("Label", "MyFeatureVector"), + }; + if (Environment.Is64BitProcess) + { + estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm("Label", "MyFeatureVector")); + } + + var initialPipeline = mlContext.Transforms.ReplaceMissingValues("MyFeatureVector"). + Append(mlContext.Transforms.NormalizeMinMax("MyFeatureVector")); + foreach (var estimator in estimators) + { + var pipeline = initialPipeline.Append(estimator); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = $"{estimator.ToString()}.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + + // Compare model scores produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedColumns("Score", "Score", transformedData, onnxResult, 3); //compare scores + CompareSelectedColumns("PredictedLabel", "PredictedLabel", transformedData, onnxResult); //compare predicted labels + } + } + Done(); + } + + [Fact] + public void NonDefaultColNamesMultiClassificationOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + + string dataPath = GetDataPath("breast-cancer.txt"); + var dataView = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: true); + + List> estimators = new List>() + { + mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy("Label", "MyFeatureVector"), + mlContext.MulticlassClassification.Trainers.NaiveBayes("Label", "MyFeatureVector"), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "MyFeatureVector")), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "MyFeatureVector"), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression("Label", "MyFeatureVector")), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression("Label", "MyFeatureVector"), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LinearSvm("Label", "MyFeatureVector")), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.LinearSvm("Label", "MyFeatureVector"), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.FastForest("Label", "MyFeatureVector")), + mlContext.MulticlassClassification.Trainers.OneVersusAll( + mlContext.BinaryClassification.Trainers.FastForest("Label", "MyFeatureVector"), useProbabilities:false), + mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "MyFeatureVector"), + mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated("Label", "MyFeatureVector") + }; + + if (Environment.Is64BitProcess) + { + estimators.Add(mlContext.MulticlassClassification.Trainers.LightGbm("Label", "MyFeatureVector")); + } + + var initialPipeline = mlContext.Transforms.ReplaceMissingValues("MyFeatureVector") + .Append(mlContext.Transforms.NormalizeMinMax("MyFeatureVector")) + .Append(mlContext.Transforms.Conversion.MapValueToKey("Label")); + + foreach (var estimator in estimators) + { + var pipeline = initialPipeline.Append(estimator); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + var onnxFileName = $"{estimator.ToString()}.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + // Compare results produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedColumns("PredictedLabel", "PredictedLabel", transformedData, onnxResult); + CompareSelectedColumns("Score", "Score", transformedData, onnxResult, 4); + } + } + Done(); + } + private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6) { var leftColumn = left.Schema[leftColumnName];