diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 913db98559..36d6ab45dc 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -43,6 +43,12 @@ internal abstract class OnnxContext /// Whether the column is mapped in this context public abstract bool ContainsColumn(string colName); + /// + /// Check the required OpSet version satisfies our requirement + /// + /// + public abstract void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName); + /// /// Stops tracking a column. /// diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 9265e60424..d1948c0cb4 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1745,6 +1745,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu _host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames)); _host.Check(Utils.Size(scoreProbablityColumnNames) == 2); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, "PlattCalibrator"); + // The Affine operator is no longer supported in the v11 opset. // So we have to decompose it using Mul and Add string opType = "Mul"; diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index ba70941461..e4712a97d8 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -902,6 +902,9 @@ public void SaveAsOnnx(OnnxContext ctx) Host.CheckValue(ctx, nameof(ctx)); Contracts.Assert(CanSaveOnnx(ctx)); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + for (int iinfo = 0; iinfo < _columns.Length; ++iinfo) { var colInfo = _parent._columns[iinfo]; diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index a8b9b8dc18..c79a9b5bbe 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -237,6 +237,9 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() public void SaveAsOnnx(OnnxContext ctx) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + var opType = "Identity"; foreach (var column in _columns) diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index e65eba6734..359fad24b1 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -732,6 +732,9 @@ IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView ne public void SaveAsOnnx(OnnxContext ctx) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + var outputToInputMap = _mapper.OutputToInputMap; for(int i = 0; i < outputToInputMap.Length; i++) { diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 5bf272b50d..a5e7fe8876 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1353,6 +1353,9 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder) private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable) { + const int minimumOpSetVersion = 11; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + string castOutput; string isGreaterThanZeroOutput = ""; OnnxNode castNode; diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index 5ffc82b1a1..47705d021a 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -500,6 +500,9 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken) public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + string opType; // Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index cf04031325..f7a3fcdbc7 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -689,6 +689,9 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + var dim = info.TypeSrc.GetValueCount(); string opType = "Cast"; diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 0d3ee95241..ace25ae86a 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -822,6 +822,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColumnOptions info, stri Contracts.Assert(_parent.Columns[iinfo] == info); Contracts.Assert(CanSaveOnnx(ctx)); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + int valueCount = info.InputType.GetValueCount(); if (valueCount == 0) return false; diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index 0087e6546e..0dbef23e73 100644 --- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -895,6 +895,9 @@ public void SaveAsOnnx(OnnxContext ctx) public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + string opType; var slots = _slotDropper[iinfo].GetPreservedSlots(); // vector column is not suppressed diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 0ab4e0520a..4af5ea8344 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -497,6 +497,9 @@ public void SaveAsOnnx(OnnxContext ctx) private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + var opType = "Cast"; var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), ""); var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType(); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 297b475159..7947ddaaf4 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -812,6 +812,9 @@ private void CastInputToFloat(OnnxContext ctx, out OnnxNode node, out long[] private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + OnnxNode node; long[] termIds; string opType = "LabelEncoder"; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 766e7eb3d1..b7f33239fe 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3032,6 +3032,9 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, Host.CheckValue(ctx, nameof(ctx)); Host.Check(Utils.Size(outputNames) >= 1); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, "TreeEnsembleModelParameters"); + //Nodes. var nodesTreeids = new List(); var nodesIds = new List(); diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 4cedfa1d0b..68383d8202 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -534,6 +534,9 @@ internal static FastTreeTweedieModelParameters Create(IHostEnvironment env, Mode bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + // Mapping score to prediction var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true); base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 2957a14937..5186f46813 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -213,6 +213,9 @@ internal static FastForestRegressionModelParameters Create(IHostEnvironment env, bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + // Mapping score to prediction var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true); var numTrees = ctx.AddInitializer((float)TrainedEnsemble.NumTrees, "NumTrees"); diff --git a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs index 2f5c220fe9..849caae820 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs @@ -316,6 +316,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string // v // L [l] <--- ArgMin <--- Y [l, k] + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + // Allocate C, which is a constant tensor in prediction phase var shapeC = new long[] { _centroids.Length, _centroids[0].Length }; var tensorC = new List(); diff --git a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs index 751fb03b02..fdeffa1a8a 100644 --- a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs +++ b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs @@ -641,6 +641,9 @@ public void SaveAsOnnx(OnnxContext ctx) private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + var model = _parent._models[iinfo]; int dimension = _srcTypes[iinfo].GetValueCount(); Host.Assert(model.Length == dimension * dimension); diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs index 98866529f4..5e5bbd669d 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs @@ -14,6 +14,8 @@ namespace Microsoft.ML.Model.OnnxConverter /// internal sealed class OnnxContextImpl : OnnxContext { + private const int CurrentOpSetVersion = 12; + private const int MinimumOpSetVersion = 9; private readonly List _nodes; private readonly List _inputs; // The map from IDataView column names to variable names. @@ -32,9 +34,10 @@ internal sealed class OnnxContextImpl : OnnxContext private readonly string _producerVersion; private readonly long _modelVersion; private readonly OnnxVersion _onnxVersion; + private readonly int _opSetVersion; public OnnxContextImpl(IHostEnvironment env, string name, string producerName, - string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion) + string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion, int opSetVersion = CurrentOpSetVersion) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(OnnxContext)); @@ -55,6 +58,9 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName, _modelVersion = modelVersion; _domain = domain; _onnxVersion = onnxVersion; + _opSetVersion = opSetVersion <= CurrentOpSetVersion ? + opSetVersion >= MinimumOpSetVersion ? opSetVersion : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is lower than the minimum required OpSet version {MinimumOpSetVersion}") : + throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is higher than the current most updated OpSet version {CurrentOpSetVersion}"); } public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); @@ -127,6 +133,12 @@ public override string GetNodeName(string prefix) return GetUniqueName(prefix, _nodeNames.Contains); } + public override void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName) + { + if (_opSetVersion < thisTransformerMinumumOpSetVersion) + throw _host.ExceptParam(nameof(thisTransformerMinumumOpSetVersion), $"Requested OpSet version {_opSetVersion} is lower than {registerTransformerName}'s minimum OpSet version requirement: {thisTransformerMinumumOpSetVersion}"); + } + /// /// Adds a node to the node list of the graph. /// @@ -409,7 +421,7 @@ public override string AddInitializer(IEnumerable values, bool isUint64, /// Makes the ONNX model based on the context. /// public OnnxCSharpToProtoWrapper.ModelProto MakeModel() - => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers); + => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _opSetVersion, _inputs, _outputs, _intermediateValues, _initializers); /// /// Return either "Experimental" or "Stable". The string "Experimental" indicates that some experimental features which are diff --git a/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs b/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs index d81d5e64ee..659829abe4 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs @@ -7,12 +7,23 @@ using Google.Protobuf; using Microsoft.ML.Data; using Microsoft.ML.Model.OnnxConverter; +using Microsoft.ML.Runtime; using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper; namespace Microsoft.ML { public static class OnnxExportExtensions { + private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData) + { + var outputData = transform.Transform(inputData); + LinkedList transforms = null; + using (var ch = env.Start("ONNX conversion")) + { + SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms); + return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null); + } + } /// /// Convert the specified to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object. @@ -26,13 +37,23 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat { var env = catalog.GetEnvironment(); var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable); - var outputData = transform.Transform(inputData); - LinkedList transforms = null; - using (var ch = env.Start("ONNX conversion")) - { - SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms); - return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null); - } + return ConvertToOnnxProtobufCore(env, ctx, transform, inputData); + } + + /// + /// Convert the specified to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object. + /// + /// The class that attached to. + /// The that will be converted into ONNX format. + /// The input of the specified transform. + /// The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12 + /// An ONNX model equivalent to the converted ML.NET model. + [BestFriend] + internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion) + { + var env = catalog.GetEnvironment(); + var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable, opSetVersion); + return ConvertToOnnxProtobufCore(env, ctx, transform, inputData); } /// @@ -45,5 +66,17 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat /// An ONNX model equivalent to the converted ML.NET model. public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream) => ConvertToOnnxProtobuf(catalog, transform, inputData).WriteTo(stream); + + /// + /// Convert the specified to ONNX format and writes to a stream. + /// + /// The class that attached to. + /// The that will be converted into ONNX format. + /// The input of the specified transform. + /// The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12 + /// The stream to write the protobuf model to. + /// An ONNX model equivalent to the converted ML.NET model. + public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) => + ConvertToOnnxProtobuf(catalog, transform, inputData, opSetVersion).WriteTo(stream); } } diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs index afdd20c7bf..af5529c2e0 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs @@ -285,7 +285,7 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List di } public static ModelProto MakeModel(List nodes, string producerName, string name, - string domain, string producerVersion, long modelVersion, List inputs, + string domain, string producerVersion, long modelVersion, int opSetVersion, List inputs, List outputs, List intermediateValues, List initializers) { Contracts.CheckValue(nodes, nameof(nodes)); @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List nodes, string producerName, s model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion; model.ModelVersion = modelVersion; model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 2 }); - model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 11 }); + model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = opSetVersion }); model.Graph = new GraphProto(); var graph = model.Graph; graph.Node.Add(nodes); diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs index 55774feff1..c90e16ee48 100644 --- a/src/Microsoft.ML.PCA/PcaTransformer.cs +++ b/src/Microsoft.ML.PCA/PcaTransformer.cs @@ -625,6 +625,9 @@ private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, { Host.CheckValue(ctx, nameof(ctx)); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + TransformInfo transformInfo = _parent._transformInfos[iinfo]; // When the transformer is loaded from a model file, diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs index 8f1745b2a0..e59096bc88 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs @@ -137,6 +137,10 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, str { Host.CheckValue(ctx, nameof(ctx)); Host.Check(Utils.Size(outputs) >= 1); + + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, "LinearModel"); + string opType = "LinearRegressor"; string scoreVarName = (Utils.Size(outputs) >= 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType)); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 2c141c7520..2cf18dc757 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -978,6 +978,10 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, JToken input) private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); + + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, "MultiClassLogisticRegression"); + Host.Assert(outputs[0] == DefaultColumnNames.PredictedLabel); Host.Assert(outputs[1] == DefaultColumnNames.Score); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs index 963d612ec7..1dfef5ad5d 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs @@ -421,6 +421,9 @@ ValueMapper IValueMapper.GetMapper() /// bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, "MulticlassNaiveBayes"); + float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length]; float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length]; diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 7161a953a4..bfbcde8086 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -660,6 +660,8 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false); string opType = "Concat"; @@ -790,6 +792,8 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { Contracts.Assert(outputNames.Length >= 2); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); string opType; var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true); @@ -910,6 +914,8 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { Contracts.Assert(outputNames.Length >= 2); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs index b3af360868..d7ab83d94d 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs @@ -405,6 +405,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string lab Host.CheckValue(ctx, nameof(ctx)); Host.Check(Utils.Size(outputs) >= 3); + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + string scoreVarName = outputs[1]; string probVarName = outputs[2]; var prob = ctx.AddInitializer(_prob, "probability"); diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index 10f55de931..9d94ba334c 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -627,6 +627,9 @@ public void SaveAsOnnx(OnnxContext ctx) private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + string opType; if ((_norms[iinfo] != LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation) && (_ensureZeroMeans[iinfo] == false)) diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 54cec41759..a0dc85bc97 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -462,6 +462,9 @@ public void SaveAsOnnx(OnnxContext ctx) private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoadName); + var inputType = _infos[iinfo].InputType; Type rawType = (inputType is VectorDataViewType vectorType) ? vectorType.ItemType.RawType : inputType.RawType; diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 76578f4f2d..585eb1dc14 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -867,6 +867,9 @@ public void SaveAsOnnx(OnnxContext ctx) private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoadName); + Type rawType; var type = _infos[iinfo].TypeSrc; if (type is VectorDataViewType vectorType) diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 98b100c15f..9abb68a365 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -538,6 +538,9 @@ public void SaveAsOnnx(OnnxContext ctx) private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, DataViewType columnType) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + Type type = columnType.RawType; int size; diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 6540deebd3..d3703d165a 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -770,6 +770,9 @@ public void SaveAsOnnx(OnnxContext ctx) private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + var transformInfo = _parent._transformInfos[iinfo]; // TfIdfVectorizer accepts strings, int32 and int64 tensors. diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 0185d79951..80d8f290f5 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -233,6 +233,9 @@ public void SaveAsOnnx(OnnxContext ctx) private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 10; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + // StringNormalizer only takes input of shapes [C] or [1,C], // so the input is squeezed to support inferred shapes ( e.g. [-1,C] ). var opType = "Squeeze"; diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 027cfb7f0b..61a42220f2 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -231,6 +231,9 @@ public void SaveAsOnnx(OnnxContext ctx) private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + string opType = "Tokenizer"; DataViewType dataViewType; if (_isSourceVector[iinfo]) diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 6333e637f7..61dd71f59c 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -414,6 +414,9 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV // | // P [j * 3] + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + long[] axes = new long[] { 0 }; // Allocate D, a constant tensor representing word embedding weights. var shapeD = new long[] { _parent._currentVocab.GetNumWords() + 3, _parent._currentVocab.Dimension }; diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index 4358fbaad9..a47b5da8ad 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -932,7 +932,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt index 1c8cc52265..6ddba6bb1b 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt @@ -580,7 +580,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt index 8806bc45e3..50b0aa349e 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt @@ -324,7 +324,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index 9091a61338..ba18e15d79 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -1181,7 +1181,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt index a1324ac03d..1cafc8ca1a 100644 --- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt +++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt @@ -502,7 +502,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index 50a604c84c..1c66067502 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -761,7 +761,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt index 8806bc45e3..50b0aa349e 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt @@ -324,7 +324,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt index 917bcc9862..0c3bc4d075 100644 --- a/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt +++ b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt @@ -193,7 +193,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt index 76ca512604..b705c2e885 100644 --- a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt +++ b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt @@ -1152,7 +1152,7 @@ "version": "2" }, { - "version": "11" + "version": "12" } ] } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index b46e7ba3fa..67ceb95c24 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -905,8 +905,12 @@ public void OnnxTypeConversionTest(DataKind fromKind, DataKind toKind) Done(); } - [Fact] - public void PcaOnnxConversionTest() + [Theory] + [InlineData(9)] + [InlineData(10)] + [InlineData(11)] + [InlineData(12)] + public void PcaOnnxConversionTest(int customOpSetVersion) { var dataSource = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); @@ -920,11 +924,24 @@ public void PcaOnnxConversionTest() foreach (var zeroMean in zeroMeans) { var pipeline = ML.Transforms.ProjectToPrincipalComponents("pca", "features", rank: 5, seed: 1, ensureZeroMean: zeroMean); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView, customOpSetVersion); + var onnxFileName = "pca.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); - TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("pca") }); - } + SaveOnnxModel(onnxModel, onnxModelPath, null); + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedColumns("pca", "pca", transformedData, onnxResult); + } + } Done(); } @@ -1689,6 +1706,70 @@ public void NonDefaultColNamesMultiClassificationOnnxConversionTest() } Done(); } + + [Fact] + public void OneHotHashEncodingOnnxConversionWithCustomOpSetVersionTest() + { + var mlContext = new MLContext(); + string dataPath = GetDataPath("breast-cancer.txt"); + + var dataView = ML.Data.LoadFromTextFile(dataPath); + var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{ + new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false), + }); + var model = pipe.Fit(dataView); + var transformedData = model.Transform(dataView); + + try + { + var onnxModelPath = GetOutputPath("onnxmodel_custom_opset_version_test.onnx"); + using (FileStream stream = new FileStream(onnxModelPath, FileMode.Create)) + mlContext.Model.ConvertToOnnx(model, dataView, 9, stream); + Assert.True(false); + } + catch (System.Exception ex) + { + Assert.Contains("Requested OpSet version 9 is lower than HashTransform's minimum OpSet version requirement: 11", ex.Message); + return; + } + + try + { + var onnxModelPath = GetOutputPath("onnxmodel_custom_opset_version_test.onnx"); + using (FileStream stream = new FileStream(onnxModelPath, FileMode.Create)) + mlContext.Model.ConvertToOnnx(model, dataView, 13, stream); + Assert.True(false); + } + catch (System.Exception ex) + { + Assert.Contains("Requested OpSet version 13 is higher than the current most updated OpSet version 12", ex.Message); + return; + } + + try + { + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView, 9); + Assert.True(false); + } + catch (System.Exception ex) + { + Assert.Contains("Requested OpSet version 9 is lower than HashTransform's minimum OpSet version requirement: 11", ex.Message); + return; + } + + try + { + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView, 13); + Assert.True(false); + } + catch (System.Exception ex) + { + Assert.Contains("Requested OpSet version 13 is higher than the current most updated OpSet version 12", ex.Message); + return; + } + + Done(); + } [Theory] [CombinatorialData]