diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index c1863a32cf..166f14f3eb 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -510,20 +510,40 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType(); castNode.AddAttribute("to", t); + var labelEncoderOutput = dstVariableName; + var labelEncoderInput = srcVariableName; + if (TypeOutput == NumberDataViewType.Double || TypeOutput == NumberDataViewType.Int64) + labelEncoderOutput = ctx.AddIntermediateVariable(TypeOutput, "CastNodeOutput", true); + opType = "LabelEncoder"; - var node = ctx.CreateNode(opType, castNodeOutput, dstVariableName, ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, castNodeOutput, labelEncoderOutput, ctx.GetNodeName(opType)); var keys = Array.ConvertAll(Enumerable.Range(1, _values.Length).ToArray(), item => Convert.ToInt64(item)); node.AddAttribute("keys_int64s", keys); if (TypeOutput == NumberDataViewType.Int64) { - long[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToInt64(item)); - node.AddAttribute("values_int64s", values); + // LabelEncoder doesn't support mapping int64 -> int64, so values are converted to strings and later cast back to Int64s + string[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToString(item)); + node.AddAttribute("values_strings", values); + opType = "Cast"; + castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), ""); + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType(); + castNode.AddAttribute("to", t); + } + else if (TypeOutput == NumberDataViewType.Single) + { + float[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToSingle(item)); + node.AddAttribute("values_floats", values); } - else if (TypeOutput == NumberDataViewType.Double || TypeOutput == NumberDataViewType.Single) + else if (TypeOutput == NumberDataViewType.Double) { + // LabelEncoder doesn't support double tensors, so values are converted to floats and later cast back to doubles float[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToSingle(item)); node.AddAttribute("values_floats", values); + opType = "Cast"; + castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), ""); + t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Double).ToType(); + castNode.AddAttribute("to", t); } else if (TypeOutput == TextDataViewType.Instance) { diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 4f39d5af0b..f8cea9e8af 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -803,6 +803,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src } else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Double)) { + // LabelEncoder doesn't support double tensors, so values are cast to floats var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true); castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); @@ -813,6 +814,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src } else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Int64)) { + // LabelEncoder doesn't support mapping int64 -> int64, so values are cast to strings var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true); castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 7fd42545fb..e965be2f6a 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1095,7 +1095,7 @@ public void IndicateMissingValuesOnnxConversionTest() [InlineData(DataKind.Int64)] [InlineData(DataKind.Double)] [InlineData(DataKind.String)] - public void ValueToKeyMappingOnnxConversionTest(DataKind valueType) + public void ValueToKeyandKeyToValueMappingOnnxConversionTest(DataKind valueType) { var mlContext = new MLContext(seed: 1); string filePath = GetDataPath("type-conversion.txt"); @@ -1106,7 +1106,8 @@ public void ValueToKeyMappingOnnxConversionTest(DataKind valueType) }; var dataView = mlContext.Data.LoadFromTextFile(filePath, columns); - var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Key", "Value"); + var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Key", "Value"). + Append(mlContext.Transforms.Conversion.MapKeyToValue("ValueOutput", "Key")); var model = pipeline.Fit(dataView); var mlnetResult = model.Transform(dataView); @@ -1123,9 +1124,9 @@ public void ValueToKeyMappingOnnxConversionTest(DataKind valueType) var onnxTransformer = onnxEstimator.Fit(dataView); var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedVectorColumns(model.ColumnPairs[0].outputColumnName, outputNames[1], mlnetResult, onnxResult); + CompareResults(mlnetResult.Schema[2].Name, outputNames[2], mlnetResult, onnxResult); //compare output values + CompareSelectedVectorColumns(mlnetResult.Schema[1].Name, outputNames[1], mlnetResult, onnxResult); //compare keys } - Done(); } @@ -1555,6 +1556,8 @@ private void CompareResults(string leftColumnName, string rightColumnName, IData CompareSelectedR4VectorColumns(leftColumnName, rightColumnName, left, right); else if (leftType == NumberDataViewType.Double) CompareSelectedVectorColumns(leftColumnName, rightColumnName, left, right); + else if (leftType == TextDataViewType.Instance) + CompareSelectedVectorColumns>(leftColumnName, rightColumnName, left, right); } private void CompareSelectedVectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right)