Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -510,20 +510,40 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType();
castNode.AddAttribute("to", t);

var labelEncoderOutput = dstVariableName;
var labelEncoderInput = srcVariableName;
if (TypeOutput == NumberDataViewType.Double || TypeOutput == NumberDataViewType.Int64)
labelEncoderOutput = ctx.AddIntermediateVariable(TypeOutput, "CastNodeOutput", true);

opType = "LabelEncoder";
var node = ctx.CreateNode(opType, castNodeOutput, dstVariableName, ctx.GetNodeName(opType));
var node = ctx.CreateNode(opType, castNodeOutput, labelEncoderOutput, ctx.GetNodeName(opType));
var keys = Array.ConvertAll<int, long>(Enumerable.Range(1, _values.Length).ToArray(), item => Convert.ToInt64(item));
node.AddAttribute("keys_int64s", keys);

if (TypeOutput == NumberDataViewType.Int64)
{
long[] values = Array.ConvertAll<TValue, long>(_values.GetValues().ToArray(), item => Convert.ToInt64(item));
node.AddAttribute("values_int64s", values);
// LabelEncoder doesn't support mapping int64 -> int64, so values are converted to strings and later cast back to Int64s
string[] values = Array.ConvertAll<TValue, string>(_values.GetValues().ToArray(), item => Convert.ToString(item));
node.AddAttribute("values_strings", values);
opType = "Cast";
castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType();
castNode.AddAttribute("to", t);
}
else if (TypeOutput == NumberDataViewType.Single)
{
float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item));
node.AddAttribute("values_floats", values);
}
else if (TypeOutput == NumberDataViewType.Double || TypeOutput == NumberDataViewType.Single)
else if (TypeOutput == NumberDataViewType.Double)
{
// LabelEncoder doesn't support double tensors, so values are converted to floats and later cast back to doubles
float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item));
node.AddAttribute("values_floats", values);
opType = "Cast";
castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Double).ToType();
castNode.AddAttribute("to", t);
}
else if (TypeOutput == TextDataViewType.Instance)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
}
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Double))
{
// LabelEncoder doesn't support double tensors, so values are cast to floats
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
Expand All @@ -813,6 +814,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
}
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Int64))
{
// LabelEncoder doesn't support mapping int64 -> int64, so values are cast to strings
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
Expand Down
11 changes: 7 additions & 4 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ public void IndicateMissingValuesOnnxConversionTest()
[InlineData(DataKind.Int64)]
[InlineData(DataKind.Double)]
[InlineData(DataKind.String)]
public void ValueToKeyMappingOnnxConversionTest(DataKind valueType)
public void ValueToKeyandKeyToValueMappingOnnxConversionTest(DataKind valueType)
{
var mlContext = new MLContext(seed: 1);
string filePath = GetDataPath("type-conversion.txt");
Expand All @@ -1106,7 +1106,8 @@ public void ValueToKeyMappingOnnxConversionTest(DataKind valueType)
};
var dataView = mlContext.Data.LoadFromTextFile(filePath, columns);

var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Key", "Value");
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Key", "Value").
Append(mlContext.Transforms.Conversion.MapKeyToValue("ValueOutput", "Key"));
var model = pipeline.Fit(dataView);
var mlnetResult = model.Transform(dataView);

Expand All @@ -1123,9 +1124,9 @@ public void ValueToKeyMappingOnnxConversionTest(DataKind valueType)
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);

CompareSelectedVectorColumns<UInt32>(model.ColumnPairs[0].outputColumnName, outputNames[1], mlnetResult, onnxResult);
CompareResults(mlnetResult.Schema[2].Name, outputNames[2], mlnetResult, onnxResult); //compare output values
CompareSelectedVectorColumns<UInt32>(mlnetResult.Schema[1].Name, outputNames[1], mlnetResult, onnxResult); //compare keys
}

Done();
}

Expand Down Expand Up @@ -1555,6 +1556,8 @@ private void CompareResults(string leftColumnName, string rightColumnName, IData
CompareSelectedR4VectorColumns(leftColumnName, rightColumnName, left, right);
else if (leftType == NumberDataViewType.Double)
CompareSelectedVectorColumns<double>(leftColumnName, rightColumnName, left, right);
else if (leftType == TextDataViewType.Instance)
CompareSelectedVectorColumns<ReadOnlyMemory<char>>(leftColumnName, rightColumnName, left, right);
}

private void CompareSelectedVectorColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
Expand Down