Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,20 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)]
public bool? LoadPredictor;

/// <summary>
/// Entry point API can save either <see cref="TransformModel"/> or <see cref="PredictorModel"/>.
/// <see cref="Model"/> is used when the saved model is typed to <see cref="TransformModel"/>.
/// </summary>
[Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)]
public TransformModel Model;

/// <summary>
/// Entry point API can save either <see cref="TransformModel"/> or <see cref="PredictorModel"/>.
/// <see cref="PredictiveModel"/> is used when the saved model is typed to <see cref="PredictorModel"/>.
/// </summary>
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Predictor model that needs to be converted to ONNX format.", SortOrder = 12)]
public PredictorModel PredictiveModel;

[Argument(ArgumentType.AtMostOnce, HelpText = "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.", SortOrder = 11)]
public OnnxVersion OnnxVersion;
}
Expand All @@ -72,6 +83,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase
private readonly HashSet<string> _inputsToDrop;
private readonly HashSet<string> _outputsToDrop;
private readonly TransformModel _model;
private readonly PredictorModel _predictiveModel;
private const string ProducerName = "ML.NET";
private const long ModelVersion = 0;

Expand All @@ -96,7 +108,13 @@ public SaveOnnxCommand(IHostEnvironment env, Arguments args)
_inputsToDrop = CreateDropMap(args.InputsToDropArray ?? args.InputsToDrop?.Split(','));
_outputsToDrop = CreateDropMap(args.OutputsToDropArray ?? args.OutputsToDrop?.Split(','));
_domain = args.Domain;

if (args.Model != null && args.PredictiveModel != null)
throw env.Except(nameof(args.Model) + " and " + nameof(args.PredictiveModel) +
" cannot be specified at the same time when calling ONNX converter. Please check the content of " + nameof(args) + ".");

_model = args.Model;
_predictiveModel = args.PredictiveModel;
}

private static HashSet<string> CreateDropMap(string[] toDrop)
Expand Down Expand Up @@ -198,7 +216,7 @@ private void Run(IChannel ch)
IDataView view;
RoleMappedSchema trainSchema = null;

if (_model == null)
if (_model == null && _predictiveModel == null)
{
if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
{
Expand All @@ -213,8 +231,16 @@ private void Run(IChannel ch)

view = loader;
}
else
else if (_model != null)
{
view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
}
else
{
view = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema));
rawPred = _predictiveModel.Predictor;
trainSchema = _predictiveModel.GetTrainingSchema(Host);
}

// Create the ONNX context for storing global information
var assembly = System.Reflection.Assembly.GetExecutingAssembly();
Expand Down
9 changes: 9 additions & 0 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2293,6 +2293,15 @@
"SortOrder": 11.0,
"IsNullable": false,
"Default": "Stable"
},
{
"Name": "PredictiveModel",

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PredictiveModel [](start = 19, length = 15)

Can it be named PredictorModel ?

"Type": "PredictorModel",
"Desc": "Predictor model that needs to be converted to ONNX format.",
"Required": false,
"SortOrder": 12.0,
"IsNullable": false,
"Default": null
}
],
"Outputs": []
Expand Down
129 changes: 127 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Text.RegularExpressions;
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
Expand Down Expand Up @@ -186,7 +187,7 @@ void CommandLineOnnxConversionTest()
string modelPath = GetOutputPath("ModelWithLessIO.zip");
var trainingPathArgs = $"data={dataPath} out={modelPath}";
var trainingArgs = " loader=text{col=Label:BL:0 col=F1:R4:1-8 col=F2:TX:9} xf=Cat{col=F2} xf=Concat{col=Features:F1,F2} tr=ft{numberOfThreads=1 numberOfLeaves=8 numberOfTrees=3} seed=1";
Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs}));
Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs }));

var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer");
var onnxTextName = "ModelWithLessIO.txt";
Expand Down Expand Up @@ -403,6 +404,130 @@ public void MulticlassLogisticRegressionOnnxConversionTest()
Done();
}

[Fact]
public void LoadingPredictorModelAndOnnxConversionTest()
{
string dataPath = GetDataPath("iris.txt");
string modelPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.bin";
string onnxPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx";
string onnxJsonPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx.json";

string inputGraph = string.Format(@"
{{
'Inputs': {{
'inputFile': '{0}'
}},
'Nodes': [
{{
'Name': 'Data.TextLoader',
'Inputs':
{{
'InputFile': '$inputFile',
'Arguments':
{{
'UseThreads': true,
'HeaderFile': null,
'MaxRows': null,
'AllowQuoting': true,
'AllowSparse': true,
'InputSize': null,
'TrimWhitespace': false,
'HasHeader': false,
'Column':
[
{{'Name':'Sepal_Width','Type':null,'Source':[{{'Min':2,'Max':2,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}},
{{'Name':'Petal_Length','Type':null,'Source':[{{'Min':3,'Max':3,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}},
{{'Name':'Petal_Width','Type':null,'Source':[{{'Min':4,'Max':4,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}},
]
}}
}},
'Outputs':
{{
'Data': '$optional_data'
}}
}},
{{
'Inputs': {{
'Data': '$optional_data',
'Features': [
'Petal_Length',
'Petal_Width',
]
}},
'Name': 'Transforms.FeatureCombiner',
'Outputs': {{
'OutputData': '$output_data'
}}
}},
{{
'Inputs': {{
'FeatureColumnName': 'Features',
'LabelColumnName': 'Sepal_Width',
'TrainingData': '$output_data',
}},
'Name': 'Trainers.StochasticDualCoordinateAscentRegressor',
'Outputs': {{
'PredictorModel': '$output_model'

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is typed to PredictorModel, which cannot be loaded correctly before.

}}
}}
],
'Outputs': {{
'output_model': '{1}'
}}
}}", dataPath.Replace("\\", "\\\\"), modelPath.Replace("\\", "\\\\"));

// Write entry point graph into file so that it can be invoke by graph runner below.
var jsonPath = DeleteOutputPath("graph.json");
File.WriteAllLines(jsonPath, new[] { inputGraph });

// Execute the saved entry point graph to produce a predictive model.
var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath };
var cmd = new ExecuteGraphCommand(Env, args);
cmd.Run();

// Make entry point graph to conduct ONNX conversion.
inputGraph = string.Format(@"
{{
'Inputs': {{
'model': '{0}'
}},
'Nodes': [
{{
'Inputs': {{
'Domain': 'com.microsoft.models',
'Json': '{1}',
'Model': '$model',
'Onnx': '{2}',
'OnnxVersion': 'Experimental'
}},
'Name': 'Models.OnnxConverter',
'Outputs': {{}}
}}
],
'Outputs': {{}}
}}
", modelPath.Replace("\\", "\\\\"), onnxJsonPath.Replace("\\", "\\\\"), onnxPath.Replace("\\", "\\\\"));

// Write entry point graph for ONNX conversion into file so that it can be invoke by graph runner below.
jsonPath = DeleteOutputPath("graph.json");
File.WriteAllLines(jsonPath, new[] { inputGraph });

// Onnx converter's assembly is not loaded by default, so we need to register it before calling it.
Env.ComponentCatalog.RegisterAssembly(typeof(OnnxExportExtensions).Assembly);

// Execute the saved entry point graph to convert the saved model.
args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath };
cmd = new ExecuteGraphCommand(Env, args);
cmd.Run();

File.Delete(modelPath);
File.Delete(onnxPath);
File.Delete(onnxJsonPath);

Done();
}


[Fact]
public void RemoveVariablesInPipelineTest()
{
Expand Down Expand Up @@ -451,7 +576,7 @@ public void RemoveVariablesInPipelineTest()

private class SmallSentimentExample
{
[LoadColumn(0,3), VectorType(4)]
[LoadColumn(0, 3), VectorType(4)]
public string[] Tokens;
}

Expand Down