diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs index 37fd961258..a4b5ae30e3 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingWithInMemoryCustomType.cs @@ -84,7 +84,7 @@ public AlienTypeAttributeAttribute(int raceId) public override void Register() { DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), - typeof(AlienBody), new[] { this }); + typeof(AlienBody), this); } public override bool Equals(DataViewTypeAttribute other) diff --git a/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs b/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs index fda18633ff..b390c9167a 100644 --- a/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs +++ b/src/Microsoft.ML.Data/Data/DataViewTypeManager.cs @@ -53,10 +53,25 @@ public static class DataViewTypeManager /// internal static DataViewType GetDataViewType(Type type, IEnumerable typeAttributes = null) { + //Filter attributes as we only care about DataViewTypeAttribute + DataViewTypeAttribute typeAttr = null; + if(typeAttributes != null) + { + typeAttributes = typeAttributes.Where(attr => attr.GetType().IsSubclassOf(typeof(DataViewTypeAttribute))); + if (typeAttributes.Count() > 1) + { + throw Contracts.ExceptParam(nameof(type), "Type {0} cannot be marked with multiple attributes, {1}, derived from {2}.", + type.Name, typeAttributes, typeof(DataViewTypeAttribute)); + } + else if (typeAttributes.Count() == 1) + { + typeAttr = typeAttributes.First() as DataViewTypeAttribute; + } + } lock (_lock) { // Compute the ID of type with extra attributes. - var rawType = new TypeWithAttributes(type, typeAttributes); + var rawType = new TypeWithAttributes(type, typeAttr); // Get the DataViewType's ID which typeID is mapped into. if (!_rawTypeToDataViewTypeMap.TryGetValue(rawType, out DataViewType dataViewType)) @@ -73,10 +88,25 @@ internal static DataViewType GetDataViewType(Type type, IEnumerable t /// internal static bool Knows(Type type, IEnumerable typeAttributes = null) { + //Filter attributes as we only care about DataViewTypeAttribute + DataViewTypeAttribute typeAttr = null; + if(typeAttributes != null) + { + typeAttributes = typeAttributes.Where(attr => attr.GetType().IsSubclassOf(typeof(DataViewTypeAttribute))); + if (typeAttributes.Count() > 1) + { + throw Contracts.ExceptParam(nameof(type), "Type {0} cannot be marked with multiple attributes, {1}, derived from {2}.", + type.Name, typeAttributes, typeof(DataViewTypeAttribute)); + } + else if (typeAttributes.Count() == 1) + { + typeAttr = typeAttributes.First() as DataViewTypeAttribute; + } + } lock (_lock) { // Compute the ID of type with extra attributes. - var rawType = new TypeWithAttributes(type, typeAttributes); + var rawType = new TypeWithAttributes(type, typeAttr); // Check if this ID has been associated with a DataViewType. // Note that the dictionary below contains (rawType, dataViewType) pairs (key type is TypeWithAttributes, and value type is DataViewType). @@ -111,7 +141,39 @@ internal static bool Knows(DataViewType dataViewType) /// Native type in C#. /// The corresponding type of in ML.NET's type system. /// The s attached to . - public static void Register(DataViewType dataViewType, Type type, IEnumerable typeAttributes = null) + [Obsolete("This API is depricated, please use the new form of Register which takes in a single DataViewTypeAttribute instead.", false)] + public static void Register(DataViewType dataViewType, Type type, IEnumerable typeAttributes) + { + DataViewTypeAttribute typeAttr = null; + if (typeAttributes != null) + { + if (typeAttributes.Count() > 1) + { + throw Contracts.ExceptParam(nameof(type), $"Type {type} has too many attributes."); + } + else if (typeAttributes.Count() == 1) + { + var attr = typeAttributes.First(); + if (!attr.GetType().IsSubclassOf(typeof(DataViewTypeAttribute))) + { + throw Contracts.ExceptParam(nameof(type), $"Type {type} has an attribute that is not of DataViewTypeAttribute."); + } + else + { + typeAttr = attr as DataViewTypeAttribute; + } + } + } + Register(dataViewType, type, typeAttr); + } + /// + /// This function tells that should be representation of data in in + /// ML.NET's type system. The registered must be a standard C# object's type. + /// + /// Native type in C#. + /// The corresponding type of in ML.NET's type system. + /// The attached to . + public static void Register(DataViewType dataViewType, Type type, DataViewTypeAttribute typeAttribute = null) { lock (_lock) { @@ -119,7 +181,7 @@ public static void Register(DataViewType dataViewType, Type type, IEnumerable - /// An instance of represents an unique key of its and . + /// An instance of represents an unique key of its and . /// private class TypeWithAttributes { @@ -162,16 +224,16 @@ private class TypeWithAttributes public Type TargetType { get; } /// - /// The underlying type's attributes. Together with , uniquely defines + /// The underlying type's attributes. Together with , uniquely defines /// a key when using as the key type in . Note that the /// uniqueness is determined by and below. /// - private IEnumerable _associatedAttributes; + private DataViewTypeAttribute _associatedAttribute; - public TypeWithAttributes(Type type, IEnumerable attributes) + public TypeWithAttributes(Type type, DataViewTypeAttribute attribute) { TargetType = type; - _associatedAttributes = attributes; + _associatedAttribute = attribute; } public override bool Equals(object obj) @@ -183,22 +245,15 @@ public override bool Equals(object obj) // Flag of having the attribute configurations. var sameAttributeConfig = true; - if (_associatedAttributes == null && other._associatedAttributes == null) + if (_associatedAttribute == null && other._associatedAttribute == null) sameAttributeConfig = true; - else if (_associatedAttributes == null && other._associatedAttributes != null) + else if (_associatedAttribute == null && other._associatedAttribute != null) sameAttributeConfig = false; - else if (_associatedAttributes != null && other._associatedAttributes == null) - sameAttributeConfig = false; - else if (_associatedAttributes.Count() != other._associatedAttributes.Count()) + else if (_associatedAttribute != null && other._associatedAttribute == null) sameAttributeConfig = false; else { - var zipped = _associatedAttributes.Zip(other._associatedAttributes, (attr, otherAttr) => (attr, otherAttr)); - foreach (var (attr, otherAttr) in zipped) - { - if (!attr.Equals(otherAttr)) - sameAttributeConfig = false; - } + sameAttributeConfig = _associatedAttribute.Equals(other._associatedAttribute); } return sameType && sameAttributeConfig; @@ -213,12 +268,14 @@ public override bool Equals(object obj) /// public override int GetHashCode() { - if (_associatedAttributes == null) + if (_associatedAttribute == null) return TargetType.GetHashCode(); var code = TargetType.GetHashCode(); - foreach (var attr in _associatedAttributes) - code = Hashing.CombineHash(code, attr.GetHashCode()); + if (_associatedAttribute != null) + { + code = Hashing.CombineHash(code, _associatedAttribute.GetHashCode()); + } return code; } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs index 6a30084f81..a9ac823b2c 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageType.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs @@ -64,7 +64,7 @@ public override int GetHashCode() public override void Register() { - DataViewTypeManager.Register(new ImageDataViewType(Height, Width), typeof(Bitmap), new[] { this }); + DataViewTypeManager.Register(new ImageDataViewType(Height, Width), typeof(Bitmap), this ); } } diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs b/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs index 028fd3f8df..a7ae3cbfd8 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs @@ -24,7 +24,7 @@ public sealed class OnnxMapType : StructuredDataViewType /// Value type of the associated ONNX map. public OnnxMapType(Type keyType, Type valueType) : base(typeof(IDictionary<,>).MakeGenericType(keyType, valueType)) { - DataViewTypeManager.Register(this, RawType, new[] { new OnnxMapTypeAttribute(keyType, valueType) }); + DataViewTypeManager.Register(this, RawType, new OnnxMapTypeAttribute(keyType, valueType)); } public override bool Equals(DataViewType other) @@ -95,7 +95,7 @@ public override void Register() { var enumerableType = typeof(IDictionary<,>); var type = enumerableType.MakeGenericType(_keyType, _valueType); - DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, new[] { this }); + DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, this); } } } diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs b/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs index acfca70e47..c8b310b45c 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs @@ -29,7 +29,7 @@ private static Type MakeNativeType(Type elementType) /// The element type of a sequence. public OnnxSequenceType(Type elementType) : base(MakeNativeType(elementType)) { - DataViewTypeManager.Register(this, RawType, new[] { new OnnxSequenceTypeAttribute(elementType) }); + DataViewTypeManager.Register(this, RawType, new OnnxSequenceTypeAttribute(elementType)); } public override bool Equals(DataViewType other) @@ -96,7 +96,7 @@ public override void Register() { var enumerableType = typeof(IEnumerable<>); var type = enumerableType.MakeGenericType(_elemType); - DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, new[] { this }); + DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, this); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index b51d8952d5..b0660bfe08 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -58,7 +58,7 @@ public AlienTypeAttributeAttribute(int raceId) /// public override void Register() { - DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), new[] { this }); + DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), this); } public override bool Equals(DataViewTypeAttribute other) @@ -243,7 +243,7 @@ public void TestTypeManager() { // "a" has been registered with AlienBody without any attribute, so the user can't // register "a" again with AlienBody plus the attribute "c." - DataViewTypeManager.Register(a, typeof(AlienBody), new[] { c }); + DataViewTypeManager.Register(a, typeof(AlienBody), c); } catch { @@ -268,14 +268,30 @@ public void TestTypeManager() // Register a type with attribute. var e = new DataViewAlienBodyType(7788); var f = new AlienTypeAttributeAttribute(8877); - DataViewTypeManager.Register(e, typeof(AlienBody), new[] { f }); + DataViewTypeManager.Register(e, typeof(AlienBody), f); Assert.True(DataViewTypeManager.Knows(e)); Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f })); - Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f })); // "e" is associated with typeof(AlienBody) with "f," so the call below should return true. Assert.Equal(e, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f })); // "a" is associated with typeof(AlienBody) without any attribute, so the call below should return false. Assert.NotEqual(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f })); } + + [Fact] + public void GetTypeWithAdditionalDataViewTypeAttributes() + { + var a = new DataViewAlienBodyType(7788); + var b = new AlienTypeAttributeAttribute(8877); + var c = new ColumnNameAttribute("foo"); + var d = new AlienTypeAttributeAttribute(8876); + + + DataViewTypeManager.Register(a, typeof(AlienBody), b); + Assert.True(DataViewTypeManager.Knows(a)); + Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new Attribute[] { b, c })); + // "a" is associated with typeof(AlienBody) with "b," so the call below should return true. + Assert.Equal(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new Attribute[] { b, c })); + Assert.Throws(() => DataViewTypeManager.Knows(typeof(AlienBody), new Attribute[] { b, d })); + } } } diff --git a/test/Microsoft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs b/test/Microsoft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs new file mode 100644 index 0000000000..e8090aef81 --- /dev/null +++ b/test/Microsoft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Drawing; +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Transforms.Image; +using Microsoft.ML.Transforms.Onnx; +using Xunit; +using Xunit.Abstractions; +using System.Linq; +using System.IO; +using Microsoft.ML.TestFramework.Attributes; + +namespace Microsoft.ML.Tests +{ + public class OnnxSequenceTypeWithAttributesTest : BaseTestBaseline + { + public class OutputObj + { + [ColumnName("output")] + [OnnxSequenceType(typeof(IDictionary))] + public IEnumerable> Output; + } + public class FloatInput + { + [ColumnName("input")] + [VectorType(3)] + public float[] Input { get; set; } + } + + public OnnxSequenceTypeWithAttributesTest(ITestOutputHelper output) : base(output) + { + } + public static PredictionEngine LoadModel(string onnxModelFilePath) + { + var ctx = new MLContext(); + var dataView = ctx.Data.LoadFromEnumerable(new List()); + + var pipeline = ctx.Transforms.ApplyOnnxModel( + modelFile: onnxModelFilePath, + outputColumnNames: new[] { "output" }, inputColumnNames: new[] { "input" }); + + var model = pipeline.Fit(dataView); + return ctx.Model.CreatePredictionEngine(model); + } + + [OnnxFact] + public void OnnxSequenceTypeWithColumnNameAttributeTest() + { + var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapString.onnx"); + var predictor = LoadModel(modelFile); + + FloatInput input = new FloatInput() { Input = new float[] { 1.0f, 2.0f, 3.0f } }; + var output = predictor.Predict(input); + var onnx_out = output.Output.FirstOrDefault(); + Assert.True(onnx_out.Count == 3, "Output missing data."); + var keys = new List(onnx_out.Keys); + for(var i =0; i < onnx_out.Count; ++i) + { + Assert.Equal(onnx_out[keys[i]], input.Input[i]); + } + + } + } +}