@@ -50,12 +50,14 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5050 if (data == nullptr || weight == nullptr ) return false ;
5151 const auto * param = attrs.as <Conv2DAttrs>();
5252 ICHECK (param != nullptr ) << " Conv2DAttrs cannot be nullptr." ;
53- ICHECK (data->dtype == DataType::Int (8 ) || data->dtype == DataType::UInt (8 ))
54- << " Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype ;
53+ ICHECK (data->dtype == DataType::Int (8 ) || data->dtype == DataType::UInt (8 ) ||
54+ data->dtype == DataType::Int (16 ))
55+ << " Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype ;
5556 ICHECK (weight->dtype == DataType::Int (8 ) || weight->dtype == DataType::UInt (8 ))
56- << " Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype ;
57- ICHECK (param->out_dtype == DataType::Int (16 ) || param->out_dtype == DataType::Int (32 ))
58- << " Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype ;
57+ << " Expected qnn conv2d type(int8, uint8, int16) for weight but was " << weight->dtype ;
58+ ICHECK (param->out_dtype == DataType::Int (16 ) || param->out_dtype == DataType::Int (32 ) ||
59+ param->out_dtype == DataType::Int (64 ))
60+ << " Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype ;
5961 ICHECK (param->out_dtype .bits () > 0 ) << " Output dtype bits should be greater than 0." ;
6062
6163 // Check the types of scale and zero points.
@@ -190,19 +192,21 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DA
190192 */
191193Expr Conv2DFallBack (const Expr& data, const Expr& weight, const Expr& input_zero_point,
192194 const Expr& kernel_zero_point, const Conv2DAttrs* param) {
193- // Upcast the zero point to Int16.
194- auto zp_data = Cast (input_zero_point, DataType::Int (16 ));
195- auto zp_kernel = Cast (kernel_zero_point, DataType::Int (16 ));
195+ // Upcast the parameters to be at least int32 to avoid overflow
196+ auto upcast_bits = param->out_dtype .bits () < 32 ? 32 : param->out_dtype .bits ();
196197
197- auto shifted_data = Cast (data, DataType::Int (16 ));
198- auto zero_scalar = MakeConstantScalar (DataType::Int (32 ), 0 );
198+ auto zp_data = Cast (input_zero_point, DataType::Int (upcast_bits));
199+ auto zp_kernel = Cast (kernel_zero_point, DataType::Int (upcast_bits));
200+
201+ auto shifted_data = Cast (data, DataType::Int (upcast_bits));
202+ auto zero_scalar = MakeConstantScalar (DataType::Int (upcast_bits), 0 );
199203 if (!IsEqualScalar (input_zero_point, zero_scalar)) {
200- shifted_data = Subtract (Cast (data, DataType::Int (16 )), zp_data);
204+ shifted_data = Subtract (Cast (data, DataType::Int (upcast_bits )), zp_data);
201205 }
202206
203- auto shifted_kernel = Cast (weight, DataType::Int (16 ));
207+ auto shifted_kernel = Cast (weight, DataType::Int (upcast_bits ));
204208 if (!IsEqualScalar (kernel_zero_point, zero_scalar)) {
205- shifted_kernel = Subtract (Cast (weight, DataType::Int (16 )), zp_kernel);
209+ shifted_kernel = Subtract (Cast (weight, DataType::Int (upcast_bits )), zp_kernel);
206210 }
207211
208212 return Conv2D (shifted_data, shifted_kernel, param->strides , param->padding , param->dilation ,
@@ -557,17 +561,19 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con
557561 * \param in_channels The number of input channels.
558562 * \param kernel_h The height of kernel.
559563 * \param kernel_w The width of kernel.
564+ * \param param The qnn conv2d attributes.
560565 * \return The sequence of Relay operators for term4.
561566 * \note The term4 looks like this
562567 *
563568 * Sigma(c,r,s) zp_a * zp_w
564569 *
565570 */
566571Expr Conv2DFourthTerm (int input_zero_point_int, int kernel_zero_point_int, int in_channels,
567- int kernel_h, int kernel_w) {
572+ int kernel_h, int kernel_w, const Conv2DAttrs* param) {
573+ auto upcast_bits = param->out_dtype .bits () < 32 ? 32 : param->out_dtype .bits ();
568574 int scalar_term4 =
569575 input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w;
570- return MakeConstantScalar (DataType::Int (32 ), scalar_term4);
576+ return MakeConstantScalar (DataType::Int (upcast_bits ), scalar_term4);
571577}
572578
573579/*
@@ -578,15 +584,18 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i
578584 * \param in_channels The number of input channels.
579585 * \param kernel_h The height of kernel.
580586 * \param kernel_w The width of kernel.
587+ * \param param The qnn conv2d attributes.
581588 * \return The sequence of Relay operators for term4.
582589 * \note The term4 looks like this
583590 *
584591 * Sigma(c,r,s) zp_a * zp_w
585592 *
586593 */
587594Expr Conv2DFourthTerm (const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels,
588- int kernel_h, int kernel_w) {
589- Expr scalar_term4 = MakeConstantScalar (DataType::Int (32 ), in_channels * kernel_h * kernel_w);
595+ int kernel_h, int kernel_w, const Conv2DAttrs* param) {
596+ auto upcast_bits = param->out_dtype .bits () < 32 ? 32 : param->out_dtype .bits ();
597+ Expr scalar_term4 =
598+ MakeConstantScalar (DataType::Int (upcast_bits), in_channels * kernel_h * kernel_w);
590599 Expr variable_term4 = Multiply (input_zero_point, kernel_zero_point);
591600 return Multiply (scalar_term4, variable_term4);
592601}
@@ -791,10 +800,11 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
791800 auto term3 = Conv2DThirdTerm (weight, input_zero_point, param, out_channels);
792801 Expr term4;
793802 if (dynamic_zp) {
794- term4 = Conv2DFourthTerm (input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w);
803+ term4 = Conv2DFourthTerm (input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w,
804+ param);
795805 } else {
796806 term4 = Conv2DFourthTerm (input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
797- kernel_w);
807+ kernel_w, param );
798808 }
799809 return Conv2DCombineTerms (term1, term2, term3, term4, input_zero_point_int,
800810 kernel_zero_point_int);
@@ -829,7 +839,7 @@ This operator convolves quantized weight with quantized data. The scale of the
829839output quantized tensor is the product of the weight_scale and input_scale of
830840the input quantized tensors. The zero point of the output quantized tensor is
8318410. By default, the dtype of output is int32. Please also refer to Requantize
832- operator to understand how to scale back the int32 output to (u)int8.
842+ operator to understand how to scale back the int32 output to (u)int8 or (u)int16 .
833843- **data**: This depends on the `layout` parameter. Input is 4D array of shape
834844 (batch_size, in_channels, height, width) if `layout` is `NCHW`.
835845- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
0 commit comments