@@ -359,6 +359,10 @@ def qnn_conv2d(expr):
359359 kernel_typ = args [1 ].checked_type
360360 if len (kernel_typ .shape ) != 4 or kernel_typ .dtype not in qnn_dtypes :
361361 return False
362+ if is_per_channel_quantization (
363+ zero_point = args [2 ], scale = args [4 ]
364+ ) or is_per_channel_quantization (zero_point = args [3 ], scale = args [5 ]):
365+ return False
362366 is_depthwise = is_depthwise_conv2d (
363367 data_typ .shape ,
364368 attrs ["data_layout" ],
@@ -422,6 +426,10 @@ def qnn_dense(expr):
422426 return False
423427 if attrs .out_dtype != "int32" :
424428 return False
429+ if is_per_channel_quantization (
430+ zero_point = args [2 ], scale = args [4 ]
431+ ) or is_per_channel_quantization (zero_point = args [3 ], scale = args [5 ]):
432+ return False
425433 return True
426434
427435
@@ -514,10 +522,24 @@ def qnn_add(expr):
514522 for typ in [args [0 ].checked_type , args [1 ].checked_type ]:
515523 if typ .dtype not in ["int8" , "uint8" ]:
516524 return False
517-
525+ if (
526+ is_per_channel_quantization (zero_point = args [3 ], scale = args [2 ])
527+ or is_per_channel_quantization (zero_point = args [5 ], scale = args [4 ])
528+ or is_per_channel_quantization (zero_point = args [7 ], scale = args [6 ])
529+ ):
530+ return False
518531 return True
519532
520533
534+ def is_per_channel_quantization (zero_point , scale ):
535+ """Check if the quantization is per-channel"""
536+ for value in [zero_point , scale ]:
537+ shape = value .checked_type .shape
538+ if len (shape ) != 0 and shape [0 ] != 1 :
539+ return True
540+ return False
541+
542+
521543class OpAttrContext (object ):
522544 """Temporarily changes the attr of an op."""
523545
0 commit comments