-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added categorical value support for PredictedLabel for the Image Classification Transfer Learning example. Addresses Issue #4169 #4228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
8d07401
d80e23d
7b239b0
33e71ae
2b1ab69
7254a7e
1a43cc7
d39f98b
8cf2b7c
c4182fb
d490383
01c77c9
380d8bf
2af03ca
2c05ba8
6356665
ed928c6
f7f8253
bd42f0a
f851791
643fb58
b6cfeda
cca4fb8
1c4c5dc
2301a96
26e2ae1
94c0423
109879b
6ff4c64
15c4654
0573ef3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,9 +46,13 @@ public static void Example() | |
| IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows( | ||
| mlContext.Data.LoadFromEnumerable(images)); | ||
|
|
||
| shuffledFullImagesDataset = mlContext.Transforms.Conversion | ||
| .MapValueToKey("Label") | ||
| .Fit(shuffledFullImagesDataset) | ||
| var estimator = mlContext.Transforms.Conversion | ||
| .MapValueToKey("Label"); | ||
| var estimatorWithKeyType = estimator.Append( | ||
| mlContext.Transforms.Conversion.MapKeyToValue( | ||
| outputColumnName: "LabelAsKey", inputColumnName: "Label")); | ||
|
|
||
| shuffledFullImagesDataset = estimatorWithKeyType.Fit(shuffledFullImagesDataset) | ||
| .Transform(shuffledFullImagesDataset); | ||
|
|
||
| // Split the data 90:10 into train and test sets, train and evaluate. | ||
|
|
@@ -93,15 +97,16 @@ public static void Example() | |
| DataViewSchema schema; | ||
| using (var file = File.OpenRead("model.zip")) | ||
| loadedModel = mlContext.Model.Load(file, out schema); | ||
| // the schema in line 99 and the shuffledFullImagesDataset.Schema in line 93 don't have | ||
| // the same annotations. | ||
|
mstfbl marked this conversation as resolved.
Outdated
|
||
|
|
||
| EvaluateModel(mlContext, testDataset, loadedModel); | ||
|
|
||
| VBuffer<ReadOnlyMemory<char>> keys = default; | ||
| loadedModel.GetOutputSchema(schema)["Label"].GetKeyValues(ref keys); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't need to do all this if you convert predicted label column from Key to Value. It will output string type that will contain class names that corresponds to key indices. #Resolved |
||
|
|
||
| watch = System.Diagnostics.Stopwatch.StartNew(); | ||
| TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel, | ||
| keys.DenseValues().ToArray()); | ||
| TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel, keys.DenseValues().ToArray(), shuffledFullImagesDataset.Schema); | ||
|
|
||
| watch.Stop(); | ||
| elapsedMs = watch.ElapsedMilliseconds; | ||
|
|
@@ -119,8 +124,7 @@ public static void Example() | |
| } | ||
|
|
||
| private static void TrySinglePrediction(string imagesForPredictions, | ||
| MLContext mlContext, ITransformer trainedModel, | ||
| ReadOnlyMemory<char>[] originalLabels) | ||
| MLContext mlContext, ITransformer trainedModel, ReadOnlyMemory<char>[] originalLabels, DataViewSchema schema) | ||
| { | ||
| // Create prediction function to try one prediction | ||
| var predictionEngine = mlContext.Model | ||
|
|
@@ -135,6 +139,7 @@ private static void TrySinglePrediction(string imagesForPredictions, | |
| }; | ||
|
|
||
| var prediction = predictionEngine.Predict(imageToPredict); | ||
| var predictedLabelsKeyType = ((DataViewSchema.Column)schema.GetColumnOrNull("Label")).Annotations; | ||
| var index = prediction.PredictedLabel; | ||
|
mstfbl marked this conversation as resolved.
Outdated
|
||
|
|
||
| Console.WriteLine($"ImageFile : " + | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -878,7 +878,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() | |
| { | ||
| var info = new DataViewSchema.DetachedColumn[_parent._outputs.Length]; | ||
| info[0] = new DataViewSchema.DetachedColumn(_parent._outputs[0], new VectorDataViewType(NumberDataViewType.Single, _parent._classCount), null); | ||
| info[1] = new DataViewSchema.DetachedColumn(_parent._outputs[1], NumberDataViewType.UInt32, null); | ||
| info[1] = new DataViewSchema.DetachedColumn(_parent._outputs[1], new KeyDataViewType(typeof(uint), _parent._classCount), ((DataViewSchema.Column)InputSchema.GetColumnOrNull("Label")).Annotations); | ||
|
mstfbl marked this conversation as resolved.
Outdated
|
||
| return info; | ||
| } | ||
| } | ||
|
|
@@ -1166,7 +1166,7 @@ internal ImageClassificationEstimator(IHostEnvironment env, Options options, Dnn | |
| _options = options; | ||
| _dnnModel = dnnModel; | ||
| _tfInputTypes = new[] { TF_DataType.TF_STRING }; | ||
| _outputTypes = new[] { new VectorDataViewType(NumberDataViewType.Single), NumberDataViewType.UInt32.GetItemType() }; | ||
| _outputTypes = new DataViewType[] { new VectorDataViewType(NumberDataViewType.Single), new KeyDataViewType(typeof(uint), 5) }; | ||
|
mstfbl marked this conversation as resolved.
Outdated
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good catch, thank you! Quick question, the Image Classification example ResnetV2101TransferLearningTrainTestSplit.cs runs perfectly well when I don't further define _outputTypes, i,e, delete line 69. Does this also mean than the estimator isn't using this field? #Resolved
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I've found that the DataViewType field is needed for the pipeline in ResnetV2101TransferLearningTrainTestSplit.cs to fit properly. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate? The only place In reply to: 326329163 [](ancestors = 326329163)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was giving a different error when I removed _outputTypes, but as I now see that was not the bottleneck problem I was having while implementing this KeyType solution. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The field In reply to: 326229209 [](ancestors = 326229209)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed this, thank you for the catch! #Resolved |
||
| } | ||
|
|
||
| private static Options CreateArguments(DnnModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.