Skip to content

Commit 6c57791

Browse files
committed
fix quant tests
1 parent d43e754 commit 6c57791

File tree

4 files changed

+25
-6
lines changed

4 files changed

+25
-6
lines changed

src/qonnx/converters/keras.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import tensorflow as tf
33
import tf2onnx
44
from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS
5+
from collections import OrderedDict
56

67
from finn.core.modelwrapper import ModelWrapper
78
from qonnx.util.cleanup import cleanup_model
@@ -16,6 +17,9 @@
1617
"QDepthwiseConv2DBatchnorm",
1718
]
1819

20+
# Skip remove_identity optimizer
21+
del tf2onnx.optimizer._optimizers['remove_identity']
22+
1923

2024
def add_value_info_for_constants(model: onnx.ModelProto):
2125
"""
@@ -101,14 +105,15 @@ def iterate_model(model):
101105

102106

103107
def _strip_qkeras_model(model):
104-
quantizers = {}
108+
quantizers = OrderedDict()
105109

106110
def extract_quantizers(layer):
107111
keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_layer(layer)
108112
if layer_quantizers:
109113
layer_quantizers = {
110114
k: None if v == "None" else v for k, v in layer_quantizers.items()
111115
} # Get rid of 'None' strings
116+
layer_quantizers["input"] = layer.input.name
112117
quantizers[layer.name] = layer_quantizers
113118

114119
layer_class = tf.keras.layers.__dict__.get(keras_cls_name, None)

src/qonnx/converters/qkeras/onnx.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,29 @@ def get_qkeras_onnx_handlers(all_quantizers):
1111
"MatMul": (dense_handler, ["MatMul", all_quantizers]),
1212
"BiasAdd": (bias_handler, ["BiasAdd", all_quantizers]),
1313
"Relu": (relu_handler, ["Relu", all_quantizers]),
14+
"Identity": (identity_handler, ["Identity", all_quantizers]),
1415
}
1516

1617

17-
def _extract_node_name(onnx_name, keras_names):
18+
def _extract_node_name(onnx_node, keras_quantizers):
19+
onnx_name = onnx_node.name
20+
keras_names = keras_quantizers.keys()
1821
for keras_name in keras_names:
1922
match = "/" + keras_name + "/"
2023
if match in onnx_name:
2124
return keras_name
25+
elif "Identity" in onnx_name:
26+
onnx_input = onnx_node.input[0]
27+
keras_input = keras_quantizers[keras_name]["input"]
28+
if keras_input in onnx_input:
29+
return keras_name
2230

2331
return None
2432

2533

2634
def qlayer_handler(ctx, node, name, args):
2735
all_quantizers = args[0]
28-
keras_name = _extract_node_name(name, all_quantizers.keys())
36+
keras_name = _extract_node_name(node, all_quantizers)
2937
if not keras_name:
3038
return # Not found in quantizers, nothing to do
3139
quantizers = all_quantizers[keras_name]
@@ -79,7 +87,7 @@ def qlayer_handler(ctx, node, name, args):
7987

8088
def qact_handler(ctx, node, name, args):
8189
all_quantizers = args[0]
82-
keras_name = _extract_node_name(name, all_quantizers.keys())
90+
keras_name = _extract_node_name(node, all_quantizers)
8391
if not keras_name:
8492
return # Not found in quantizers, nothing to do
8593
quantizers = all_quantizers[keras_name]
@@ -119,7 +127,7 @@ def bias_handler(ctx, node, name, args):
119127
BiasAdd.version_1(ctx, node)
120128

121129
all_quantizers = args[0]
122-
keras_name = _extract_node_name(name, all_quantizers.keys())
130+
keras_name = _extract_node_name(node, all_quantizers)
123131
if not keras_name:
124132
return # Not found in quantizers, nothing to do
125133
quantizers = all_quantizers[keras_name]
@@ -140,3 +148,8 @@ def bias_handler(ctx, node, name, args):
140148
def relu_handler(ctx, node, name, args):
141149
DirectOp.version_1(ctx, node)
142150
qact_handler(ctx, node, name, args)
151+
152+
153+
def identity_handler(ctx, node, name, args):
154+
DirectOp.version_1(ctx, node)
155+
qact_handler(ctx, node, name, args)

src/qonnx/converters/qkeras/quantizers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def convert_ternary(tensor, quantizer):
9494
ternary = qkeras.ternary()
9595
t = ternary.default_threshold
9696
assert t == 0.5, "ternary - only threshold 0.5 is supported"
97+
# note that if assertions fail, Quant node is not inserted, but model is still converted; this seems to be unexpected behavior
9798
scale = 1.0
9899
zero_point = 0
99100
bit_width = 2

tests/test_keras_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
quantized_bits(4, 2, 0, alpha=1),
3333
quantized_bits(2, 2, 1, alpha=1),
3434
quantized_bits(2, 1, 1, alpha=1),
35-
ternary(alpha=1),
35+
ternary(alpha=1, threshold=0.5),
3636
binary(alpha=1),
3737
]
3838
act_quantizers_ids = list(range(len(act_quantizers)))

0 commit comments

Comments
 (0)