@@ -11,21 +11,29 @@ def get_qkeras_onnx_handlers(all_quantizers):
1111 "MatMul" : (dense_handler , ["MatMul" , all_quantizers ]),
1212 "BiasAdd" : (bias_handler , ["BiasAdd" , all_quantizers ]),
1313 "Relu" : (relu_handler , ["Relu" , all_quantizers ]),
14+ "Identity" : (identity_handler , ["Identity" , all_quantizers ]),
1415 }
1516
1617
17- def _extract_node_name (onnx_name , keras_names ):
18+ def _extract_node_name (onnx_node , keras_quantizers ):
19+ onnx_name = onnx_node .name
20+ keras_names = keras_quantizers .keys ()
1821 for keras_name in keras_names :
1922 match = "/" + keras_name + "/"
2023 if match in onnx_name :
2124 return keras_name
25+ elif "Identity" in onnx_name :
26+ onnx_input = onnx_node .input [0 ]
27+ keras_input = keras_quantizers [keras_name ]["input" ]
28+ if keras_input in onnx_input :
29+ return keras_name
2230
2331 return None
2432
2533
2634def qlayer_handler (ctx , node , name , args ):
2735 all_quantizers = args [0 ]
28- keras_name = _extract_node_name (name , all_quantizers . keys () )
36+ keras_name = _extract_node_name (node , all_quantizers )
2937 if not keras_name :
3038 return # Not found in quantizers, nothing to do
3139 quantizers = all_quantizers [keras_name ]
@@ -79,7 +87,7 @@ def qlayer_handler(ctx, node, name, args):
7987
8088def qact_handler (ctx , node , name , args ):
8189 all_quantizers = args [0 ]
82- keras_name = _extract_node_name (name , all_quantizers . keys () )
90+ keras_name = _extract_node_name (node , all_quantizers )
8391 if not keras_name :
8492 return # Not found in quantizers, nothing to do
8593 quantizers = all_quantizers [keras_name ]
@@ -119,7 +127,7 @@ def bias_handler(ctx, node, name, args):
119127 BiasAdd .version_1 (ctx , node )
120128
121129 all_quantizers = args [0 ]
122- keras_name = _extract_node_name (name , all_quantizers . keys () )
130+ keras_name = _extract_node_name (node , all_quantizers )
123131 if not keras_name :
124132 return # Not found in quantizers, nothing to do
125133 quantizers = all_quantizers [keras_name ]
@@ -140,3 +148,8 @@ def bias_handler(ctx, node, name, args):
140148def relu_handler (ctx , node , name , args ):
141149 DirectOp .version_1 (ctx , node )
142150 qact_handler (ctx , node , name , args )
151+
152+
153+ def identity_handler (ctx , node , name , args ):
154+ DirectOp .version_1 (ctx , node )
155+ qact_handler (ctx , node , name , args )
0 commit comments