2121
2222import tvm
2323from tvm import relay
24- from tvm .relay .dataflow_pattern import DFPatternCallback , rewrite , wildcard
25- from tvm .relay .dataflow_pattern import is_constant , is_op , is_tuple
24+ from tvm .relay .dataflow_pattern import (
25+ DFPatternCallback ,
26+ is_constant ,
27+ is_op ,
28+ is_tuple ,
29+ rewrite ,
30+ wildcard ,
31+ )
32+ from tvm .relay .expr import Call
33+
2634from ..._ffi .registry import register_func
2735
2836### VTCM
@@ -43,7 +51,6 @@ def mem_info_vtcm():
4351
4452
4553def lower_vtcm_ (get_alloc , get_free , def_align , func , mod , ctx ): # pylint: disable=unused-argument
46-
4754 """Generic VTCM allocation
4855
4956 Parameters
@@ -311,3 +318,95 @@ def remove_empty_pad(mod):
311318 """Remove the empty pad operator."""
312319 mod ["main" ] = rewrite (remove_empty_pad_callback (), mod ["main" ])
313320 return mod
321+
322+
323+ class simplify_qnn_concat_in_func (DFPatternCallback ):
324+
325+ """
326+ Propagate qnn.concat's quantization params to its inputs,
327+ and try to avoid redundant requantization while doing so.
328+
329+ Replace
330+ def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
331+ %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) {
332+ %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC");
333+ %1 = qnn.requantize(%q2, 0.000109401f, 0, 0.00345f, 0, axis=1, out_dtype="uint8");
334+ %2 = (%0, %1, %q3);
335+ %3 = (0.0425042f, 0.00345f, 0.0486874f);
336+ %4 = (0, 0, 0);
337+ qnn.concatenate(%2, %3, %4, 0.0486874f, 0, axis=1)
338+ }
339+
340+ with
341+
342+ def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
343+ %q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) {
344+ %0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC");
345+ %1 = qnn.requantize(%0, 0.0425042f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8");
346+ %2 = qnn.requantize(%q2, 0.000109401f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8");
347+ %3 = (%1, %2, %q3);
348+ concatenate(%3, axis=1)
349+ }
350+ """
351+
352+ def __init__ (self ):
353+ super (simplify_qnn_concat_in_func , self ).__init__ ()
354+ self .qvals = wildcard ()
355+ self .scales = wildcard ()
356+ self .zps = wildcard ()
357+ self .out_scale = wildcard ()
358+ self .out_zp = wildcard ()
359+ self .pattern = is_op ("qnn.concatenate" )(
360+ self .qvals , self .scales , self .zps , self .out_scale , self .out_zp
361+ )
362+
363+ def callback (self , pre , post , node_map ):
364+ in_qvals = node_map [self .qvals ][0 ]
365+ in_scales = node_map [self .scales ][0 ]
366+ in_zps = node_map [self .zps ][0 ]
367+ new_qvals = []
368+ for i in range (len (in_qvals )):
369+ new_requant_args = []
370+ # TODO Generalize for all qnn ops
371+ if isinstance (in_qvals [i ], Call ) and (in_qvals [i ].op .name == "qnn.requantize" ):
372+ # propagate scale/zp of qnn.concat to this requantize op
373+ for j in range (3 ):
374+ new_requant_args .append (in_qvals [i ].args [j ])
375+ new_requant_args += [node_map [self .out_scale ][0 ], node_map [self .out_zp ][0 ]]
376+ new_qvals .append (relay .qnn .op .requantize (* new_requant_args , ** (in_qvals [i ].attrs )))
377+ else :
378+ # simply create a new requantize op if there is a change in quantization params
379+ # if not, just retain the old qval
380+ if (in_scales [i ] == node_map [self .out_scale ][0 ]) and (
381+ in_zps [i ] == node_map [self .out_zp ][0 ]
382+ ):
383+ new_qvals .append (in_qvals [i ])
384+ else :
385+ new_requant_args += [
386+ in_qvals [i ],
387+ in_scales [i ],
388+ in_zps [i ],
389+ node_map [self .out_scale ][0 ],
390+ node_map [self .out_zp ][0 ],
391+ ]
392+ new_qvals .append (
393+ relay .qnn .op .requantize (
394+ * new_requant_args ,
395+ axis = post .attrs ["axis" ],
396+ out_dtype = post .checked_type .dtype ,
397+ )
398+ )
399+
400+ new_op = relay .op .concatenate (
401+ new_qvals ,
402+ node_map [self .pattern ][0 ].attrs ["axis" ],
403+ )
404+ return new_op
405+
406+
407+ # Right now context is ignored
408+ @tvm .transform .module_pass (opt_level = 1 )
409+ def simplify_qnn_concat (mod , _ = None ):
410+ for global_var in mod .functions .keys ():
411+ mod [global_var ] = rewrite (simplify_qnn_concat_in_func (), mod [global_var ])
412+ return mod
0 commit comments