@@ -50,12 +50,15 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5050 if (data == nullptr || weight == nullptr ) return false ;
5151 const auto * param = attrs.as <Conv2DAttrs>();
5252 ICHECK (param != nullptr ) << " Conv2DAttrs cannot be nullptr." ;
53- ICHECK (data->dtype == DataType::Int (8 ) || data->dtype == DataType::UInt (8 ))
54- << " Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype ;
55- ICHECK (weight->dtype == DataType::Int (8 ) || weight->dtype == DataType::UInt (8 ))
56- << " Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype ;
57- ICHECK (param->out_dtype == DataType::Int (16 ) || param->out_dtype == DataType::Int (32 ))
58- << " Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype ;
53+ ICHECK (data->dtype == DataType::Int (8 ) || data->dtype == DataType::UInt (8 ) ||
54+ data->dtype == DataType::Int (16 ))
55+ << " Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype ;
56+ ICHECK (weight->dtype == DataType::Int (8 ) || weight->dtype == DataType::UInt (8 ) ||
57+ weight->dtype == DataType::Int (16 ))
58+ << " Expected qnn conv2d type(int8, uint8, int16) for weight but was " << weight->dtype ;
59+ ICHECK (param->out_dtype == DataType::Int (16 ) || param->out_dtype == DataType::Int (32 ) ||
60+ param->out_dtype == DataType::Int (64 ))
61+ << " Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype ;
5962 ICHECK (param->out_dtype .bits () > 0 ) << " Output dtype bits should be greater than 0." ;
6063
6164 // Check the types of scale and zero points.
@@ -190,19 +193,21 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DA
190193 */
191194Expr Conv2DFallBack (const Expr& data, const Expr& weight, const Expr& input_zero_point,
192195 const Expr& kernel_zero_point, const Conv2DAttrs* param) {
193- // Upcast the zero point to Int16.
194- auto zp_data = Cast (input_zero_point, DataType::Int (16 ));
195- auto zp_kernel = Cast (kernel_zero_point, DataType::Int (16 ));
196+ // Upcast the parameters to be at least int32 to avoid overflow
197+ auto upcast_bits = param->out_dtype .bits () < 32 ? 32 : param->out_dtype .bits ();
196198
197- auto shifted_data = Cast (data, DataType::Int (16 ));
198- auto zero_scalar = MakeConstantScalar (DataType::Int (32 ), 0 );
199+ auto zp_data = Cast (input_zero_point, DataType::Int (upcast_bits));
200+ auto zp_kernel = Cast (kernel_zero_point, DataType::Int (upcast_bits));
201+
202+ auto shifted_data = Cast (data, DataType::Int (upcast_bits));
203+ auto zero_scalar = MakeConstantScalar (DataType::Int (upcast_bits), 0 );
199204 if (!IsEqualScalar (input_zero_point, zero_scalar)) {
200- shifted_data = Subtract (Cast (data, DataType::Int (16 )), zp_data);
205+ shifted_data = Subtract (Cast (data, DataType::Int (upcast_bits )), zp_data);
201206 }
202207
203- auto shifted_kernel = Cast (weight, DataType::Int (16 ));
208+ auto shifted_kernel = Cast (weight, DataType::Int (upcast_bits ));
204209 if (!IsEqualScalar (kernel_zero_point, zero_scalar)) {
205- shifted_kernel = Subtract (Cast (weight, DataType::Int (16 )), zp_kernel);
210+ shifted_kernel = Subtract (Cast (weight, DataType::Int (upcast_bits )), zp_kernel);
206211 }
207212
208213 return Conv2D (shifted_data, shifted_kernel, param->strides , param->padding , param->dilation ,
@@ -557,17 +562,19 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con
557562 * \param in_channels The number of input channels.
558563 * \param kernel_h The height of kernel.
559564 * \param kernel_w The width of kernel.
565+ * \param param The qnn conv2d attributes.
560566 * \return The sequence of Relay operators for term4.
561567 * \note The term4 looks like this
562568 *
563569 * Sigma(c,r,s) zp_a * zp_w
564570 *
565571 */
566572Expr Conv2DFourthTerm (int input_zero_point_int, int kernel_zero_point_int, int in_channels,
567- int kernel_h, int kernel_w) {
573+ int kernel_h, int kernel_w, const Conv2DAttrs* param) {
574+ auto upcast_bits = param->out_dtype .bits () < 32 ? 32 : param->out_dtype .bits ();
568575 int scalar_term4 =
569576 input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w;
570- return MakeConstantScalar (DataType::Int (32 ), scalar_term4);
577+ return MakeConstantScalar (DataType::Int (upcast_bits ), scalar_term4);
571578}
572579
573580/*
@@ -578,15 +585,18 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i
578585 * \param in_channels The number of input channels.
579586 * \param kernel_h The height of kernel.
580587 * \param kernel_w The width of kernel.
588+ * \param param The qnn conv2d attributes.
581589 * \return The sequence of Relay operators for term4.
582590 * \note The term4 looks like this
583591 *
584592 * Sigma(c,r,s) zp_a * zp_w
585593 *
586594 */
587595Expr Conv2DFourthTerm (const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels,
588- int kernel_h, int kernel_w) {
589- Expr scalar_term4 = MakeConstantScalar (DataType::Int (32 ), in_channels * kernel_h * kernel_w);
596+ int kernel_h, int kernel_w, const Conv2DAttrs* param) {
597+ auto upcast_bits = param->out_dtype .bits () < 32 ? 32 : param->out_dtype .bits ();
598+ Expr scalar_term4 =
599+ MakeConstantScalar (DataType::Int (upcast_bits), in_channels * kernel_h * kernel_w);
590600 Expr variable_term4 = Multiply (input_zero_point, kernel_zero_point);
591601 return Multiply (scalar_term4, variable_term4);
592602}
@@ -791,10 +801,11 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
791801 auto term3 = Conv2DThirdTerm (weight, input_zero_point, param, out_channels);
792802 Expr term4;
793803 if (dynamic_zp) {
794- term4 = Conv2DFourthTerm (input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w);
804+ term4 = Conv2DFourthTerm (input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w,
805+ param);
795806 } else {
796807 term4 = Conv2DFourthTerm (input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
797- kernel_w);
808+ kernel_w, param );
798809 }
799810 return Conv2DCombineTerms (term1, term2, term3, term4, input_zero_point_int,
800811 kernel_zero_point_int);
@@ -829,7 +840,7 @@ This operator convolves quantized weight with quantized data. The scale of the
829840output quantized tensor is the product of the weight_scale and input_scale of
830841the input quantized tensors. The zero point of the output quantized tensor is
8318420. By default, the dtype of output is int32. Please also refer to Requantize
832- operator to understand how to scale back the int32 output to (u)int8.
843+ operator to understand how to scale back the int32 output to (u)int8 or (u)int16 .
833844- **data**: This depends on the `layout` parameter. Input is 4D array of shape
834845 (batch_size, in_channels, height, width) if `layout` is `NCHW`.
835846- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
0 commit comments