diff --git a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs index d29464d5fc..bde5bcd373 100644 --- a/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs @@ -15,6 +15,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; using Newtonsoft.Json; using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper; @@ -241,6 +242,52 @@ private static void AddSlotNames(OnnxContextImpl ctx, DataViewSchema.Column colu ctx.AddOutputVariable(NumberDataViewType.Int64, labelEncoderOutput); } + // Checks if a column has KeyValues Annotations of any type, + // So to know if it is safe to use KeyToValue Transformer on it. + private bool HasKeyValues(DataViewSchema.Column column) + { + if (column.Type.GetItemType() is KeyDataViewType keyType) + { + var metaColumn = column.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues); + return metaColumn != null && + metaColumn.Value.Type is VectorDataViewType vectorType && + keyType.Count == (ulong)vectorType.Size; + } + + return false; + } + + // Get the names of the KeyDataViewType columns that aren't affected by the pipeline that is being exported to ONNX. + private HashSet GetPassThroughKeyDataViewTypeColumnsNames(IDataView source, IDataView end) + { + var inputKeyDataViewTypeColumnsNames = new HashSet(); + foreach (var col in source.Schema) + if (col.IsHidden == false && HasKeyValues(col)) + inputKeyDataViewTypeColumnsNames.Add(col.Name); + + var passThroughColumnNames = new HashSet(); + var outputColumnNames = new HashSet(); + foreach (var col in end.Schema) + { + if (outputColumnNames.Contains(col.Name)) + { + // "Pass through" column names appear only once in the output schema + passThroughColumnNames.Remove(col.Name); + } + else + { + // We are only interested in the KeyDataViewType outpus columns + if (col.IsHidden == false && HasKeyValues(col)) + passThroughColumnNames.Add(col.Name); + } + outputColumnNames.Add(col.Name); + } + + // Only count those columns that were in the input of the pipeline + passThroughColumnNames.IntersectWith(inputKeyDataViewTypeColumnsNames); + return passThroughColumnNames; + } + private void Run(IChannel ch) { ILegacyDataLoader loader = null; @@ -308,6 +355,20 @@ private void Run(IChannel ch) Host.Assert(scorePipe.Source == end); end = scorePipe; transforms.AddLast(scoreOnnx); + + if(rawPred.PredictionKind == PredictionKind.BinaryClassification || rawPred.PredictionKind == PredictionKind.MulticlassClassification) + { + // Check if the PredictedLabel Column is a KeyDataViewType and has KeyValue Annotations. + // If it does, add a KeyToValueMappingTransformer, to enable NimbusML to get the values back + // when using an ONNX model, as described in https://github.com/dotnet/machinelearning/pull/4841 + var predictedLabelColumn = scorePipe.Schema.GetColumnOrNull(DefaultColumnNames.PredictedLabel); + if (predictedLabelColumn.HasValue && HasKeyValues(predictedLabelColumn.Value)) + { + var outputData = new KeyToValueMappingTransformer(Host, DefaultColumnNames.PredictedLabel).Transform(scorePipe); + end = outputData; + transforms.AddLast(outputData as ITransformCanSaveOnnx); + } + } } else { @@ -322,6 +383,18 @@ private void Run(IChannel ch) nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present."); } + // Convert back to values the KeyDataViewType "pass-through" columns + // (i.e those that remained untouched by the model). This is done to enable NimbusML to get these values + // as described in https://github.com/dotnet/machinelearning/pull/4841 + + var passThroughColumnNames = GetPassThroughKeyDataViewTypeColumnsNames(source, end); + foreach (var name in passThroughColumnNames) + { + var outputData = new KeyToValueMappingTransformer(Host, name).Transform(end); + end = outputData; + transforms.AddLast(end as ITransformCanSaveOnnx); + } + var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop); using (var file = Host.CreateOutputFile(_outputModelPath))