Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ internal abstract class OnnxNode
public abstract void AddAttribute(string argName, string[] value);
public abstract void AddAttribute(string argName, IEnumerable<string> value);
public abstract void AddAttribute(string argName, IEnumerable<bool> value);
public abstract void AddAttribute(string argName, Type t);
}
}
11 changes: 3 additions & 8 deletions src/Microsoft.ML.Data/Transforms/TypeConverting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -497,15 +497,10 @@ public void SaveAsOnnx(OnnxContext ctx)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
var opType = "CSharp";
var opType = "Cast";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("type", LoaderSignature);
node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind);
if (_parent._columns[iinfo].OutputKeyCount != null)
{
var key = (KeyDataViewType)_types[iinfo].GetItemType();
node.AddAttribute("max", key.Count);
}
var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType();

@ganik ganik Aug 29, 2019

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var t [](start = 16, length = 6)

Need to handle key type #Resolved

node.AddAttribute("to", t);
return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ public override void AddAttribute(string argName, string value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, bool value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, Type value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
}
}
79 changes: 48 additions & 31 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ private static AttributeProto MakeAttribute(string key)
return attribute;
}

private static AttributeProto MakeAttribute(string key, TensorProto.Types.DataType value)
{
AttributeProto attribute = MakeAttribute(key);
attribute.Type = AttributeProto.Types.AttributeType.Int;
attribute.I = (int)value;
return attribute;
}

private static AttributeProto MakeAttribute(string key, double value)
{
AttributeProto attribute = MakeAttribute(key);
Expand Down Expand Up @@ -211,6 +219,45 @@ public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable

public static void NodeAddAttributes(NodeProto node, string argName, bool value)
=> node.Attribute.Add(MakeAttribute(argName, value));
public static void NodeAddAttributes(NodeProto node, string argName, Type value)
=> node.Attribute.Add(MakeAttribute(argName, ConvertToTensorProtoType(value)));

private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType)
{
var dataType = TensorProto.Types.DataType.Undefined;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + rawType.ToString();
Contracts.Check(false, msg);
}

return dataType;
}

private static ByteString StringToByteString(ReadOnlyMemory<char> str) => ByteString.CopyFrom(Encoding.UTF8.GetBytes(str.ToString()));
private static IEnumerable<ByteString> StringToByteString(IEnumerable<ReadOnlyMemory<char>> str)
Expand Down Expand Up @@ -295,42 +342,12 @@ public static ModelArgs GetModelArgs(DataViewType type, string colName,
Contracts.CheckValue(type, nameof(type));
Contracts.CheckNonEmpty(colName, nameof(colName));

TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
Type rawType;
if (type is VectorDataViewType vectorType)
rawType = vectorType.ItemType.RawType;
else
rawType = type.RawType;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + type.ToString();
Contracts.Check(false, msg);
}
var dataType = ConvertToTensorProtoType(rawType);

string name = colName;
List<long> dimsLocal = null;
Expand Down