Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ internal abstract class OnnxNode
public abstract void AddAttribute(string argName, string[] value);
public abstract void AddAttribute(string argName, IEnumerable<string> value);
public abstract void AddAttribute(string argName, IEnumerable<bool> value);
public abstract void AddAttribute(string argName, Type t);
}
}
19 changes: 7 additions & 12 deletions src/Microsoft.ML.Data/Transforms/TypeConverting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ internal static bool GetNewType(IExceptionContext ectx, DataViewType srcType, In
return true;
}

private sealed class Mapper : OneToOneMapperBase, ICanSaveOnnx
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly TypeConvertingTransformer _parent;
private readonly DataViewType[] _types;
private readonly int[] _srcCols;

public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
public bool CanSaveOnnx(OnnxContext ctx) => true;

public Mapper(TypeConvertingTransformer parent, DataViewSchema inputSchema)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
Expand Down Expand Up @@ -497,22 +497,17 @@ public void SaveAsOnnx(OnnxContext ctx)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
var opType = "CSharp";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("type", LoaderSignature);
node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind);
if (_parent._columns[iinfo].OutputKeyCount != null)
{
var key = (KeyDataViewType)_types[iinfo].GetItemType();
node.AddAttribute("max", key.Count);
}
var opType = "Cast";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType();

@ganik ganik Aug 29, 2019

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var t [](start = 16, length = 6)

Need to handle key type #Resolved

node.AddAttribute("to", t);
return true;
}
}
}

/// <summary>
/// Estimator for <see cref="KeyToVectorMappingTransformer"/>. Converts the underlying input column type to a new type.
/// Estimator for <see cref="TypeConvertingTransformer"/>. Converts the underlying input column type to a new type.
/// The input and output column types need to be compatible.
/// <see cref="PrimitiveDataViewType"/>
/// </summary>
Expand Down
8 changes: 7 additions & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,17 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
var addNodeY = ctx.CreateNode("Add", new[] { nameZ, nameC2 }, new[] { nameY }, ctx.GetNodeName("Add"), "");

// Compute the most-matched cluster index, L
var nameL = outputNames[0];
var nameL = "ArgMinInt64";
var predictNodeL = ctx.CreateNode("ArgMin", nameY, nameL, ctx.GetNodeName("ArgMin"), "");
predictNodeL.AddAttribute("axis", 1);
predictNodeL.AddAttribute("keepdims", 1);

// ArgMin outputs an Int64. But ML.NET's KMeans trainer outputs a UINT32.
// Cast the output here to UInt32 to make them compatible
var predictedNode = ctx.CreateNode("Cast", nameL, outputNames[0], ctx.GetNodeName("Cast"), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
predictedNode.AddAttribute("to", t);

return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ public override void AddAttribute(string argName, string value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, bool value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, Type value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
}
}
79 changes: 48 additions & 31 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ private static AttributeProto MakeAttribute(string key)
return attribute;
}

private static AttributeProto MakeAttribute(string key, TensorProto.Types.DataType value)
{
AttributeProto attribute = MakeAttribute(key);
attribute.Type = AttributeProto.Types.AttributeType.Int;
attribute.I = (int)value;
return attribute;
}

private static AttributeProto MakeAttribute(string key, double value)
{
AttributeProto attribute = MakeAttribute(key);
Expand Down Expand Up @@ -211,6 +219,45 @@ public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable

public static void NodeAddAttributes(NodeProto node, string argName, bool value)
=> node.Attribute.Add(MakeAttribute(argName, value));
public static void NodeAddAttributes(NodeProto node, string argName, Type value)
=> node.Attribute.Add(MakeAttribute(argName, ConvertToTensorProtoType(value)));

private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType)
{
var dataType = TensorProto.Types.DataType.Undefined;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Uint32;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + rawType.ToString();
Contracts.Check(false, msg);
}

return dataType;
}

private static ByteString StringToByteString(ReadOnlyMemory<char> str) => ByteString.CopyFrom(Encoding.UTF8.GetBytes(str.ToString()));
private static IEnumerable<ByteString> StringToByteString(IEnumerable<ReadOnlyMemory<char>> str)
Expand Down Expand Up @@ -295,42 +342,12 @@ public static ModelArgs GetModelArgs(DataViewType type, string colName,
Contracts.CheckValue(type, nameof(type));
Contracts.CheckNonEmpty(colName, nameof(colName));

TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
Type rawType;
if (type is VectorDataViewType vectorType)
rawType = vectorType.ItemType.RawType;
else
rawType = type.RawType;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + type.ToString();
Contracts.Check(false, msg);
}
var dataType = ConvertToTensorProtoType(rawType);

string name = colName;
List<long> dimsLocal = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@
"name": "F20",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@
"name": "F20",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
21 changes: 19 additions & 2 deletions test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"Score"
],
"output": [
"PredictedLabel"
"ArgMinInt64"
],
"name": "ArgMin",
"opType": "ArgMin",
Expand All @@ -126,6 +126,23 @@
}
]
},
{
"input": [
"ArgMinInt64"
],
"output": [
"PredictedLabel"
],
"name": "Cast",
"opType": "Cast",
"attribute": [
{
"name": "to",
"i": "12",
"type": "INT"
}
]
},
{
"input": [
"Features0"
Expand Down Expand Up @@ -272,7 +289,7 @@
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@
"name": "Label1",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down Expand Up @@ -348,7 +348,7 @@
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down Expand Up @@ -404,7 +404,7 @@
"name": "Label0",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
Loading