Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,16 @@ def do_transformation(self):
min_filter_node = None
# The Min and Max of non-const weight node are from QuantizeV2's output, not valid nodes.
# Add check here for excluding this case.
if ":2" not in new_node.input[6]:
max_filter_node = self.graph_info[new_node.input[6]].node
if ":1" not in new_node.input[5]:
min_filter_node = self.graph_info[new_node.input[5]].node
if len(attr_fused_ops) == 0: # single matmul case
if ":2" not in new_node.input[5]:
max_filter_node = self.graph_info[new_node.input[5]].node
if ":1" not in new_node.input[4]:
min_filter_node = self.graph_info[new_node.input[4]].node
else:
if ":2" not in new_node.input[6]:
max_filter_node = self.graph_info[new_node.input[6]].node
if ":1" not in new_node.input[5]:
min_filter_node = self.graph_info[new_node.input[5]].node
last_node = self.graph_info[new_node.input[0]].node
is_min_first = bool(quantized_node.attr['input_quant_mode'].s == b'MIN_FIRST')
weight_node = self.graph_info[new_node.input[1]].node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def do_transformation(self):
self.g.graph = copy.deepcopy(self.model)
self.graph_info = self.g.parse_graph()

self.g.get_frame_info()

# insert QDQ pattern for op's input
for op_name in quantizable_op_names:
if self._ignore_insert_qdq_pattern(op_name):
Expand Down Expand Up @@ -115,20 +117,16 @@ def do_transformation(self):

computational_node = self.graph_info[computational_node_name].node
weight_name = computational_node.input[1]
weight_node = self.graph_info[weight_name].node
if re.search(r"\w+:\d+", weight_name):
weight_node = self.graph_info[weight_name.rsplit(':', 1)[0]].node
else:
weight_node = self.graph_info[weight_name].node
enter_node = None
if weight_node.op == 'Enter':
if self.itex_mode:
parent_node = self.graph_info[Helper.node_name_from_input(weight_node.input[0])].node
if not parent_node.op == 'Const':
continue
else:
enter_node = weight_node
weight_node = parent_node
weight_node = parent_node
else:
continue

Expand All @@ -139,10 +137,10 @@ def do_transformation(self):
else:
per_channel = False
weight_bit = 7

self._insert_qdq_pattern_for_weight_node(computational_node,
weight_node,
enter_node,
weight_name,
min_max_values,
per_channel,
weight_bit,
Expand Down Expand Up @@ -414,7 +412,7 @@ def _insert_qdq_pattern_for_each_input(self, op_name, namespace_prefix,
def _insert_qdq_pattern_for_weight_node(self,
computational_node,
weight_node,
enter_node,
weight_name,
min_max_values,
per_channel,
weight_bit=7.0,
Expand Down Expand Up @@ -504,41 +502,27 @@ def _insert_qdq_pattern_for_weight_node(self,
max_node = Helper.create_constant_node(max_name, max_value,
dtypes.float32, device="cpu")
if "BatchMatMul" in host_op_type and "BatchMatMul" not in weight_node.op:
min_node.input.append("^" + weight_node.name)
max_node.input.append("^" + weight_node.name)
min_node.input.append("^" + weight_name)
max_node.input.append("^" + weight_name)

quant_const_enter_node = None
min_enter_node = None
max_enter_node = None
if enter_node:
quant_const_enter_node = Helper.create_node('Enter', \
qint8_const_name + '_enter', [weight_node.name])
Helper.set_attr_string(quant_const_enter_node,
'frame_name', enter_node.attr['frame_name'].s)
Helper.set_attr_dtype(quant_const_enter_node, 'T', dtypes.float32)
Helper.set_attr_bool(quant_const_enter_node, 'is_constant', True)
Helper.set_attr_int(quant_const_enter_node, \
'parallel_iterations', enter_node.attr['parallel_iterations'].i)
if insert_reshape:
reshape_dims_4to3_name = qint8_const_name + "_reshape_dims_4to3_"
reshape_dims_4to3_node = Helper.create_constant_node(
reshape_dims_4to3_name, shape_convert, dtypes.int32)
reshape_4to3_name = qint8_const_name + "_reshape_4to3_"
reshape_4to3_node = Helper.create_node("Reshape", reshape_4to3_name,
[weight_node.name, reshape_dims_4to3_name])
reshape_4to3_node.attr["T"].CopyFrom(
attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
quant_node = Helper.create_node(
"QuantizeV2", qint8_const_name + '_quant',
[quant_const_enter_node.name, min_name, max_name])
[reshape_4to3_name, min_name, max_name])
else:
if insert_reshape:
reshape_dims_4to3_name = qint8_const_name + "_reshape_dims_4to3_"
reshape_dims_4to3_node = Helper.create_constant_node(
reshape_dims_4to3_name, shape_convert, dtypes.int32)
reshape_4to3_name = qint8_const_name + "_reshape_4to3_"
reshape_4to3_node = Helper.create_node("Reshape", reshape_4to3_name,
[weight_node.name, reshape_dims_4to3_name])
reshape_4to3_node.attr["T"].CopyFrom(
attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum))
quant_node = Helper.create_node(
"QuantizeV2", qint8_const_name + '_quant',
[reshape_4to3_name, min_name, max_name])
else:
quant_node = Helper.create_node(
"QuantizeV2", qint8_const_name + '_quant',
[weight_node.name, min_name, max_name])
quant_node = Helper.create_node(
"QuantizeV2", qint8_const_name + '_quant',
[weight_node.name, min_name, max_name])

dequant_node = Helper.create_node(
"Dequantize", base_name + '_dequant',
Expand All @@ -549,10 +533,10 @@ def _insert_qdq_pattern_for_weight_node(self,
Helper.set_attr_dtype(dequant_node, "T", dtypes.qint8)
Helper.set_attr_string(dequant_node, "mode", b"SCALED")
if per_channel:
if host_op_type == 'Conv2D' or host_op_type == 'Conv2DBackpropInput':
if host_op_type in ('Conv2D', 'Conv2DBackpropInput'):
Helper.set_attr_int(quant_node, 'axis', 3)
Helper.set_attr_int(dequant_node, 'axis', 3)
elif host_op_type == 'Conv3D' or host_op_type == 'Conv3DBackpropInputV2':
elif host_op_type in ('Conv3D', 'Conv3DBackpropInputV2'):
Helper.set_attr_int(quant_node, 'axis', 4)
Helper.set_attr_int(dequant_node, 'axis', 4)
elif host_op_type == 'MatMul':
Expand Down Expand Up @@ -584,33 +568,32 @@ def _insert_qdq_pattern_for_weight_node(self,
self.g_weight.add_node(reshape_3to4_node, dequant_node.name, [computational_node.name])
computational_node.input[1] = reshape_3to4_node.name
else:
if enter_node:
if weight_node.name in self.g.parent_frame_details and self.g.parent_frame_details[weight_node.name]:
min_enter_node = Helper.create_node('Enter', min_name + '_enter', [min_name])
Helper.set_attr_string(min_enter_node,
'frame_name', enter_node.attr['frame_name'].s)
Helper.set_attr_string(min_enter_node, 'frame_name',
self.g.parent_frame_details[weight_node.name].attr['frame_name'].s)
Helper.set_attr_dtype(min_enter_node, 'T', dtypes.float32)
Helper.set_attr_bool(min_enter_node, 'is_constant', True)
Helper.set_attr_int(min_enter_node, 'parallel_iterations', \
enter_node.attr['parallel_iterations'].i)
self.g.parent_frame_details[weight_node.name].attr['parallel_iterations'].i)

max_enter_node = Helper.create_node('Enter', max_name + '_enter', [max_name])
Helper.set_attr_string(max_enter_node,
'frame_name', enter_node.attr['frame_name'].s)
Helper.set_attr_string(max_enter_node, 'frame_name',
self.g.parent_frame_details[weight_node.name].attr['frame_name'].s)
Helper.set_attr_dtype(max_enter_node, 'T', dtypes.float32)
Helper.set_attr_bool(max_enter_node, 'is_constant', True)
Helper.set_attr_int(max_enter_node, 'parallel_iterations',\
enter_node.attr['parallel_iterations'].i)
self.g.parent_frame_details[weight_node.name].attr['parallel_iterations'].i)

self.g_weight.add_node(quant_const_enter_node, weight_node.name, [quant_node.name])
self.g_weight.add_node(quant_node, quant_const_enter_node.name, [])
self.g_weight.add_node(quant_node, weight_name, [])
self.g_weight.add_node(min_node, None, [min_enter_node.name])
self.g_weight.add_node(max_node, None, [max_enter_node.name])
self.g_weight.add_node(min_enter_node, min_node.name, [quant_node.name])
self.g_weight.add_node(max_enter_node, max_node.name, [quant_node.name])
self.g_weight.add_node(dequant_node, quant_node.name, [computational_node.name])
computational_node.input[1] = dequant_node.name
else:
self.g_weight.add_node(quant_node, weight_node.name, [])
self.g_weight.add_node(quant_node, weight_name, [])
self.g_weight.add_node(min_node, None, [quant_node.name])
self.g_weight.add_node(max_node, None, [quant_node.name])
self.g_weight.add_node(dequant_node, quant_node.name, [computational_node.name])
Expand Down