Skip to content

Commit

Permalink
Add support for complex64 data type in parseDtypeParam function (#8083
Browse files Browse the repository at this point in the history
)


* Add support for complex64 data type in `parseDtypeParam` function

* Add parse complex64 test to operation_mapper_test.ts
  • Loading branch information
Lutra-Fs authored Dec 4, 2023
1 parent b8a0023 commit d94193d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tfjs-converter/src/operations/operation_mapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ export function parseDtypeParam(value: string|tensorflow.DataType): DataType {
return 'float32';
case tensorflow.DataType.DT_STRING:
return 'string';
case tensorflow.DataType.DT_COMPLEX64:
case tensorflow.DataType.DT_COMPLEX128:
return 'complex64';
default:
// Unknown dtype error will happen at runtime (instead of parse time),
// since these nodes might not be used by the actual subgraph execution.
Expand Down
20 changes: 17 additions & 3 deletions tfjs-converter/src/operations/operation_mapper_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ const SIMPLE_MODEL: tensorflow.IGraphDef = {
input: ['BiasAdd'],
attr: {DstT: {type: tensorflow.DataType.DT_HALF}}
},
{
name: 'Cast4',
op: 'Cast',
input: ['BiasAdd'],
attr: {DstT: {type: tensorflow.DataType.DT_COMPLEX64}}
}
],
library: {
function: [
Expand Down Expand Up @@ -310,7 +316,7 @@ describe('operationMapper without signature', () => {
it('should find the graph output nodes', () => {
expect(convertedGraph.outputs.map(node => node.name)).toEqual([
'Fill', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot',
'FusedBatchNorm', 'Cast2', 'Cast3'
'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4'
]);
});

Expand All @@ -324,7 +330,7 @@ describe('operationMapper without signature', () => {
expect(Object.keys(convertedGraph.nodes)).toEqual([
'image_placeholder', 'Const', 'Shape', 'Value', 'Fill', 'Conv2D',
'BiasAdd', 'Cast', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot',
'FusedBatchNorm', 'Cast2', 'Cast3'
'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4'
]);
});
});
Expand Down Expand Up @@ -447,6 +453,10 @@ describe('operationMapper without signature', () => {
expect(convertedGraph.nodes['Cast'].attrParams['dtype'].value)
.toEqual('int32');
});
it('should map params with complex64 dtype', () => {
expect(convertedGraph.nodes['Cast4'].attrParams['dtype'].value)
.toEqual('complex64');
});
});
});
});
Expand Down Expand Up @@ -486,7 +496,7 @@ describe('operationMapper with signature', () => {
expect(Object.keys(convertedGraph.nodes)).toEqual([
'image_placeholder', 'Const', 'Shape', 'Value', 'Fill', 'Conv2D',
'BiasAdd', 'Cast', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot',
'FusedBatchNorm', 'Cast2', 'Cast3'
'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4'
]);
});
});
Expand Down Expand Up @@ -552,6 +562,10 @@ describe('operationMapper with signature', () => {
expect(convertedGraph.nodes['Cast3'].attrParams['dtype'].value)
.toEqual('float32');
});
it('should map params with complex64 dtype', () => {
expect(convertedGraph.nodes['Cast4'].attrParams['dtype'].value)
.toEqual('complex64');
});
});
});
});

0 comments on commit d94193d

Please sign in to comment.