Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
long[] values = Array.ConvertAll<TValue, long>(_values.GetValues().ToArray(), item => Convert.ToInt64(item));
node.AddAttribute("values_int64s", values);
}
else if (TypeOutput == NumberDataViewType.Single)
else if (TypeOutput == NumberDataViewType.Double || TypeOutput == NumberDataViewType.Single)
{
float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item));
node.AddAttribute("values_floats", values);
Expand Down
23 changes: 22 additions & 1 deletion src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
OnnxNode node;
long[] termIds;
string opType = "LabelEncoder";
OnnxNode castNode;
var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true);

if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance))
Expand All @@ -800,6 +801,26 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
var terms = GetTermsAndIds<float>(iinfo, out termIds);
node.AddAttribute("keys_floats", terms);
}
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Double))
{
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t2 = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
castNode.AddAttribute("to", t2);
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
var terms = GetTermsAndIds<double>(iinfo, out termIds);
node.AddAttribute("keys_floats", terms);
}
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Int64))
Comment thread
kere-nel marked this conversation as resolved.
{
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
castNode.AddAttribute("to", t);
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
var terms = GetTermsAndIds<long>(iinfo, out termIds);
node.AddAttribute("keys_strings", terms.Select(item => item.ToString()));
}
else
Comment thread
kere-nel marked this conversation as resolved.
{
// LabelEncoder-2 in ORT v1 only supports the following mappings
Expand All @@ -822,7 +843,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
InternalDataKindExtensions.TryGetDataKind(_parent._unboundMaps[iinfo].OutputType.RawType, out dataKind);

opType = "Cast";
var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
castNode.AddAttribute("to", dataKind.ToType());

return true;
Expand Down