Skip to content

Commit dc7125b

Browse files
authored
[Hexagon] Propagate QNN Concat Quantization Params to Inputs (#15258)
* [Hexagon] Propagate qnn.concat quantization params to inputs, eliminating redundant requantization when possible, and make it concat * Fix pylint issue * Add relay IR snippet before and after transformation * Better test file description comment
1 parent 3c23865 commit dc7125b

File tree

2 files changed

+203
-3
lines changed

2 files changed

+203
-3
lines changed

python/tvm/contrib/hexagon/transform.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,16 @@
2121

2222
import tvm
2323
from tvm import relay
24-
from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard
25-
from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple
24+
from tvm.relay.dataflow_pattern import (
25+
DFPatternCallback,
26+
is_constant,
27+
is_op,
28+
is_tuple,
29+
rewrite,
30+
wildcard,
31+
)
32+
from tvm.relay.expr import Call
33+
2634
from ..._ffi.registry import register_func
2735

2836
### VTCM
@@ -43,7 +51,6 @@ def mem_info_vtcm():
4351

4452

4553
def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx): # pylint: disable=unused-argument
46-
4754
"""Generic VTCM allocation
4855
4956
Parameters
@@ -311,3 +318,95 @@ def remove_empty_pad(mod):
311318
"""Remove the empty pad operator."""
312319
mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"])
313320
return mod
321+
322+
323+
class simplify_qnn_concat_in_func(DFPatternCallback):
324+
325+
"""
326+
Propagate qnn.concat's quantization params to its inputs,
327+
and try to avoid redundant requantization while doing so.
328+
329+
Replace
330+
def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
331+
%q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) {
332+
%0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC");
333+
%1 = qnn.requantize(%q2, 0.000109401f, 0, 0.00345f, 0, axis=1, out_dtype="uint8");
334+
%2 = (%0, %1, %q3);
335+
%3 = (0.0425042f, 0.00345f, 0.0486874f);
336+
%4 = (0, 0, 0);
337+
qnn.concatenate(%2, %3, %4, 0.0486874f, 0, axis=1)
338+
}
339+
340+
with
341+
342+
def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
343+
%q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) {
344+
%0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC");
345+
%1 = qnn.requantize(%0, 0.0425042f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8");
346+
%2 = qnn.requantize(%q2, 0.000109401f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8");
347+
%3 = (%1, %2, %q3);
348+
concatenate(%3, axis=1)
349+
}
350+
"""
351+
352+
def __init__(self):
353+
super(simplify_qnn_concat_in_func, self).__init__()
354+
self.qvals = wildcard()
355+
self.scales = wildcard()
356+
self.zps = wildcard()
357+
self.out_scale = wildcard()
358+
self.out_zp = wildcard()
359+
self.pattern = is_op("qnn.concatenate")(
360+
self.qvals, self.scales, self.zps, self.out_scale, self.out_zp
361+
)
362+
363+
def callback(self, pre, post, node_map):
364+
in_qvals = node_map[self.qvals][0]
365+
in_scales = node_map[self.scales][0]
366+
in_zps = node_map[self.zps][0]
367+
new_qvals = []
368+
for i in range(len(in_qvals)):
369+
new_requant_args = []
370+
# TODO Generalize for all qnn ops
371+
if isinstance(in_qvals[i], Call) and (in_qvals[i].op.name == "qnn.requantize"):
372+
# propagate scale/zp of qnn.concat to this requantize op
373+
for j in range(3):
374+
new_requant_args.append(in_qvals[i].args[j])
375+
new_requant_args += [node_map[self.out_scale][0], node_map[self.out_zp][0]]
376+
new_qvals.append(relay.qnn.op.requantize(*new_requant_args, **(in_qvals[i].attrs)))
377+
else:
378+
# simply create a new requantize op if there is a change in quantization params
379+
# if not, just retain the old qval
380+
if (in_scales[i] == node_map[self.out_scale][0]) and (
381+
in_zps[i] == node_map[self.out_zp][0]
382+
):
383+
new_qvals.append(in_qvals[i])
384+
else:
385+
new_requant_args += [
386+
in_qvals[i],
387+
in_scales[i],
388+
in_zps[i],
389+
node_map[self.out_scale][0],
390+
node_map[self.out_zp][0],
391+
]
392+
new_qvals.append(
393+
relay.qnn.op.requantize(
394+
*new_requant_args,
395+
axis=post.attrs["axis"],
396+
out_dtype=post.checked_type.dtype,
397+
)
398+
)
399+
400+
new_op = relay.op.concatenate(
401+
new_qvals,
402+
node_map[self.pattern][0].attrs["axis"],
403+
)
404+
return new_op
405+
406+
407+
# Right now context is ignored
408+
@tvm.transform.module_pass(opt_level=1)
409+
def simplify_qnn_concat(mod, _=None):
410+
for global_var in mod.functions.keys():
411+
mod[global_var] = rewrite(simplify_qnn_concat_in_func(), mod[global_var])
412+
return mod
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-wildcard-import, invalid-name
18+
19+
"""
20+
Test hexagon relay transform - qnn.concat optimization
21+
"""
22+
import tvm
23+
from tvm import relay, testing
24+
from tvm.contrib.hexagon.transform import simplify_qnn_concat
25+
26+
27+
def get_test_module():
28+
"""Creates a test relay module and returns it."""
29+
q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8")
30+
q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8")
31+
q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8")
32+
s2 = relay.const(0.000109401, dtype="float32")
33+
s3 = relay.const(0.0486874, dtype="float32")
34+
s4 = relay.const(0.0425042, dtype="float32")
35+
s5 = relay.const(0.00345, dtype="float32")
36+
z1 = relay.const(0, dtype="int32")
37+
r1 = relay.op.nn.max_pool2d(
38+
q1,
39+
pool_size=[3, 3],
40+
strides=[1, 1],
41+
padding=[1, 1],
42+
dilation=[1, 1],
43+
ceil_mode=False,
44+
layout="NHWC",
45+
)
46+
r2 = relay.qnn.op.requantize(q2, s2, z1, s5, z1, axis=1, out_dtype="uint8")
47+
q_tuple = relay.expr.Tuple([r1, r2, q3])
48+
s_tuple = relay.expr.Tuple([s4, s5, s3])
49+
z_tuple = relay.expr.Tuple([z1, z1, z1])
50+
graph = relay.qnn.op.concatenate(q_tuple, s_tuple, z_tuple, s3, z1, axis=1)
51+
52+
func = relay.Function(relay.analysis.free_vars(graph), graph)
53+
mod = tvm.IRModule.from_expr(func)
54+
return mod
55+
56+
57+
def get_expected_output_module():
58+
"""Returns manually created expected output module."""
59+
out_q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8")
60+
out_q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8")
61+
out_q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8")
62+
out_s2 = relay.const(0.000109401, dtype="float32")
63+
out_s3 = relay.const(0.0486874, dtype="float32")
64+
out_s4 = relay.const(0.0425042, dtype="float32")
65+
out_z1 = relay.const(0, dtype="int32")
66+
nn_max_pool = relay.op.nn.max_pool2d(
67+
out_q1,
68+
pool_size=[3, 3],
69+
strides=[1, 1],
70+
padding=[1, 1],
71+
dilation=[1, 1],
72+
ceil_mode=False,
73+
layout="NHWC",
74+
)
75+
out_r1 = relay.qnn.op.requantize(
76+
nn_max_pool, out_s4, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8"
77+
)
78+
out_r2 = relay.qnn.op.requantize(
79+
out_q2, out_s2, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8"
80+
)
81+
out_q_tuple = relay.expr.Tuple([out_r1, out_r2, out_q3])
82+
out_graph = relay.op.concatenate(out_q_tuple, axis=1)
83+
84+
out_func = relay.Function(relay.analysis.free_vars(out_graph), out_graph)
85+
out_mod = tvm.IRModule.from_expr(out_func)
86+
return out_mod
87+
88+
89+
def test_simplify_qnn_concat():
90+
mod = get_test_module()
91+
mod = tvm.relay.transform.InferType()(mod)
92+
mod = simplify_qnn_concat(mod)
93+
94+
out_mod = get_expected_output_module()
95+
out_mod = tvm.relay.transform.InferType()(out_mod)
96+
97+
assert tvm.ir.structural_equal(mod["main"], out_mod["main"])
98+
99+
100+
if __name__ == "__main__":
101+
testing.main()

0 commit comments

Comments
 (0)