@@ -340,6 +340,62 @@ def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op):
340340 )
341341
342342
343+ def helper_change_dtypes_to_uint8 (attrs , inputs , types , relay_op ):
344+ """Helper function to change dtypes to uint8 x uint8.
345+ Legalizes QNN dense op for Hexagon DSP. It supports fast u8 x u8 vrmpy instruction.
346+
347+ Converting from int8 to uint8 can be done in following manner:
348+
349+ Original equation
350+ scale * (QA - zp_a)
351+ scale * (QA + 128 - 128 - zp_a)
352+ scale * ( (QA + 128) - (zp_a + 128))
353+
354+ Replacing QA + 128 with QA' and (zp_a + 128) with zp_a'
355+ We get our new quantized uint8 tensor - scale * (QA' - zp_a')
356+
357+ Parameters
358+ ----------
359+ attrs : tvm.ir.Attrs
360+ Attributes of current convolution
361+ inputs : list of tvm.relay.Expr
362+ The args of the Relay expr to be legalized
363+ types : list of types
364+ List of input and output types
365+
366+ Returns
367+ -------
368+ result : tvm.relay.Expr
369+ The legalized expr
370+ """
371+ # Collect the dtypes.
372+ data_dtype = types [0 ].dtype
373+ kernel_dtype = types [1 ].dtype
374+
375+ # Do nothing since it is already uint8.
376+ if data_dtype == "uint8" and kernel_dtype == "uint8" :
377+ return None
378+
379+ # Collect the input exprs.
380+ data , kernel , input_zero_point , kernel_zero_point , input_scale , kernel_scale = inputs
381+
382+ # Shift input if necessary.
383+ if data_dtype == "int8" :
384+ # Compute (QA + 128) and (zp_a + 128)
385+ data , input_zero_point = _shift (data , input_zero_point , "uint8" )
386+
387+ # Shift kernel if necessary.
388+ if kernel_dtype == "int8" :
389+ # Compute (QA + 128) and (zp_a + 128)
390+ kernel , kernel_zero_point = _shift (kernel , kernel_zero_point , "uint8" )
391+
392+ # Call qnn.conv2d/qnn.dense with modified inputs and zero points.
393+ new_attrs = dict (attrs )
394+ return relay_op (
395+ data , kernel , input_zero_point , kernel_zero_point , input_scale , kernel_scale , ** new_attrs
396+ )
397+
398+
343399# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
344400def helper_change_dtypes_to_be_same (attrs , inputs , types , relay_op ):
345401 """Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
@@ -520,7 +576,7 @@ def _qnn_conv2d_legalize_hexagon(attrs, inputs, types):
520576 out_channel = kernel_tensor .shape [0 ].value
521577 ic_modified = False
522578 oc_modified = False
523- data , kernel , input_zp , output_zp , input_scale , output_scale = inputs
579+ data , kernel , data_zp , kernel_zp , data_scale , kernel_scale = inputs
524580
525581 if in_channel % IN_CHANNEL_VECTOR_LENGTH != 0 :
526582 new_in_channel = helper_align_up (in_channel , IN_CHANNEL_VECTOR_LENGTH )
@@ -537,21 +593,93 @@ def _qnn_conv2d_legalize_hexagon(attrs, inputs, types):
537593 kernel = relay .nn .pad (kernel , pad_width = ((0 , diff ), (0 , 0 ), (0 , 0 ), (0 , 0 )))
538594 oc_modified = True
539595
596+ # Pad kernel zero point by 'diff' elements of 0 if it is not scalar
597+ kernel_zp_tensor = types [3 ]
598+ if len (kernel_zp_tensor .shape ) != 0 :
599+ assert isinstance (kernel_zp , relay .Constant )
600+ padded_kernel_zp_np = np .append (kernel_zp .data .numpy (), [0 ] * diff )
601+ kernel_zp = relay .const (padded_kernel_zp_np )
602+
603+ # Pad kernel scale by 'diff' elements of 1.0 if it is not scalar
604+ kernel_scale_tensor = types [5 ]
605+ if len (kernel_scale_tensor .shape ) != 0 :
606+ assert isinstance (kernel_scale , relay .Constant )
607+ padded_kernel_scale_np = np .append (kernel_scale .data .numpy (), [1.0 ] * diff )
608+ kernel_scale = relay .const (padded_kernel_scale_np )
609+
540610 if ic_modified is True or oc_modified is True :
541611 new_attrs = dict (attrs )
542612 if oc_modified :
543613 new_attrs ["channels" ] = new_out_channel
544614 out = relay .qnn .op .conv2d (
545- data , kernel , input_zp , output_zp , input_scale , output_scale , ** new_attrs
615+ data , kernel , data_zp , kernel_zp , data_scale , kernel_scale , ** new_attrs
546616 )
547617 output_tensor = types [6 ]
548618 original_out_shape = list (output_tensor .shape )
549619 out = relay .strided_slice (out , begin = [0 , 0 , 0 , 0 ], end = original_out_shape )
550620 else :
551621 out = relay .qnn .op .conv2d (
552- data , kernel , input_zp , output_zp , input_scale , output_scale , ** new_attrs
622+ data , kernel , data_zp , kernel_zp , data_scale , kernel_scale , ** new_attrs
553623 )
554624
555625 return out
556626
557627 return None
628+
629+
630+ @qnn_dense_legalize .register ("hexagon" )
631+ def _qnn_dense_legalize_hexagon (attrs , inputs , types ):
632+ """Legalize qnn.dense op for vrmpy tensorization.
633+
634+ N dimension of weights should be aligned on vector length. If not, then N dimension is padded to
635+ be a multiple of 32.
636+ """
637+ assert len (types ) == 7
638+ assert len (inputs ) == 6
639+
640+ data_tensor , kernel_tensor = types [0 ], types [1 ]
641+ if "int8" not in data_tensor .dtype or "int8" not in kernel_tensor .dtype :
642+ return None
643+
644+ N , _ = kernel_tensor .shape
645+
646+ if N % OUT_CHANNEL_VECTOR_LENGTH != 0 :
647+ N_padded = helper_align_up (N , OUT_CHANNEL_VECTOR_LENGTH )
648+ diff = N_padded - N
649+
650+ data , kernel , data_zp , kernel_zp , data_scale , kernel_scale = inputs
651+
652+ # Pad weights by 'diff'
653+ padded_kernel = relay .nn .pad (kernel , pad_width = ((0 , diff ), (0 , 0 )))
654+
655+ kernel_zp_tensor , kernel_scale_tensor = types [3 ], types [5 ]
656+
657+ # Pad kernel zero point by 'diff' elements of 0 if it is not scalar
658+ if len (kernel_zp_tensor .shape ) != 0 :
659+ assert isinstance (kernel_zp , relay .Constant )
660+ assert isinstance (diff , tvm .tir .IntImm )
661+ padded_kernel_zp_np = np .append (kernel_zp .data .numpy (), [0 ] * diff .value )
662+ kernel_zp = relay .const (padded_kernel_zp_np )
663+
664+ # Pad kernel scale by 'diff' elements of 1.0 if it is not scalar
665+ if len (kernel_scale_tensor .shape ) != 0 :
666+ assert isinstance (kernel_scale , relay .Constant )
667+ assert isinstance (diff , tvm .tir .IntImm )
668+ padded_kernel_scale_np = np .append (kernel_scale .data .numpy (), [1.0 ] * diff .value )
669+ kernel_scale = relay .const (padded_kernel_scale_np )
670+
671+ # If units is explicitly specified, it is used to compute the output shape.
672+ # We need to update units after padding to prevent a type error.
673+ new_attrs = dict (attrs )
674+ if attrs ["units" ] is not None :
675+ new_attrs ["units" ] = N + diff
676+
677+ new_inputs = (data , padded_kernel , data_zp , kernel_zp , data_scale , kernel_scale )
678+
679+ out = relay .qnn .op .dense (* new_inputs , ** new_attrs )
680+
681+ output_tensor = types [6 ]
682+ out = relay .strided_slice (out , begin = [0 , 0 ], end = list (output_tensor .shape ))
683+ return out
684+
685+ return None
0 commit comments