Skip to content
4 changes: 4 additions & 0 deletions src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ public override int GetHashCode()
/// </summary>
public override void Register()
{
// this happens when use OnnxSequenceType attribute without specify sequence type
if (_elemType == null)
throw new InvalidOperationException("Please specify sequence type when use OnnxSequenceType Attribute.");

var enumerableType = typeof(IEnumerable<>);
var type = enumerableType.MakeGenericType(_elemType);
DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Drawing;
using Microsoft.ML.Data;
Expand All @@ -14,36 +15,54 @@
using System.IO;
using Microsoft.ML.TestFramework.Attributes;


namespace Microsoft.ML.Tests
{
public class OnnxSequenceTypeWithAttributesTest : BaseTestBaseline
public class OnnxSequenceTypeTest : BaseTestBaseline
{
public class OutputObj
{
[ColumnName("output")]
[OnnxSequenceType(typeof(IDictionary<string, float>))]
public IEnumerable<IDictionary<string, float>> Output;
}

public class ProblematicOutputObj
{

[ColumnName("output")]
// incorrect usage, should always specify sequence type when using OnnxSequenceType attribute
[OnnxSequenceType]
public IEnumerable<IDictionary<string, float>> Output;
}

public class FloatInput
{
[ColumnName("input")]
[VectorType(3)]
public float[] Input { get; set; }
}

public OnnxSequenceTypeWithAttributesTest(ITestOutputHelper output) : base(output)
public OnnxSequenceTypeTest(ITestOutputHelper output) : base(output)
{
}
public static PredictionEngine<FloatInput, OutputObj> LoadModel(string onnxModelFilePath)

private static OnnxTransformer PrepareModel(string onnxModelFilePath, MLContext ctx)
{
var ctx = new MLContext();
var dataView = ctx.Data.LoadFromEnumerable(new List<FloatInput>());

var pipeline = ctx.Transforms.ApplyOnnxModel(
modelFile: onnxModelFilePath,
outputColumnNames: new[] { "output" }, inputColumnNames: new[] { "input" });

var model = pipeline.Fit(dataView);
return model;
}

public static PredictionEngine<FloatInput, OutputObj> LoadModel(string onnxModelFilePath)
{
var ctx = new MLContext();
var model = PrepareModel(onnxModelFilePath, ctx);
return ctx.Model.CreatePredictionEngine<FloatInput, OutputObj>(model);
}

Expand All @@ -64,5 +83,21 @@ public void OnnxSequenceTypeWithColumnNameAttributeTest()
}

}

private static PredictionEngine<FloatInput, ProblematicOutputObj> CreatePredictorWithProblematicOutputObj()
{
var onnxModelFilePath = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapString.onnx");

var ctx = new MLContext();
var model = PrepareModel(onnxModelFilePath, ctx);
return ctx.Model.CreatePredictionEngine<FloatInput, ProblematicOutputObj>(model);
}

[OnnxFact]
public void OnnxSequenceTypeWithouSpecifySequenceTypeTest()
{
InvalidOperationException ex = Assert.Throws<InvalidOperationException>(() => CreatePredictorWithProblematicOutputObj());
Assert.Equal("Please specify sequence type when use OnnxSequenceType Attribute.", ex.Message);
}

@harishsk harishsk Oct 2, 2019

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to assert on the type of exception rather than the specific message? #Resolved

@frank-dong-ms-zz frank-dong-ms-zz Oct 2, 2019

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the exception type is assert in previous line Assert.Throws, after change to use InvalidOperationException, it should look like Assert.Throws(....), so we can assert both on exception type and exception message.


In reply to: 330651709 [](ancestors = 330651709)

}
}