@@ -340,6 +340,62 @@ def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op):
340340 )
341341
342342
343+ def helper_change_dtypes_to_uint8 (attrs , inputs , types , relay_op ):
344+ """Helper function to change dtypes to uint8 x uint8.
345+ Legalizes QNN dense op for Hexagon DSP. It supports fast u8 x u8 vrmpy instruction.
346+
347+ Converting from int8 to uint8 can be done in following manner:
348+
349+ Original equation
350+ scale * (QA - zp_a)
351+ scale * (QA + 128 - 128 - zp_a)
352+ scale * ( (QA + 128) - (zp_a + 128))
353+
354+ Replacing QA + 128 with QA' and (zp_a + 128) with zp_a'
355+ We get our new quantized uint8 tensor - scale * (QA' - zp_a')
356+
357+ Parameters
358+ ----------
359+ attrs : tvm.ir.Attrs
360+ Attributes of current convolution
361+ inputs : list of tvm.relay.Expr
362+ The args of the Relay expr to be legalized
363+ types : list of types
364+ List of input and output types
365+
366+ Returns
367+ -------
368+ result : tvm.relay.Expr
369+ The legalized expr
370+ """
371+ # Collect the dtypes.
372+ data_dtype = types [0 ].dtype
373+ kernel_dtype = types [1 ].dtype
374+
375+ # Do nothing since it is already uint8.
376+ if data_dtype == "uint8" and kernel_dtype == "uint8" :
377+ return None
378+
379+ # Collect the input exprs.
380+ data , kernel , input_zero_point , kernel_zero_point , input_scale , kernel_scale = inputs
381+
382+ # Shift input if necessary.
383+ if data_dtype == "int8" :
384+ # Compute (QA + 128) and (zp_a + 128)
385+ data , input_zero_point = _shift (data , input_zero_point , "uint8" )
386+
387+ # Shift kernel if necessary.
388+ if kernel_dtype == "int8" :
389+ # Compute (QA + 128) and (zp_a + 128)
390+ kernel , kernel_zero_point = _shift (kernel , kernel_zero_point , "uint8" )
391+
392+ # Call qnn.conv2d/qnn.dense with modified inputs and zero points.
393+ new_attrs = dict (attrs )
394+ return relay_op (
395+ data , kernel , input_zero_point , kernel_zero_point , input_scale , kernel_scale , ** new_attrs
396+ )
397+
398+
343399# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
344400def helper_change_dtypes_to_be_same (attrs , inputs , types , relay_op ):
345401 """Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
@@ -555,3 +611,54 @@ def _qnn_conv2d_legalize_hexagon(attrs, inputs, types):
555611 return out
556612
557613 return None
614+
615+
616+ @qnn_dense_legalize .register ("hexagon" )
617+ def _qnn_dense_legalize_hexagon (attrs , inputs , types ):
618+ """Legalize qnn.dense op for vrmpy tensorization.
619+
620+ N dimension of weights should be aligned on vector length. If not, then N dimension is padded to
621+ be a multiple of 32.
622+ """
623+ assert len (types ) == 7
624+ assert len (inputs ) == 6
625+
626+ data_tensor , kernel_tensor = types [0 ], types [1 ]
627+ if "int8" not in data_tensor .dtype or "int8" not in kernel_tensor .dtype :
628+ return None
629+
630+ N , _ = kernel_tensor .shape
631+
632+ if N % OUT_CHANNEL_VECTOR_LENGTH != 0 :
633+ N_padded = helper_align_up (N , OUT_CHANNEL_VECTOR_LENGTH )
634+ diff = N_padded - N
635+
636+ # Padd weights by 'diff'
637+ padded_kernel = relay .nn .pad (inputs [1 ], pad_width = ((0 , diff ), (0 , 0 )))
638+
639+ # If units is explicitly specified, it is used to compute the output shape.
640+ # We need to update units after padding to prevent a type error.
641+ new_attrs = dict (attrs )
642+ if attrs ["units" ] is not None :
643+ new_attrs ["units" ] = N + diff
644+
645+ new_inputs = (inputs [0 ], padded_kernel , * inputs [2 :])
646+
647+ # TODO: enable legalization u8i8i32 -> u8u8i32 for qnn.dense. Code:
648+ # """
649+ # new_types = (types[0], relay.TensorType([N + diff, C], types[1].dtype), *types[2:])
650+ # out = helper_change_dtypes_to_uint8(new_attrs, new_inputs, new_types, relay.qnn.op.dense)
651+ # if out is None:
652+ # out = relay.qnn.op.dense(*new_inputs, **new_attrs)
653+ # """
654+ out = relay .qnn .op .dense (* new_inputs , ** new_attrs )
655+
656+ output_tensor = types [6 ]
657+ out = relay .strided_slice (out , begin = [0 , 0 ], end = list (output_tensor .shape ))
658+ return out
659+
660+ # TODO: enable legalization u8i8i32 -> u8u8i32 for qnn.dense. Code:
661+ # """
662+ # return helper_change_dtypes_to_uint8(attrs, inputs, types, relay.qnn.op.dense)
663+ # """
664+ return None
0 commit comments