Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ internal abstract class OnnxNode
public abstract void AddAttribute(string argName, string[] value);
public abstract void AddAttribute(string argName, IEnumerable<string> value);
public abstract void AddAttribute(string argName, IEnumerable<bool> value);
public abstract void AddAttribute(string argName, Type t);
}
}
19 changes: 7 additions & 12 deletions src/Microsoft.ML.Data/Transforms/TypeConverting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ internal static bool GetNewType(IExceptionContext ectx, DataViewType srcType, In
return true;
}

private sealed class Mapper : OneToOneMapperBase, ICanSaveOnnx
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly TypeConvertingTransformer _parent;
private readonly DataViewType[] _types;
private readonly int[] _srcCols;

public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
public bool CanSaveOnnx(OnnxContext ctx) => true;

public Mapper(TypeConvertingTransformer parent, DataViewSchema inputSchema)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
Expand Down Expand Up @@ -497,22 +497,17 @@ public void SaveAsOnnx(OnnxContext ctx)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
var opType = "CSharp";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("type", LoaderSignature);
node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind);
if (_parent._columns[iinfo].OutputKeyCount != null)
{
var key = (KeyDataViewType)_types[iinfo].GetItemType();
node.AddAttribute("max", key.Count);
}
var opType = "Cast";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType();

@ganik ganik Aug 29, 2019

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var t [](start = 16, length = 6)

Need to handle key type #Resolved

node.AddAttribute("to", t);
return true;
}
}
}

/// <summary>
/// Estimator for <see cref="KeyToVectorMappingTransformer"/>. Converts the underlying input column type to a new type.
/// Estimator for <see cref="TypeConvertingTransformer"/>. Converts the underlying input column type to a new type.
/// The input and output column types need to be compatible.
/// <see cref="PrimitiveDataViewType"/>
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ public override void AddAttribute(string argName, string value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, bool value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, Type value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
}
}
79 changes: 48 additions & 31 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ private static AttributeProto MakeAttribute(string key)
return attribute;
}

private static AttributeProto MakeAttribute(string key, TensorProto.Types.DataType value)
{
AttributeProto attribute = MakeAttribute(key);
attribute.Type = AttributeProto.Types.AttributeType.Int;
attribute.I = (int)value;
return attribute;
}

private static AttributeProto MakeAttribute(string key, double value)
{
AttributeProto attribute = MakeAttribute(key);
Expand Down Expand Up @@ -211,6 +219,45 @@ public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable

public static void NodeAddAttributes(NodeProto node, string argName, bool value)
=> node.Attribute.Add(MakeAttribute(argName, value));
public static void NodeAddAttributes(NodeProto node, string argName, Type value)
=> node.Attribute.Add(MakeAttribute(argName, ConvertToTensorProtoType(value)));

private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType)
{
var dataType = TensorProto.Types.DataType.Undefined;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Uint32;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + rawType.ToString();
Contracts.Check(false, msg);
}

return dataType;
}

private static ByteString StringToByteString(ReadOnlyMemory<char> str) => ByteString.CopyFrom(Encoding.UTF8.GetBytes(str.ToString()));
private static IEnumerable<ByteString> StringToByteString(IEnumerable<ReadOnlyMemory<char>> str)
Expand Down Expand Up @@ -295,42 +342,12 @@ public static ModelArgs GetModelArgs(DataViewType type, string colName,
Contracts.CheckValue(type, nameof(type));
Contracts.CheckNonEmpty(colName, nameof(colName));

TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
Type rawType;
if (type is VectorDataViewType vectorType)
rawType = vectorType.ItemType.RawType;
else
rawType = type.RawType;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + type.ToString();
Contracts.Check(false, msg);
}
var dataType = ConvertToTensorProtoType(rawType);

string name = colName;
List<long> dimsLocal = null;
Expand Down
155 changes: 155 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,83 @@ public void WordEmbeddingsTest()
Done();
}

[Fact]
public void OnnxTypeConversionTest()
{
var mlContext = new MLContext(seed: 1);
string filePath = GetDataPath("type-conversion.txt");

// These are the supported conversions
// ML.NET does not allow any conversions between signed and unsigned numeric types
// Onnx does not seem to support casting a string to any type
// Though the onnx docs claim support for byte and sbyte,
// CreateNamedOnnxValue in OnnxUtils.cs throws a NotImplementedException for those two
DataKind[,] supportedConversions = new[,]
{
{ DataKind.Int16, DataKind.Int16},
{ DataKind.Int16, DataKind.Int32},
{ DataKind.Int16, DataKind.Int64},
{ DataKind.Int16, DataKind.Single},
{ DataKind.Int16, DataKind.Double},
{ DataKind.UInt16, DataKind.UInt16},
{ DataKind.UInt16, DataKind.UInt32},
{ DataKind.UInt16, DataKind.UInt64},
{ DataKind.UInt16, DataKind.Single},
{ DataKind.UInt16, DataKind.Double},
{ DataKind.Int32, DataKind.Int16},
{ DataKind.Int32, DataKind.Int32},
{ DataKind.Int32, DataKind.Int64},
{ DataKind.Int32, DataKind.Single},
{ DataKind.Int32, DataKind.Double},
{ DataKind.Int64, DataKind.Int16},
{ DataKind.Int64, DataKind.Int32},
{ DataKind.Int64, DataKind.Int64},
{ DataKind.Int64, DataKind.Single},
{ DataKind.Int64, DataKind.Double},
{ DataKind.UInt64, DataKind.UInt16},
{ DataKind.UInt64, DataKind.UInt32},
{ DataKind.UInt64, DataKind.UInt64},
{ DataKind.UInt64, DataKind.Single},
{ DataKind.UInt64, DataKind.Double},
{ DataKind.Single, DataKind.Single},
{ DataKind.Single, DataKind.Double},
{ DataKind.Double, DataKind.Single},
{ DataKind.Double, DataKind.Double}
};

for (int i = 0; i < supportedConversions.GetLength(0); i++)
{
var fromKind = supportedConversions[i, 0];
var toKind = supportedConversions[i, 1];

TextLoader.Column[] columns = new []
{
new TextLoader.Column("Value", fromKind, 0, 0)
};
var dataView = mlContext.Data.LoadFromTextFile(filePath, columns);
Comment thread
harishsk marked this conversation as resolved.

var pipeline = mlContext.Transforms.Conversion.ConvertType("ValueConverted", "Value", outputKind: toKind);
var model = pipeline.Fit(dataView);
var mlnetResult = model.Transform(dataView);

var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
var onnxFileName = "typeconversion.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
{
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);

CompareResults("ValueConverted", "ValueConverted0", mlnetResult, onnxResult);
}
}
}

private void CreateDummyExamplesToMakeComplierHappy()
{
var dummyExample = new BreastCancerFeatureVector() { Features = null };
Expand All @@ -609,6 +686,84 @@ private void CreateDummyExamplesToMakeComplierHappy()
var dummyExample3 = new SmallSentimentExample() { Tokens = null };
}

private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];
var leftType = leftColumn.Type.GetItemType();
var rightType = rightColumn.Type.GetItemType();
Assert.Equal(leftType, rightType);

if (leftType == NumberDataViewType.SByte)
CompareSelectedVectorColumns<sbyte>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.Byte)
CompareSelectedVectorColumns<byte>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.Int16)
CompareSelectedVectorColumns<short>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.UInt16)
CompareSelectedVectorColumns<ushort>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.Int32)
CompareSelectedVectorColumns<int>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.UInt32)
CompareSelectedVectorColumns<uint>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.Int64)
CompareSelectedVectorColumns<long>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.UInt64)
CompareSelectedVectorColumns<ulong>(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.Single)
CompareSelectedR4VectorColumns(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.Double)
CompareSelectedVectorColumns<double>(leftColumnName, rightColumnName, left, right);
}

private void CompareSelectedVectorColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];

using (var expectedCursor = left.GetRowCursor(leftColumn))
using (var actualCursor = right.GetRowCursor(rightColumn))
{
VBuffer<T> expected = default;
VBuffer<T> actual = default;
var expectedGetter = expectedCursor.GetGetter<VBuffer<T>>(leftColumn);
var actualGetter = actualCursor.GetGetter<VBuffer<T>>(rightColumn);
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
{
expectedGetter(ref expected);
actualGetter(ref actual);

Assert.Equal(expected.Length, actual.Length);
for (int i = 0; i < expected.Length; ++i)
Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i));
}
}
}

private void CompareSelectedR8VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];

using (var expectedCursor = left.GetRowCursor(leftColumn))
using (var actualCursor = right.GetRowCursor(rightColumn))
{
VBuffer<double> expected = default;
VBuffer<double> actual = default;
var expectedGetter = expectedCursor.GetGetter<VBuffer<double>>(leftColumn);
var actualGetter = actualCursor.GetGetter<VBuffer<double>>(rightColumn);
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
{
expectedGetter(ref expected);
actualGetter(ref actual);

Assert.Equal(expected.Length, actual.Length);
for (int i = 0; i < expected.Length; ++i)
Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision);
}
}
}

private void CompareSelectedR4VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6)
{
var leftColumn = left.Schema[leftColumnName];
Expand Down
1 change: 1 addition & 0 deletions test/data/type-conversion.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3