-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added categorical value support for PredictedLabel for the Image Classification Transfer Learning example. Addresses Issue #4169 #4228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8d07401
d80e23d
7b239b0
33e71ae
2b1ab69
7254a7e
1a43cc7
d39f98b
8cf2b7c
c4182fb
d490383
01c77c9
380d8bf
2af03ca
2c05ba8
6356665
ed928c6
f7f8253
bd42f0a
f851791
643fb58
b6cfeda
cca4fb8
1c4c5dc
2301a96
26e2ae1
94c0423
109879b
6ff4c64
15c4654
0573ef3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,6 +64,7 @@ public sealed class ImageClassificationTransformer : RowToRowTransformerBase | |
| private Graph Graph => _session.graph; | ||
| private readonly string[] _inputs; | ||
| private readonly string[] _outputs; | ||
| private ReadOnlyMemory<char>[] _keyValueAnnotations; | ||
| private readonly string _labelColumnName; | ||
| private readonly string _finalModelPrefix; | ||
| private readonly Architecture _arch; | ||
|
|
@@ -105,11 +106,24 @@ private static ImageClassificationTransformer Create(IHostEnvironment env, Model | |
| // int: number of output columns | ||
| // for each output column | ||
| // int: id of output column name | ||
| // stream: tensorFlow model. | ||
| // string: value of label column name | ||
| // string: prefix pf final model and checkpoint files/folder for storing graph files | ||
| // int: value of the utilized model architecture for transfer learning | ||
| // string: value of score column name | ||
| // string: value of predicted label column name | ||
| // float: value of learning rate | ||
| // int: number of prediction classes | ||
| // for each key value annotation column | ||
| // string: value of key value annotations | ||
| // string: name of prediction tensor | ||
| // string: name of softmax tensor | ||
| // string: name of JPEG data tensor | ||
| // string: name of resized image tensor | ||
| // stream (byte): tensorFlow model. | ||
|
|
||
| GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool addBatchDimensionInput, | ||
| out string labelColumn, out string checkpointName, out Architecture arch, out string scoreColumnName, | ||
| out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName, | ||
| out string predictedColumnName, out float learningRate, out int classCount, out string[] keyValueAnnotations, out string predictionTensorName, out string softMaxTensorName, | ||
| out string jpegDataTensorName, out string resizeTensorName); | ||
|
|
||
| byte[] modelBytes = null; | ||
|
|
@@ -119,7 +133,7 @@ private static ImageClassificationTransformer Create(IHostEnvironment env, Model | |
| return new ImageClassificationTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs, | ||
| null, addBatchDimensionInput, 1, labelColumn, checkpointName, arch, | ||
| scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName, | ||
| softMaxTensorName, jpegDataTensorName, resizeTensorName); | ||
| softMaxTensorName, jpegDataTensorName, resizeTensorName, keyValueAnnotations); | ||
|
|
||
| } | ||
|
|
||
|
|
@@ -628,7 +642,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat | |
| private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, | ||
| out string[] outputs, out bool addBatchDimensionInput, | ||
| out string labelColumn, out string checkpointName, out Architecture arch, | ||
| out string scoreColumnName, out string predictedColumnName, out float learningRate, out int classCount, out string predictionTensorName, out string softMaxTensorName, | ||
| out string scoreColumnName, out string predictedColumnName, out float learningRate, out int classCount, out string[] keyValueAnnotations, out string predictionTensorName, out string softMaxTensorName, | ||
| out string jpegDataTensorName, out string resizeTensorName) | ||
| { | ||
| addBatchDimensionInput = ctx.Reader.ReadBoolByte(); | ||
|
|
@@ -652,6 +666,12 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out | |
| predictedColumnName = ctx.Reader.ReadString(); | ||
| learningRate = ctx.Reader.ReadFloat(); | ||
| classCount = ctx.Reader.ReadInt32(); | ||
|
|
||
| env.CheckDecode(classCount > 0); | ||
| keyValueAnnotations = new string[classCount]; | ||
| for (int j = 0; j < keyValueAnnotations.Length; j++) | ||
| keyValueAnnotations[j] = ctx.LoadNonEmptyString(); | ||
|
|
||
| predictionTensorName = ctx.Reader.ReadString(); | ||
| softMaxTensorName = ctx.Reader.ReadString(); | ||
| jpegDataTensorName = ctx.Reader.ReadString(); | ||
|
|
@@ -662,7 +682,7 @@ internal ImageClassificationTransformer(IHostEnvironment env, Session session, s | |
| string[] inputColumnNames, string modelLocation, | ||
| bool? addBatchDimensionInput, int batchSize, string labelColumnName, string finalModelPrefix, Architecture arch, | ||
| string scoreColumnName, string predictedLabelColumnName, float learningRate, DataViewSchema inputSchema, int? classCount = null, bool loadModel = false, | ||
| string predictionTensorName = null, string softMaxTensorName = null, string jpegDataTensorName = null, string resizeTensorName = null) | ||
| string predictionTensorName = null, string softMaxTensorName = null, string jpegDataTensorName = null, string resizeTensorName = null, string[] labelAnnotations = null) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationTransformer))) | ||
|
|
||
| { | ||
|
|
@@ -731,13 +751,55 @@ internal ImageClassificationTransformer(IHostEnvironment env, Session session, s | |
| (_evaluationStep, _) = AddEvaluationStep(_softMaxTensor, _labelTensor); | ||
| _softmaxTensorName = _softMaxTensor.name; | ||
| _predictionTensorName = _prediction.name; | ||
|
|
||
| // Add annotations as key values, if they exist. | ||
| VBuffer<ReadOnlyMemory<char>> keysVBuffer = default; | ||
| if (inputSchema[labelColumnName].HasKeyValues()) | ||
| { | ||
| inputSchema[labelColumnName].GetKeyValues(ref keysVBuffer); | ||
| _keyValueAnnotations = keysVBuffer.DenseValues().ToArray(); | ||
| } | ||
| else | ||
| { | ||
| _keyValueAnnotations = Enumerable.Range(0, _classCount).Select(x => x.ToString().AsMemory()).ToArray(); | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enumerable.Range(0, _classCount).Select(x => x.ToString().AsMemory()).ToArray(); #Resolved |
||
| } | ||
| else | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am assuming that this is meant for the case where you are loading the model from file. In that case, you need to use the key values that are deserialized from the model. In general, I think the way the constructors for this class work is kind of confusing - I think there should be one constructor that is given all the information it needs, either from the deserialization
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please open a separate issue for making two constructors, this suggestion is out of the scope of this change. In reply to: 328535147 [](ancestors = 328535147) |
||
| { | ||
| // Load annotations as key values, if they exist | ||
| if (labelAnnotations != null) | ||
| _keyValueAnnotations = labelAnnotations.Select(v => v.AsMemory()).ToArray(); | ||
| } | ||
| } | ||
|
|
||
| private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); | ||
|
|
||
| private protected override void SaveModel(ModelSaveContext ctx) | ||
| { | ||
| // *** Binary format *** | ||
| // byte: indicator for frozen models | ||
| // byte: indicator for adding batch dimension in input | ||
| // int: number of input columns | ||
| // for each input column | ||
| // int: id of int column name | ||
| // int: number of output columns | ||
| // for each output column | ||
| // int: id of output column name | ||
| // string: value of label column name | ||
| // string: prefix pf final model and checkpoint files/folder for storing graph files | ||
| // int: value of the utilized model architecture for transfer learning | ||
| // string: value of score column name | ||
| // string: value of predicted label column name | ||
| // float: value of learning rate | ||
| // int: number of prediction classes | ||
| // for each key value annotation column | ||
| // string: value of key value annotations | ||
| // string: name of prediction tensor | ||
| // string: name of softmax tensor | ||
| // string: name of JPEG data tensor | ||
| // string: name of resized image tensor | ||
| // stream (byte): tensorFlow model. | ||
|
|
||
| Host.AssertValue(ctx); | ||
| ctx.CheckAtModel(); | ||
| ctx.SetVersionInfo(GetVersionInfo()); | ||
|
|
@@ -760,6 +822,12 @@ private protected override void SaveModel(ModelSaveContext ctx) | |
| ctx.Writer.Write(_predictedLabelColumnName); | ||
| ctx.Writer.Write(_learningRate); | ||
| ctx.Writer.Write(_classCount); | ||
|
|
||
| Host.AssertNonEmpty(_keyValueAnnotations); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You cannot assert this, since it is possible that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adjusted code with latest push so that _keyValueAnnotations will not be null, thanks to Zeeshan for the suggestion #Resolved |
||
| Host.Assert(_keyValueAnnotations.Length == _classCount); | ||
| for (int j = 0; j < _classCount; j++) | ||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a minor FYI - In C#, when you index into an array, the index is bounds-checked at runtime. The runtime checks if the index is non-negative, and less than the length of the array. If it is outside those bounds, an exception is thrown. While these bounds checks aren't very expensive, they can add up. int[] myArray = ...;
for (int i = 0; i < myArray.Length; i++)
{
int element = myArray[i];
}It sees this pattern, and knows that The take away here is, when you have 2 variables to pick from to use in your
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, that makes sense, thank you Eric! It's pretty cool that the C# compiler is able to recognize the traversing of an array and modify the bound checks as it goes. #Resolved |
||
| ctx.SaveNonEmptyString(_keyValueAnnotations[j]); | ||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it really guaranteed _keyValueAnnotations is not empty? refer to below lines of code that you have written in the constructor. VBuffer<ReadOnlyMemory> keysVBuffer = default;
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With your suggested changes below, it is now guaranteed that _keyValueAnnotations will not be empty. #Resolved |
||
|
|
||
| ctx.Writer.Write(_predictionTensorName); | ||
| ctx.Writer.Write(_softmaxTensorName); | ||
| ctx.Writer.Write(_jpegDataTensorName); | ||
|
|
@@ -845,6 +913,7 @@ public void UpdateCacheIfNeeded() | |
| var outputTensor = _runner.AddInput(processedTensor, 0).Run(); | ||
| outputTensor[0].ToArray<float>(ref _classProbability); | ||
| outputTensor[1].ToScalar<long>(ref _predictedLabel); | ||
| _predictedLabel += 1; | ||
| outputTensor[0].Dispose(); | ||
| outputTensor[1].Dispose(); | ||
| processedTensor.Dispose(); | ||
|
|
@@ -890,9 +959,18 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a | |
|
|
||
| protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() | ||
| { | ||
| var annotationBuilder = new DataViewSchema.Annotations.Builder(); | ||
| annotationBuilder.AddKeyValues(_parent._classCount, TextDataViewType.Instance, (ref VBuffer<ReadOnlyMemory<char>> dst) => | ||
| { | ||
| var editor = VBufferEditor.Create(ref dst, _parent._classCount); | ||
| for (int i = 0; i < _parent._classCount; i++) | ||
| editor.Values[i] = _parent._keyValueAnnotations[i]; | ||
| dst = editor.Commit(); | ||
| }); | ||
|
|
||
| var info = new DataViewSchema.DetachedColumn[_parent._outputs.Length]; | ||
| info[0] = new DataViewSchema.DetachedColumn(_parent._outputs[0], new VectorDataViewType(NumberDataViewType.Single, _parent._classCount), null); | ||
| info[1] = new DataViewSchema.DetachedColumn(_parent._outputs[1], NumberDataViewType.UInt32, null); | ||
| info[0] = new DataViewSchema.DetachedColumn(_parent._scoreColumnName, new VectorDataViewType(NumberDataViewType.Single, _parent._classCount), null); | ||
| info[1] = new DataViewSchema.DetachedColumn(_parent._predictedLabelColumnName, new KeyDataViewType(typeof(uint), _parent._classCount), annotationBuilder.ToAnnotations()); | ||
| return info; | ||
| } | ||
| } | ||
|
|
@@ -1288,7 +1366,6 @@ internal sealed class Options : TransformInputBase | |
| private readonly Options _options; | ||
| private readonly DnnModel _dnnModel; | ||
| private readonly TF_DataType[] _tfInputTypes; | ||
| private readonly DataViewType[] _outputTypes; | ||
| private ImageClassificationTransformer _transformer; | ||
|
|
||
| internal ImageClassificationEstimator(IHostEnvironment env, Options options, DnnModel dnnModel) | ||
|
|
@@ -1297,7 +1374,6 @@ internal ImageClassificationEstimator(IHostEnvironment env, Options options, Dnn | |
| _options = options; | ||
| _dnnModel = dnnModel; | ||
| _tfInputTypes = new[] { TF_DataType.TF_STRING }; | ||
| _outputTypes = new[] { new VectorDataViewType(NumberDataViewType.Single), NumberDataViewType.UInt32.GetItemType() }; | ||
| } | ||
|
|
||
| private static Options CreateArguments(DnnModel tensorFlowModel, string[] outputColumnNames, string[] inputColumnName, bool addBatchDimensionInput) | ||
|
|
@@ -1327,12 +1403,16 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) | |
| if (col.ItemType != expectedType) | ||
| throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); | ||
| } | ||
| for (var i = 0; i < _options.OutputColumns.Length; i++) | ||
| { | ||
| resultDic[_options.OutputColumns[i]] = new SchemaShape.Column(_options.OutputColumns[i], | ||
| _outputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector | ||
| : SchemaShape.Column.VectorKind.VariableVector, _outputTypes[i].GetItemType(), false); | ||
| } | ||
|
|
||
| resultDic[_options.OutputColumns[0]] = new SchemaShape.Column(_options.OutputColumns[0], | ||
| SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); | ||
|
|
||
| var metadata = new List<SchemaShape.Column>(); | ||
| metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false)); | ||
|
|
||
| resultDic[_options.OutputColumns[1]] = new SchemaShape.Column(_options.OutputColumns[1], | ||
| SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true, new SchemaShape(metadata.ToArray())); | ||
|
|
||
| return new SchemaShape(resultDic.Values); | ||
| } | ||
|
|
||
|
|
@@ -1342,8 +1422,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) | |
| public ImageClassificationTransformer Fit(IDataView input) | ||
| { | ||
| _host.CheckValue(input, nameof(input)); | ||
| if (_transformer == null) | ||
| _transformer = new ImageClassificationTransformer(_host, _options, _dnnModel, input); | ||
| _transformer = new ImageClassificationTransformer(_host, _options, _dnnModel, input); | ||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // Validate input schema. | ||
| _transformer.GetOutputSchema(input.Schema); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it does not have key values then just add numeric values (0 - classCount -1) as string to have something in there #Resolved