-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Adding KeyToValueTransformers before finishing exporting to ONNX #4841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
1a826dc
009a84f
f11c197
3a972f5
3e432a1
11b84d0
a786d9f
cfa264d
86fb14a
afec6d6
3e5b2bd
28e447d
8864d8e
d382373
ff9650e
953ad7b
9c63f1a
d70843f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| using Microsoft.ML.Internal.Utilities; | ||
| using Microsoft.ML.Model.OnnxConverter; | ||
| using Microsoft.ML.Runtime; | ||
| using Microsoft.ML.Transforms; | ||
| using Newtonsoft.Json; | ||
| using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper; | ||
|
|
||
|
|
@@ -241,6 +242,21 @@ private static void AddSlotNames(OnnxContextImpl ctx, DataViewSchema.Column colu | |
| ctx.AddOutputVariable(NumberDataViewType.Int64, labelEncoderOutput); | ||
| } | ||
|
|
||
| // Checks if a column has KeyValues Annotations of any type, | ||
| // So to know if it is safe to use KeyToValue Transformer on it. | ||
| private bool HasKeyValues(DataViewSchema.Column column) | ||
| { | ||
| if (column.Type.GetItemType() is KeyDataViewType keyType) | ||
| { | ||
| var metaColumn = column.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues); | ||
| return metaColumn != null && | ||
| metaColumn.Value.Type is VectorDataViewType vectorType && | ||
| keyType.Count == (ulong)vectorType.Size; | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| private void Run(IChannel ch) | ||
| { | ||
| ILegacyDataLoader loader = null; | ||
|
|
@@ -308,6 +324,20 @@ private void Run(IChannel ch) | |
| Host.Assert(scorePipe.Source == end); | ||
| end = scorePipe; | ||
| transforms.AddLast(scoreOnnx); | ||
|
|
||
| if(rawPred.PredictionKind == PredictionKind.BinaryClassification || rawPred.PredictionKind == PredictionKind.MulticlassClassification) | ||
| { | ||
| // Check if the PredictedLabel Column is a KeyDataViewType and has KeyValue Annotations. | ||
| // If it does, add a KeyToValueMappingTransformer, to enable NimbusML to get the values back | ||
| // when using an ONNX model, as described in https://github.com/dotnet/machinelearning/pull/4841 | ||
| var predictedLabelColumn = scorePipe.Schema.GetColumnOrNull(DefaultColumnNames.PredictedLabel); | ||
| if (predictedLabelColumn.HasValue && HasKeyValues(predictedLabelColumn.Value)) | ||
| { | ||
| var outputData = new KeyToValueMappingTransformer(Host, DefaultColumnNames.PredictedLabel).Transform(scorePipe); | ||
| end = outputData; | ||
| transforms.AddLast(outputData as ITransformCanSaveOnnx); | ||
| } | ||
| } | ||
| } | ||
| else | ||
| { | ||
|
|
@@ -322,6 +352,30 @@ private void Run(IChannel ch) | |
| nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present."); | ||
| } | ||
|
|
||
| // Convert back to values the KeyDataViewType "pass-through" columns | ||
| // (i.e those that remained untouched by the model). This is done to enable NimbusML to get these values | ||
| // as described in https://github.com/dotnet/machinelearning/pull/4841 | ||
|
|
||
| var inputKeyDataViewTypeColumnsNames = new HashSet<string>(); | ||
| foreach (var col in source.Schema) | ||
| if (col.IsHidden == false && HasKeyValues(col)) | ||
| inputKeyDataViewTypeColumnsNames.Add(col.Name); | ||
|
|
||
| var passThroughColumnNames = new HashSet<string>(); | ||
| var onlyDistinctColumnNames = end.Schema.Select(c => c.Name).Distinct(); // only check column names that appear once in the output schema | ||
| foreach (var col in end.Schema) | ||
| if (col.IsHidden == false && onlyDistinctColumnNames.Contains(col.Name) && HasKeyValues(col)) | ||
| passThroughColumnNames.Add(col.Name); | ||
|
|
||
| passThroughColumnNames.IntersectWith(inputKeyDataViewTypeColumnsNames); // Only count those columns that were in the input of the pipeline | ||
|
|
||
| foreach (var name in passThroughColumnNames) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here is another edge case: If the input has a categorical column, say There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar issue with Concat: say the input has two categorical columns In reply to: 383580868 [](ancestors = 383580868)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, that is a bit strange side-effect. Perhaps in next iterations we can get more details on how AutoML prefers this. Maybe we wont need to add KeyToValue for passthrough columns at all and let them surface as keys. For now we can document this. |
||
| { | ||
| var outputData = new KeyToValueMappingTransformer(Host, name).Transform(end); | ||
| end = outputData; | ||
| transforms.AddLast(end as ITransformCanSaveOnnx); | ||
| } | ||
|
|
||
| var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop); | ||
|
|
||
| using (var file = Host.CreateOutputFile(_outputModelPath)) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.