diff --git a/python/tvm/contrib/hexagon/transform.py b/python/tvm/contrib/hexagon/transform.py index 2e5e84342bef..664739dea580 100644 --- a/python/tvm/contrib/hexagon/transform.py +++ b/python/tvm/contrib/hexagon/transform.py @@ -21,8 +21,16 @@ import tvm from tvm import relay -from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard -from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple +from tvm.relay.dataflow_pattern import ( + DFPatternCallback, + is_constant, + is_op, + is_tuple, + rewrite, + wildcard, +) +from tvm.relay.expr import Call + from ..._ffi.registry import register_func ### VTCM @@ -43,7 +51,6 @@ def mem_info_vtcm(): def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx): # pylint: disable=unused-argument - """Generic VTCM allocation Parameters @@ -311,3 +318,95 @@ def remove_empty_pad(mod): """Remove the empty pad operator.""" mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"]) return mod + + +class simplify_qnn_concat_in_func(DFPatternCallback): + + """ + Propagate qnn.concat's quantization params to its inputs, + and try to avoid redundant requantization while doing so. + + Replace + def @main(%q1: Tensor[(1, 64, 35, 35), uint8], + %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) { + %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC"); + %1 = qnn.requantize(%q2, 0.000109401f, 0, 0.00345f, 0, axis=1, out_dtype="uint8"); + %2 = (%0, %1, %q3); + %3 = (0.0425042f, 0.00345f, 0.0486874f); + %4 = (0, 0, 0); + qnn.concatenate(%2, %3, %4, 0.0486874f, 0, axis=1) + } + + with + + def @main(%q1: Tensor[(1, 64, 35, 35), uint8], + %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) { + %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC"); + %1 = qnn.requantize(%0, 0.0425042f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8"); + %2 = qnn.requantize(%q2, 0.000109401f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8"); + %3 = (%1, %2, %q3); + concatenate(%3, axis=1) + } + """ + + def __init__(self): + super(simplify_qnn_concat_in_func, self).__init__() + self.qvals = wildcard() + self.scales = wildcard() + self.zps = wildcard() + self.out_scale = wildcard() + self.out_zp = wildcard() + self.pattern = is_op("qnn.concatenate")( + self.qvals, self.scales, self.zps, self.out_scale, self.out_zp + ) + + def callback(self, pre, post, node_map): + in_qvals = node_map[self.qvals][0] + in_scales = node_map[self.scales][0] + in_zps = node_map[self.zps][0] + new_qvals = [] + for i in range(len(in_qvals)): + new_requant_args = [] + # TODO Generalize for all qnn ops + if isinstance(in_qvals[i], Call) and (in_qvals[i].op.name == "qnn.requantize"): + # propagate scale/zp of qnn.concat to this requantize op + for j in range(3): + new_requant_args.append(in_qvals[i].args[j]) + new_requant_args += [node_map[self.out_scale][0], node_map[self.out_zp][0]] + new_qvals.append(relay.qnn.op.requantize(*new_requant_args, **(in_qvals[i].attrs))) + else: + # simply create a new requantize op if there is a change in quantization params + # if not, just retain the old qval + if (in_scales[i] == node_map[self.out_scale][0]) and ( + in_zps[i] == node_map[self.out_zp][0] + ): + new_qvals.append(in_qvals[i]) + else: + new_requant_args += [ + in_qvals[i], + in_scales[i], + in_zps[i], + node_map[self.out_scale][0], + node_map[self.out_zp][0], + ] + new_qvals.append( + relay.qnn.op.requantize( + *new_requant_args, + axis=post.attrs["axis"], + out_dtype=post.checked_type.dtype, + ) + ) + + new_op = relay.op.concatenate( + new_qvals, + node_map[self.pattern][0].attrs["axis"], + ) + return new_op + + +# Right now context is ignored +@tvm.transform.module_pass(opt_level=1) +def simplify_qnn_concat(mod, _=None): + for global_var in mod.functions.keys(): + mod[global_var] = rewrite(simplify_qnn_concat_in_func(), mod[global_var]) + return mod diff --git a/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py new file mode 100644 index 000000000000..ad1d7592fc29 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import, invalid-name + +""" +Test hexagon relay transform - qnn.concat optimization +""" +import tvm +from tvm import relay, testing +from tvm.contrib.hexagon.transform import simplify_qnn_concat + + +def get_test_module(): + """Creates a test relay module and returns it.""" + q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8") + q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8") + q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8") + s2 = relay.const(0.000109401, dtype="float32") + s3 = relay.const(0.0486874, dtype="float32") + s4 = relay.const(0.0425042, dtype="float32") + s5 = relay.const(0.00345, dtype="float32") + z1 = relay.const(0, dtype="int32") + r1 = relay.op.nn.max_pool2d( + q1, + pool_size=[3, 3], + strides=[1, 1], + padding=[1, 1], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + ) + r2 = relay.qnn.op.requantize(q2, s2, z1, s5, z1, axis=1, out_dtype="uint8") + q_tuple = relay.expr.Tuple([r1, r2, q3]) + s_tuple = relay.expr.Tuple([s4, s5, s3]) + z_tuple = relay.expr.Tuple([z1, z1, z1]) + graph = relay.qnn.op.concatenate(q_tuple, s_tuple, z_tuple, s3, z1, axis=1) + + func = relay.Function(relay.analysis.free_vars(graph), graph) + mod = tvm.IRModule.from_expr(func) + return mod + + +def get_expected_output_module(): + """Returns manually created expected output module.""" + out_q1 = relay.var("q1", shape=(1, 64, 35, 35), dtype="uint8") + out_q2 = relay.var("q2", shape=(1, 64, 35, 35), dtype="uint8") + out_q3 = relay.var("q3", shape=(1, 32, 35, 35), dtype="uint8") + out_s2 = relay.const(0.000109401, dtype="float32") + out_s3 = relay.const(0.0486874, dtype="float32") + out_s4 = relay.const(0.0425042, dtype="float32") + out_z1 = relay.const(0, dtype="int32") + nn_max_pool = relay.op.nn.max_pool2d( + out_q1, + pool_size=[3, 3], + strides=[1, 1], + padding=[1, 1], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + ) + out_r1 = relay.qnn.op.requantize( + nn_max_pool, out_s4, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8" + ) + out_r2 = relay.qnn.op.requantize( + out_q2, out_s2, out_z1, out_s3, out_z1, axis=1, out_dtype="uint8" + ) + out_q_tuple = relay.expr.Tuple([out_r1, out_r2, out_q3]) + out_graph = relay.op.concatenate(out_q_tuple, axis=1) + + out_func = relay.Function(relay.analysis.free_vars(out_graph), out_graph) + out_mod = tvm.IRModule.from_expr(out_func) + return out_mod + + +def test_simplify_qnn_concat(): + mod = get_test_module() + mod = tvm.relay.transform.InferType()(mod) + mod = simplify_qnn_concat(mod) + + out_mod = get_expected_output_module() + out_mod = tvm.relay.transform.InferType()(out_mod) + + assert tvm.ir.structural_equal(mod["main"], out_mod["main"]) + + +if __name__ == "__main__": + testing.main()