diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
index 913db98559..36d6ab45dc 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
@@ -43,6 +43,12 @@ internal abstract class OnnxContext
/// Whether the column is mapped in this context
public abstract bool ContainsColumn(string colName);
+ ///
+ /// Check the required OpSet version satisfies our requirement
+ ///
+ ///
+ public abstract void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName);
+
///
/// Stops tracking a column.
///
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index 9265e60424..d1948c0cb4 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -1745,6 +1745,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
_host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames));
_host.Check(Utils.Size(scoreProbablityColumnNames) == 2);
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, "PlattCalibrator");
+
// The Affine operator is no longer supported in the v11 opset.
// So we have to decompose it using Mul and Add
string opType = "Mul";
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
index ba70941461..e4712a97d8 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
@@ -902,6 +902,9 @@ public void SaveAsOnnx(OnnxContext ctx)
Host.CheckValue(ctx, nameof(ctx));
Contracts.Assert(CanSaveOnnx(ctx));
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
{
var colInfo = _parent._columns[iinfo];
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
index a8b9b8dc18..c79a9b5bbe 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
@@ -237,6 +237,9 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
public void SaveAsOnnx(OnnxContext ctx)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
var opType = "Identity";
foreach (var column in _columns)
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
index e65eba6734..359fad24b1 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
@@ -732,6 +732,9 @@ IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView ne
public void SaveAsOnnx(OnnxContext ctx)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
var outputToInputMap = _mapper.OutputToInputMap;
for(int i = 0; i < outputToInputMap.Length; i++)
{
diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs
index 5bf272b50d..a5e7fe8876 100644
--- a/src/Microsoft.ML.Data/Transforms/Hashing.cs
+++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs
@@ -1353,6 +1353,9 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder)
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable)
{
+ const int minimumOpSetVersion = 11;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
string castOutput;
string isGreaterThanZeroOutput = "";
OnnxNode castNode;
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
index 5ffc82b1a1..47705d021a 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
@@ -500,6 +500,9 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken)
public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
string opType;
// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
index cf04031325..f7a3fcdbc7 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs
@@ -689,6 +689,9 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
var dim = info.TypeSrc.GetValueCount();
string opType = "Cast";
diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
index 0d3ee95241..ace25ae86a 100644
--- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs
+++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
@@ -822,6 +822,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColumnOptions info, stri
Contracts.Assert(_parent.Columns[iinfo] == info);
Contracts.Assert(CanSaveOnnx(ctx));
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
int valueCount = info.InputType.GetValueCount();
if (valueCount == 0)
return false;
diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
index 0087e6546e..0dbef23e73 100644
--- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
@@ -895,6 +895,9 @@ public void SaveAsOnnx(OnnxContext ctx)
public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
string opType;
var slots = _slotDropper[iinfo].GetPreservedSlots();
// vector column is not suppressed
diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
index 0ab4e0520a..4af5ea8344 100644
--- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
+++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
@@ -497,6 +497,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
var opType = "Cast";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType();
diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
index 297b475159..7947ddaaf4 100644
--- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
@@ -812,6 +812,9 @@ private void CastInputToFloat(OnnxContext ctx, out OnnxNode node, out long[]
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
OnnxNode node;
long[] termIds;
string opType = "LabelEncoder";
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 766e7eb3d1..b7f33239fe 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -3032,6 +3032,9 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames,
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputNames) >= 1);
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, "TreeEnsembleModelParameters");
+
//Nodes.
var nodesTreeids = new List();
var nodesIds = new List();
diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
index 4cedfa1d0b..68383d8202 100644
--- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
@@ -534,6 +534,9 @@ internal static FastTreeTweedieModelParameters Create(IHostEnvironment env, Mode
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
index 2957a14937..5186f46813 100644
--- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
@@ -213,6 +213,9 @@ internal static FastForestRegressionModelParameters Create(IHostEnvironment env,
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
var numTrees = ctx.AddInitializer((float)TrainedEnsemble.NumTrees, "NumTrees");
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
index 2f5c220fe9..849caae820 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
@@ -316,6 +316,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
// v
// L [l] <--- ArgMin <--- Y [l, k]
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
// Allocate C, which is a constant tensor in prediction phase
var shapeC = new long[] { _centroids.Length, _centroids[0].Length };
var tensorC = new List();
diff --git a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs
index 751fb03b02..fdeffa1a8a 100644
--- a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs
+++ b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs
@@ -641,6 +641,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
var model = _parent._models[iinfo];
int dimension = _srcTypes[iinfo].GetValueCount();
Host.Assert(model.Length == dimension * dimension);
diff --git a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
index 98866529f4..5e5bbd669d 100644
--- a/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
+++ b/src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
@@ -14,6 +14,8 @@ namespace Microsoft.ML.Model.OnnxConverter
///
internal sealed class OnnxContextImpl : OnnxContext
{
+ private const int CurrentOpSetVersion = 12;
+ private const int MinimumOpSetVersion = 9;
private readonly List _nodes;
private readonly List _inputs;
// The map from IDataView column names to variable names.
@@ -32,9 +34,10 @@ internal sealed class OnnxContextImpl : OnnxContext
private readonly string _producerVersion;
private readonly long _modelVersion;
private readonly OnnxVersion _onnxVersion;
+ private readonly int _opSetVersion;
public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
- string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion)
+ string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion, int opSetVersion = CurrentOpSetVersion)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(OnnxContext));
@@ -55,6 +58,9 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
_modelVersion = modelVersion;
_domain = domain;
_onnxVersion = onnxVersion;
+ _opSetVersion = opSetVersion <= CurrentOpSetVersion ?
+ opSetVersion >= MinimumOpSetVersion ? opSetVersion : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is lower than the minimum required OpSet version {MinimumOpSetVersion}") :
+ throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is higher than the current most updated OpSet version {CurrentOpSetVersion}");
}
public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
@@ -127,6 +133,12 @@ public override string GetNodeName(string prefix)
return GetUniqueName(prefix, _nodeNames.Contains);
}
+ public override void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName)
+ {
+ if (_opSetVersion < thisTransformerMinumumOpSetVersion)
+ throw _host.ExceptParam(nameof(thisTransformerMinumumOpSetVersion), $"Requested OpSet version {_opSetVersion} is lower than {registerTransformerName}'s minimum OpSet version requirement: {thisTransformerMinumumOpSetVersion}");
+ }
+
///
/// Adds a node to the node list of the graph.
///
@@ -409,7 +421,7 @@ public override string AddInitializer(IEnumerable values, bool isUint64,
/// Makes the ONNX model based on the context.
///
public OnnxCSharpToProtoWrapper.ModelProto MakeModel()
- => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers);
+ => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _opSetVersion, _inputs, _outputs, _intermediateValues, _initializers);
///
/// Return either "Experimental" or "Stable". The string "Experimental" indicates that some experimental features which are
diff --git a/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs b/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs
index d81d5e64ee..659829abe4 100644
--- a/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs
+++ b/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs
@@ -7,12 +7,23 @@
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
+using Microsoft.ML.Runtime;
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;
namespace Microsoft.ML
{
public static class OnnxExportExtensions
{
+ private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData)
+ {
+ var outputData = transform.Transform(inputData);
+ LinkedList transforms = null;
+ using (var ch = env.Start("ONNX conversion"))
+ {
+ SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
+ return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
+ }
+ }
///
/// Convert the specified to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
@@ -26,13 +37,23 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
- var outputData = transform.Transform(inputData);
- LinkedList transforms = null;
- using (var ch = env.Start("ONNX conversion"))
- {
- SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
- return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
- }
+ return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
+ }
+
+ ///
+ /// Convert the specified to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
+ ///
+ /// The class that attached to.
+ /// The that will be converted into ONNX format.
+ /// The input of the specified transform.
+ /// The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12
+ /// An ONNX model equivalent to the converted ML.NET model.
+ [BestFriend]
+ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion)
+ {
+ var env = catalog.GetEnvironment();
+ var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable, opSetVersion);
+ return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
}
///
@@ -45,5 +66,17 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
/// An ONNX model equivalent to the converted ML.NET model.
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData).WriteTo(stream);
+
+ ///
+ /// Convert the specified to ONNX format and writes to a stream.
+ ///
+ /// The class that attached to.
+ /// The that will be converted into ONNX format.
+ /// The input of the specified transform.
+ /// The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12
+ /// The stream to write the protobuf model to.
+ /// An ONNX model equivalent to the converted ML.NET model.
+ public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) =>
+ ConvertToOnnxProtobuf(catalog, transform, inputData, opSetVersion).WriteTo(stream);
}
}
diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
index afdd20c7bf..af5529c2e0 100644
--- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
+++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
@@ -285,7 +285,7 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List di
}
public static ModelProto MakeModel(List nodes, string producerName, string name,
- string domain, string producerVersion, long modelVersion, List inputs,
+ string domain, string producerVersion, long modelVersion, int opSetVersion, List inputs,
List outputs, List intermediateValues, List initializers)
{
Contracts.CheckValue(nodes, nameof(nodes));
@@ -305,7 +305,7 @@ public static ModelProto MakeModel(List nodes, string producerName, s
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 2 });
- model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 11 });
+ model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = opSetVersion });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs
index 55774feff1..c90e16ee48 100644
--- a/src/Microsoft.ML.PCA/PcaTransformer.cs
+++ b/src/Microsoft.ML.PCA/PcaTransformer.cs
@@ -625,6 +625,9 @@ private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName,
{
Host.CheckValue(ctx, nameof(ctx));
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
TransformInfo transformInfo = _parent._transformInfos[iinfo];
// When the transformer is loaded from a model file,
diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs
index 8f1745b2a0..e59096bc88 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs
@@ -137,6 +137,10 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, str
{
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 1);
+
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, "LinearModel");
+
string opType = "LinearRegressor";
string scoreVarName = (Utils.Size(outputs) >= 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));
diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index 2c141c7520..2cf18dc757 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -978,6 +978,10 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, JToken input)
private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));
+
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, "MultiClassLogisticRegression");
+
Host.Assert(outputs[0] == DefaultColumnNames.PredictedLabel);
Host.Assert(outputs[1] == DefaultColumnNames.Score);
diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
index 963d612ec7..1dfef5ad5d 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs
@@ -421,6 +421,9 @@ ValueMapper IValueMapper.GetMapper()
///
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, "MulticlassNaiveBayes");
+
float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length];
float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];
diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs
index 7161a953a4..bfbcde8086 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs
@@ -660,6 +660,8 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false);
string opType = "Concat";
@@ -790,6 +792,8 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
Contracts.Assert(outputNames.Length >= 2);
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
string opType;
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true);
@@ -910,6 +914,8 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
Contracts.Assert(outputNames.Length >= 2);
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false);
diff --git a/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs
index b3af360868..d7ab83d94d 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/Simple/SimpleTrainers.cs
@@ -405,6 +405,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string lab
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 3);
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
string scoreVarName = outputs[1];
string probVarName = outputs[2];
var prob = ctx.AddInitializer(_prob, "probability");
diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs
index 10f55de931..9d94ba334c 100644
--- a/src/Microsoft.ML.Transforms/GcnTransform.cs
+++ b/src/Microsoft.ML.Transforms/GcnTransform.cs
@@ -627,6 +627,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
string opType;
if ((_norms[iinfo] != LpNormNormalizingEstimatorBase.NormFunction.StandardDeviation) && (_ensureZeroMeans[iinfo] == false))
diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
index 54cec41759..a0dc85bc97 100644
--- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
@@ -462,6 +462,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoadName);
+
var inputType = _infos[iinfo].InputType;
Type rawType = (inputType is VectorDataViewType vectorType) ? vectorType.ItemType.RawType : inputType.RawType;
diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
index 76578f4f2d..585eb1dc14 100644
--- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
@@ -867,6 +867,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoadName);
+
Type rawType;
var type = _infos[iinfo].TypeSrc;
if (type is VectorDataViewType vectorType)
diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
index 98b100c15f..9abb68a365 100644
--- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
+++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
@@ -538,6 +538,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, DataViewType columnType)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
Type type = columnType.RawType;
int size;
diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
index 6540deebd3..d3703d165a 100644
--- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
@@ -770,6 +770,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
var transformInfo = _parent._transformInfos[iinfo];
// TfIdfVectorizer accepts strings, int32 and int64 tensors.
diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
index 0185d79951..80d8f290f5 100644
--- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
@@ -233,6 +233,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 10;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
// StringNormalizer only takes input of shapes [C] or [1,C],
// so the input is squeezed to support inferred shapes ( e.g. [-1,C] ).
var opType = "Squeeze";
diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
index 027cfb7f0b..61a42220f2 100644
--- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
+++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
@@ -231,6 +231,9 @@ public void SaveAsOnnx(OnnxContext ctx)
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
string opType = "Tokenizer";
DataViewType dataViewType;
if (_isSourceVector[iinfo])
diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
index 6333e637f7..61dd71f59c 100644
--- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
@@ -414,6 +414,9 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV
// |
// P [j * 3]
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
+
long[] axes = new long[] { 0 };
// Allocate D, a constant tensor representing word embedding weights.
var shapeD = new long[] { _parent._currentVocab.GetNumWords() + 3, _parent._currentVocab.Dimension };
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt
index 4358fbaad9..a47b5da8ad 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt
@@ -932,7 +932,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt
index 1c8cc52265..6ddba6bb1b 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt
@@ -580,7 +580,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt
index 8806bc45e3..50b0aa349e 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt
@@ -324,7 +324,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt
index 9091a61338..ba18e15d79 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt
@@ -1181,7 +1181,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt
index a1324ac03d..1cafc8ca1a 100644
--- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt
+++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt
@@ -502,7 +502,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt
index 50a604c84c..1c66067502 100644
--- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt
+++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt
@@ -761,7 +761,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt
index 8806bc45e3..50b0aa349e 100644
--- a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt
+++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt
@@ -324,7 +324,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt
index 917bcc9862..0c3bc4d075 100644
--- a/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt
+++ b/test/BaselineOutput/Common/Onnx/Transforms/IndicateMissingValues.txt
@@ -193,7 +193,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt
index 76ca512604..b705c2e885 100644
--- a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt
+++ b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt
@@ -1152,7 +1152,7 @@
"version": "2"
},
{
- "version": "11"
+ "version": "12"
}
]
}
\ No newline at end of file
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index b46e7ba3fa..67ceb95c24 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -905,8 +905,12 @@ public void OnnxTypeConversionTest(DataKind fromKind, DataKind toKind)
Done();
}
- [Fact]
- public void PcaOnnxConversionTest()
+ [Theory]
+ [InlineData(9)]
+ [InlineData(10)]
+ [InlineData(11)]
+ [InlineData(12)]
+ public void PcaOnnxConversionTest(int customOpSetVersion)
{
var dataSource = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
@@ -920,11 +924,24 @@ public void PcaOnnxConversionTest()
foreach (var zeroMean in zeroMeans)
{
var pipeline = ML.Transforms.ProjectToPrincipalComponents("pca", "features", rank: 5, seed: 1, ensureZeroMean: zeroMean);
+ var model = pipeline.Fit(dataView);
+ var transformedData = model.Transform(dataView);
+ var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView, customOpSetVersion);
+
var onnxFileName = "pca.onnx";
+ var onnxModelPath = GetOutputPath(onnxFileName);
- TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("pca") });
- }
+ SaveOnnxModel(onnxModel, onnxModelPath, null);
+ if (IsOnnxRuntimeSupported())
+ {
+ // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
+ var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
+ var onnxTransformer = onnxEstimator.Fit(dataView);
+ var onnxResult = onnxTransformer.Transform(dataView);
+ CompareSelectedColumns("pca", "pca", transformedData, onnxResult);
+ }
+ }
Done();
}
@@ -1689,6 +1706,70 @@ public void NonDefaultColNamesMultiClassificationOnnxConversionTest()
}
Done();
}
+
+ [Fact]
+ public void OneHotHashEncodingOnnxConversionWithCustomOpSetVersionTest()
+ {
+ var mlContext = new MLContext();
+ string dataPath = GetDataPath("breast-cancer.txt");
+
+ var dataView = ML.Data.LoadFromTextFile(dataPath);
+ var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
+ new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false),
+ });
+ var model = pipe.Fit(dataView);
+ var transformedData = model.Transform(dataView);
+
+ try
+ {
+ var onnxModelPath = GetOutputPath("onnxmodel_custom_opset_version_test.onnx");
+ using (FileStream stream = new FileStream(onnxModelPath, FileMode.Create))
+ mlContext.Model.ConvertToOnnx(model, dataView, 9, stream);
+ Assert.True(false);
+ }
+ catch (System.Exception ex)
+ {
+ Assert.Contains("Requested OpSet version 9 is lower than HashTransform's minimum OpSet version requirement: 11", ex.Message);
+ return;
+ }
+
+ try
+ {
+ var onnxModelPath = GetOutputPath("onnxmodel_custom_opset_version_test.onnx");
+ using (FileStream stream = new FileStream(onnxModelPath, FileMode.Create))
+ mlContext.Model.ConvertToOnnx(model, dataView, 13, stream);
+ Assert.True(false);
+ }
+ catch (System.Exception ex)
+ {
+ Assert.Contains("Requested OpSet version 13 is higher than the current most updated OpSet version 12", ex.Message);
+ return;
+ }
+
+ try
+ {
+ var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView, 9);
+ Assert.True(false);
+ }
+ catch (System.Exception ex)
+ {
+ Assert.Contains("Requested OpSet version 9 is lower than HashTransform's minimum OpSet version requirement: 11", ex.Message);
+ return;
+ }
+
+ try
+ {
+ var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView, 13);
+ Assert.True(false);
+ }
+ catch (System.Exception ex)
+ {
+ Assert.Contains("Requested OpSet version 13 is higher than the current most updated OpSet version 12", ex.Message);
+ return;
+ }
+
+ Done();
+ }
[Theory]
[CombinatorialData]