Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ private static VersionInfo GetVersionInfo()

public override DataViewSchema OutputSchema => _bindings.Schema;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than " + LoaderSignature + "'s minimum OpSet version requirement: " + minimumOpSetVersion);
return _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
}

bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;

Expand Down
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ internal abstract class OnnxContext
/// <returns>Whether the column is mapped in this context</returns>
public abstract bool ContainsColumn(string colName);

/// <summary>
/// Get the OpSet version
/// </summary>
/// <returns>OpSet version </returns>
public abstract int GetOpSetVersion();

/// <summary>
/// Stops tracking a column.
/// </summary>
Expand Down
21 changes: 18 additions & 3 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,12 @@ internal abstract class ValueMapperCalibratedModelParametersBase<TSubModel, TCal

FeatureContributionCalculator ICalculateFeatureContribution.FeatureContributionCalculator => new FeatureContributionCalculator(this);

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than ValueMapperCalibratedModel's minimum OpSet version requirement: " + minimumOpSetVersion);
return (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
}

private protected ValueMapperCalibratedModelParametersBase(IHostEnvironment env, string name, TSubModel predictor, TCalibrator calibrator)
: base(env, name, predictor, calibrator)
Expand Down Expand Up @@ -747,7 +752,12 @@ private static VersionInfo GetVersionInfo()
/// </summary>
bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than SchemaBindableCalibratedModel's minimum OpSet version requirement: " + minimumOpSetVersion);
return (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
}

internal SchemaBindableCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator)
: base(env, LoaderSignature, predictor, calibrator)
Expand Down Expand Up @@ -1649,7 +1659,12 @@ private static VersionInfo GetVersionInfo()
/// </summary>
public Double Offset { get; }
bool ICanSavePfa.CanSavePfa => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than PlattCalibrator's minimum OpSet version requirement: " + minimumOpSetVersion);
return true;
}

/// <summary>
/// Initializes a new instance of <see cref="PlattCalibrator"/>.
Expand Down
7 changes: 6 additions & 1 deletion src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ private sealed class Mapper<TCalibrator> : MapperBase, ISaveAsOnnx
private readonly int _scoreColIndex;
private CalibratorTransformer<TCalibrator> _parent;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than CalibratorTransformer's minimum OpSet version requirement: " + minimumOpSetVersion);
return _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
}

internal Mapper(CalibratorTransformer<TCalibrator> parent, TCalibrator calibrator, DataViewSchema inputSchema) :
base(parent.Host, inputSchema, parent)
Expand Down
7 changes: 6 additions & 1 deletion src/Microsoft.ML.Data/Scorers/GenericScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ private static VersionInfo GetVersionInfo()

bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than " + LoadName + "'s minimum OpSet version requirement: " + minimumOpSetVersion);
return (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
}

/// <summary>
/// The <see cref="SignatureDataScorer"/> entry point for creating a <see cref="GenericScorer"/>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public sealed class Arguments : ScorerArgumentsBase
}

public const string LoaderSignature = "MultiClassScoreTrans";

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
Expand Down Expand Up @@ -88,7 +89,12 @@ private static readonly FuncInstanceMethodInfo1<LabelNameBindableMapper, object,

public VectorDataViewType Type => _type;
bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than " + LoaderSignature + "'s minimum OpSet version requirement: " + minimumOpSetVersion);
return (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
}

private static VersionInfo GetVersionInfo()
{
Expand Down
7 changes: 6 additions & 1 deletion src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,12 @@ public override Func<int, bool> GetActiveMapperColumns(bool[] active)

bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 11;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than PredictedLabelScorer's minimum OpSet version requirement: " + minimumOpSetVersion);
return (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
}

[BestFriend]
private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ private static readonly FuncInstanceMethodInfo2<SchemaBindablePredictorWrapperBa

bool ICanSavePfa.CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) {
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than SchemaBindablePredictor's minimum OpSet version requirement: " + minimumOpSetVersion);
return (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; ;
}

public SchemaBindablePredictorWrapperBase(IPredictor predictor)
{
Expand Down
9 changes: 8 additions & 1 deletion src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,14 @@ void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
}
}

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 11;
if (ctx.GetOpSetVersion() < minimumOpSetVersion)
throw Contracts.ExceptParam(nameof(minimumOpSetVersion), $"OpSet version {ctx.GetOpSetVersion()} is older than {LoaderSignature}'s minimum OpSet version requirement: {minimumOpSetVersion}");
//Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than " + LoaderSignature + "'s minimum OpSet version requirement: " + minimumOpSetVersion);
return true;
}
}

private abstract class InvertHashHelper
Expand Down
7 changes: 6 additions & 1 deletion src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2808,7 +2808,12 @@ public abstract class TreeEnsembleModelParameters :

bool ICanSavePfa.CanSavePfa => true;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than FastTreeTraining's minimum OpSet version requirement: " + minimumOpSetVersion);
return true;
}

/// <summary>
/// Used to determine the contribution of each feature to the score of an example by <see cref="FeatureContributionCalculatingTransformer"/>.
Expand Down
7 changes: 6 additions & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ private static VersionInfo GetVersionInfo()
DataViewType IValueMapper.InputType => _inputType;
DataViewType IValueMapper.OutputType => _outputType;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than " + LoaderSignature + "'s minimum OpSet version requirement: " + minimumOpSetVersion);
return true;
}

private readonly int _dimensionality;
private readonly int _k;
Expand Down
30 changes: 29 additions & 1 deletion src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ internal sealed class OnnxContextImpl : OnnxContext
private readonly string _producerVersion;
private readonly long _modelVersion;
private readonly OnnxVersion _onnxVersion;
private readonly OptionalOpSetVersion _optionalOpSetVersion;

public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion)
Expand All @@ -55,6 +56,28 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
_modelVersion = modelVersion;
_domain = domain;
_onnxVersion = onnxVersion;
_optionalOpSetVersion = new OptionalOpSetVersion();
}

internal class OptionalOpSetVersion
{
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
private int _defaultOpSetVersion;
internal OptionalOpSetVersion()
{
_defaultOpSetVersion = 11;
}
internal int OpSetVersion
{
get { return _defaultOpSetVersion; }
set { _defaultOpSetVersion = value; }
}
}

public void ModifyOpSetVersion(int customOpSetVersion)
{
if (customOpSetVersion > _optionalOpSetVersion.OpSetVersion)
throw _host.ExceptParam(nameof(customOpSetVersion), $"User defined OpSet version is newer than the current most updated OpSet version '{_optionalOpSetVersion.OpSetVersion}'");
_optionalOpSetVersion.OpSetVersion = customOpSetVersion;
}

public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
Expand Down Expand Up @@ -127,6 +150,11 @@ public override string GetNodeName(string prefix)
return GetUniqueName(prefix, _nodeNames.Contains);
}

public override int GetOpSetVersion()
{
return _optionalOpSetVersion.OpSetVersion;
}

/// <summary>
/// Adds a node to the node list of the graph.
/// </summary>
Expand Down Expand Up @@ -409,7 +437,7 @@ public override string AddInitializer(IEnumerable<ulong> values, bool isUint64,
/// Makes the ONNX model based on the context.
/// </summary>
public OnnxCSharpToProtoWrapper.ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers);
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _optionalOpSetVersion.OpSetVersion, _inputs, _outputs, _intermediateValues, _initializers);

/// <summary>
/// Return either "Experimental" or "Stable". The string "Experimental" indicates that some experimental features which are
Expand Down
48 changes: 41 additions & 7 deletions src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;

namespace Microsoft.ML
{
public static class OnnxExportExtensions
{
private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData)
{
var outputData = transform.Transform(inputData);
LinkedList<ITransformCanSaveOnnx> transforms = null;
using (var ch = env.Start("ONNX conversion"))
{
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
}
}

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
Expand All @@ -26,13 +37,24 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
var outputData = transform.Transform(inputData);
LinkedList<ITransformCanSaveOnnx> transforms = null;
using (var ch = env.Start("ONNX conversion"))
{
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
}
return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
}

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxProtobufWithCustomOpSetVersion(ModelOperationsCatalog, ITransformer, IDataView, int)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion">Set custom value to OpSet version</param>
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
Comment thread
gh-yewang marked this conversation as resolved.
[BestFriend]
internal static ModelProto ConvertToOnnxProtobufWithCustomOpSetVersion(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion)
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
ctx.ModifyOpSetVersion(opSetVersion);
return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
Comment thread
gh-yewang marked this conversation as resolved.
}

/// <summary>
Expand All @@ -45,5 +67,17 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxWithCustomOpSetVersion(ModelOperationsCatalog, ITransformer, IDataView, int, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion"></param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnxWithCustomOpSetVersion(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) =>
ConvertToOnnxProtobufWithCustomOpSetVersion(catalog, transform, inputData, opSetVersion).WriteTo(stream);
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List<long> di
}

public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
string domain, string producerVersion, long modelVersion, int opSetVersion, List<ModelArgs> inputs,
List<ModelArgs> outputs, List<ModelArgs> intermediateValues, List<TensorProto> initializers)
{
Contracts.CheckValue(nodes, nameof(nodes));
Expand All @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 2 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 11 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = opSetVersion });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ public IEnumerator<float> GetEnumerator()

bool ICanSavePfa.CanSavePfa => true;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
Contracts.Assert(ctx.GetOpSetVersion() >= minimumOpSetVersion, "OpSet version " + ctx.GetOpSetVersion() + " is older than LinearModel's minimum OpSet version requirement: " + minimumOpSetVersion);
return true;
}

/// <summary>
/// Used to determine the contribution of each feature to the score of an example by <see cref="FeatureContributionCalculatingTransformer"/>.
Comment thread
gh-yewang marked this conversation as resolved.
Expand Down
Loading