Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ internal abstract class OnnxContext
/// <returns>Whether the column is mapped in this context</returns>
public abstract bool ContainsColumn(string colName);

/// <summary>
/// Check the required OpSet version satisfies our requirement
/// </summary>
/// <returns></returns>
public abstract void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName);

/// <summary>
/// Stops tracking a column.
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
_host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames));
_host.Check(Utils.Size(scoreProbablityColumnNames) == 2);

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "PlattCalibrator");

// The Affine operator is no longer supported in the v11 opset.
// So we have to decompose it using Mul and Add
string opType = "Mul";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ public void SaveAsOnnx(OnnxContext ctx)
Host.CheckValue(ctx, nameof(ctx));
Contracts.Assert(CanSaveOnnx(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
{
var colInfo = _parent._columns[iinfo];
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var opType = "Identity";

foreach (var column in _columns)
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,9 @@ IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView ne

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var outputToInputMap = _mapper.OutputToInputMap;
for(int i = 0; i < outputToInputMap.Length; i++)
{
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,9 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable)
{
const int minimumOpSetVersion = 11;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

string castOutput;
string isGreaterThanZeroOutput = "";
OnnxNode castNode;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,9 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken)

public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

string opType;

// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,9 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var dim = info.TypeSrc.GetValueCount();

string opType = "Cast";
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/Normalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColumnOptions info, stri
Contracts.Assert(_parent.Columns[iinfo] == info);
Contracts.Assert(CanSaveOnnx(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

int valueCount = info.InputType.GetValueCount();
if (valueCount == 0)
return false;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,9 @@ public void SaveAsOnnx(OnnxContext ctx)

public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

string opType;
var slots = _slotDropper[iinfo].GetPreservedSlots();
// vector column is not suppressed
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/TypeConverting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ public void SaveAsOnnx(OnnxContext ctx)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var opType = "Cast";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,9 @@ private void CastInputToFloat<T>(OnnxContext ctx, out OnnxNode node, out long[]

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

OnnxNode node;
long[] termIds;
string opType = "LabelEncoder";
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3032,6 +3032,9 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames,
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputNames) >= 1);

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "TreeEnsembleModelParameters");

//Nodes.
var nodesTreeids = new List<long>();
var nodesIds = new List<long>();
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ internal static FastTreeTweedieModelParameters Create(IHostEnvironment env, Mode

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ internal static FastForestRegressionModelParameters Create(IHostEnvironment env,

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
var numTrees = ctx.AddInitializer((float)TrainedEnsemble.NumTrees, "NumTrees");
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
// v
// L [l] <--- ArgMin <--- Y [l, k]

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

// Allocate C, which is a constant tensor in prediction phase
var shapeC = new long[] { _centroids.Length, _centroids[0].Length };
var tensorC = new List<float>();
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Mkl.Components/VectorWhitening.cs
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ public void SaveAsOnnx(OnnxContext ctx)

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var model = _parent._models[iinfo];
int dimension = _srcTypes[iinfo].GetValueCount();
Host.Assert(model.Length == dimension * dimension);
Expand Down
16 changes: 14 additions & 2 deletions src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.ML.Model.OnnxConverter
/// </summary>
internal sealed class OnnxContextImpl : OnnxContext
{
private const int CurrentOpSetVersion = 12;
private const int MinimumOpSetVersion = 9;
private readonly List<OnnxCSharpToProtoWrapper.NodeProto> _nodes;
private readonly List<OnnxUtils.ModelArgs> _inputs;
// The map from IDataView column names to variable names.
Expand All @@ -32,9 +34,10 @@ internal sealed class OnnxContextImpl : OnnxContext
private readonly string _producerVersion;
private readonly long _modelVersion;
private readonly OnnxVersion _onnxVersion;
private readonly int _opSetVersion;

public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion)
string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion, int opSetVersion = CurrentOpSetVersion)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(OnnxContext));
Expand All @@ -55,6 +58,9 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
_modelVersion = modelVersion;
_domain = domain;
_onnxVersion = onnxVersion;
_opSetVersion = opSetVersion <= CurrentOpSetVersion ?
opSetVersion >= MinimumOpSetVersion ? opSetVersion : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is lower than the minimum required OpSet version {MinimumOpSetVersion}") :
throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is higher than the current most updated OpSet version {CurrentOpSetVersion}");
}

public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
Expand Down Expand Up @@ -127,6 +133,12 @@ public override string GetNodeName(string prefix)
return GetUniqueName(prefix, _nodeNames.Contains);
}

public override void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName)
{
if (_opSetVersion < thisTransformerMinumumOpSetVersion)
throw _host.ExceptParam(nameof(thisTransformerMinumumOpSetVersion), $"Requested OpSet version {_opSetVersion} is lower than {registerTransformerName}'s minimum OpSet version requirement: {thisTransformerMinumumOpSetVersion}");
}

/// <summary>
/// Adds a node to the node list of the graph.
/// </summary>
Expand Down Expand Up @@ -409,7 +421,7 @@ public override string AddInitializer(IEnumerable<ulong> values, bool isUint64,
/// Makes the ONNX model based on the context.
/// </summary>
public OnnxCSharpToProtoWrapper.ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers);
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _opSetVersion, _inputs, _outputs, _intermediateValues, _initializers);

/// <summary>
/// Return either "Experimental" or "Stable". The string "Experimental" indicates that some experimental features which are
Expand Down
47 changes: 40 additions & 7 deletions src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;

namespace Microsoft.ML
{
public static class OnnxExportExtensions
{
private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData)
{
var outputData = transform.Transform(inputData);
LinkedList<ITransformCanSaveOnnx> transforms = null;
using (var ch = env.Start("ONNX conversion"))
{
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
}
}

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
Expand All @@ -26,13 +37,23 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
var outputData = transform.Transform(inputData);
LinkedList<ITransformCanSaveOnnx> transforms = null;
using (var ch = env.Start("ONNX conversion"))
{
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
}
return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
}

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxProtobuf(ModelOperationsCatalog, ITransformer, IDataView, int)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
[BestFriend]
internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion)
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable, opSetVersion);
return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
}

/// <summary>
Expand All @@ -45,5 +66,17 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, IDataView, int, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData, opSetVersion).WriteTo(stream);
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List<long> di
}

public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
string domain, string producerVersion, long modelVersion, int opSetVersion, List<ModelArgs> inputs,
List<ModelArgs> outputs, List<ModelArgs> intermediateValues, List<TensorProto> initializers)
{
Contracts.CheckValue(nodes, nameof(nodes));
Expand All @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 2 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 11 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = opSetVersion });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.PCA/PcaTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName,
{
Host.CheckValue(ctx, nameof(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

TransformInfo transformInfo = _parent._transformInfos[iinfo];

// When the transformer is loaded from a model file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, str
{
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 1);

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "LinearModel");

string opType = "LinearRegressor";
string scoreVarName = (Utils.Size(outputs) >= 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,10 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, JToken input)
private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "MultiClassLogisticRegression");

Host.Assert(outputs[0] == DefaultColumnNames.PredictedLabel);
Host.Assert(outputs[1] == DefaultColumnNames.Score);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
/// </summary>
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "MulticlassNaiveBayes");

float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length];
float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];

Expand Down
Loading