Skip to content

Commit 8940807

Browse files
committed
[TFLite] Add support to int16 data type in TFLite frontend
Add support for int16 data type and int64 biases/accumulators in the TFLite frontend. Adjusts TFLite tests to cover int16 convolutions and element-wise; Fixes a minor typo negtive->negative in the element-wise tests.
1 parent ff7efe7 commit 8940807

File tree

6 files changed

+183
-97
lines changed

6 files changed

+183
-97
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def get_tensor_type_as_numpy(self, tensor_wrapper):
390390
return {
391391
TensorType.UINT8: np.uint8,
392392
TensorType.INT8: np.int8,
393+
TensorType.INT16: np.int16,
393394
TensorType.FLOAT16: np.float16,
394395
TensorType.FLOAT32: np.float32,
395396
TensorType.INT32: np.int32,
@@ -430,6 +431,8 @@ def get_tensor_type_str(self, tensor_type):
430431

431432
if tensor_type == TensorType.INT8:
432433
return "int8"
434+
if tensor_type == TensorType.INT16:
435+
return "int16"
433436
if tensor_type == TensorType.UINT8:
434437
return "uint8"
435438
if tensor_type == TensorType.FLOAT16:
@@ -2149,7 +2152,9 @@ def convert_conv(self, op, conv_type):
21492152
qnn_conv2d_params = dict(params)
21502153
qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"]
21512154
qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"]
2152-
qnn_conv2d_params["out_dtype"] = "int32"
2155+
qnn_conv2d_params["out_dtype"] = (
2156+
"int64" if output_tensor_type_str == "int16" else "int32"
2157+
)
21532158
qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"]
21542159
qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"]
21552160
out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params)
@@ -2160,8 +2165,8 @@ def convert_conv(self, op, conv_type):
21602165
if len(input_tensors) == 3:
21612166
bias_tensor = input_tensors[2]
21622167
bias_tensor_type = bias_tensor.tensor.Type()
2163-
# bias tensor type should be INT32 (quantization) or FLOAT32
2164-
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
2168+
# bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32
2169+
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
21652170
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
21662171
if self.has_expr(bias_tensor.tensor_idx):
21672172
bias_expr = self.get_expr(bias_tensor.tensor_idx)

src/relay/qnn/op/convolution.cc

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,15 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5050
if (data == nullptr || weight == nullptr) return false;
5151
const auto* param = attrs.as<Conv2DAttrs>();
5252
ICHECK(param != nullptr) << "Conv2DAttrs cannot be nullptr.";
53-
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
54-
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
55-
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
56-
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
57-
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
58-
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
53+
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
54+
data->dtype == DataType::Int(16))
55+
<< "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype;
56+
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8) ||
57+
weight->dtype == DataType::Int(16))
58+
<< "Expected qnn conv2d type(int8, uint8, int16) for weight but was " << weight->dtype;
59+
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
60+
param->out_dtype == DataType::Int(64))
61+
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
5962
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
6063

6164
// Check the types of scale and zero points.
@@ -190,19 +193,21 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DA
190193
*/
191194
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point,
192195
const Expr& kernel_zero_point, const Conv2DAttrs* param) {
193-
// Upcast the zero point to Int16.
194-
auto zp_data = Cast(input_zero_point, DataType::Int(16));
195-
auto zp_kernel = Cast(kernel_zero_point, DataType::Int(16));
196+
// Upcast the parameters to be at least int32 to avoid overflow
197+
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
196198

197-
auto shifted_data = Cast(data, DataType::Int(16));
198-
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
199+
auto zp_data = Cast(input_zero_point, DataType::Int(upcast_bits));
200+
auto zp_kernel = Cast(kernel_zero_point, DataType::Int(upcast_bits));
201+
202+
auto shifted_data = Cast(data, DataType::Int(upcast_bits));
203+
auto zero_scalar = MakeConstantScalar(DataType::Int(upcast_bits), 0);
199204
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
200-
shifted_data = Subtract(Cast(data, DataType::Int(16)), zp_data);
205+
shifted_data = Subtract(Cast(data, DataType::Int(upcast_bits)), zp_data);
201206
}
202207

203-
auto shifted_kernel = Cast(weight, DataType::Int(16));
208+
auto shifted_kernel = Cast(weight, DataType::Int(upcast_bits));
204209
if (!IsEqualScalar(kernel_zero_point, zero_scalar)) {
205-
shifted_kernel = Subtract(Cast(weight, DataType::Int(16)), zp_kernel);
210+
shifted_kernel = Subtract(Cast(weight, DataType::Int(upcast_bits)), zp_kernel);
206211
}
207212

208213
return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
@@ -557,17 +562,19 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con
557562
* \param in_channels The number of input channels.
558563
* \param kernel_h The height of kernel.
559564
* \param kernel_w The width of kernel.
565+
* \param param The qnn conv2d attributes.
560566
* \return The sequence of Relay operators for term4.
561567
* \note The term4 looks like this
562568
*
563569
* Sigma(c,r,s) zp_a * zp_w
564570
*
565571
*/
566572
Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int in_channels,
567-
int kernel_h, int kernel_w) {
573+
int kernel_h, int kernel_w, const Conv2DAttrs* param) {
574+
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
568575
int scalar_term4 =
569576
input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w;
570-
return MakeConstantScalar(DataType::Int(32), scalar_term4);
577+
return MakeConstantScalar(DataType::Int(upcast_bits), scalar_term4);
571578
}
572579

573580
/*
@@ -578,15 +585,18 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i
578585
* \param in_channels The number of input channels.
579586
* \param kernel_h The height of kernel.
580587
* \param kernel_w The width of kernel.
588+
* \param param The qnn conv2d attributes.
581589
* \return The sequence of Relay operators for term4.
582590
* \note The term4 looks like this
583591
*
584592
* Sigma(c,r,s) zp_a * zp_w
585593
*
586594
*/
587595
Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels,
588-
int kernel_h, int kernel_w) {
589-
Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w);
596+
int kernel_h, int kernel_w, const Conv2DAttrs* param) {
597+
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
598+
Expr scalar_term4 =
599+
MakeConstantScalar(DataType::Int(upcast_bits), in_channels * kernel_h * kernel_w);
590600
Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point);
591601
return Multiply(scalar_term4, variable_term4);
592602
}
@@ -791,10 +801,11 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
791801
auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels);
792802
Expr term4;
793803
if (dynamic_zp) {
794-
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w);
804+
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w,
805+
param);
795806
} else {
796807
term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
797-
kernel_w);
808+
kernel_w, param);
798809
}
799810
return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
800811
kernel_zero_point_int);
@@ -829,7 +840,7 @@ This operator convolves quantized weight with quantized data. The scale of the
829840
output quantized tensor is the product of the weight_scale and input_scale of
830841
the input quantized tensors. The zero point of the output quantized tensor is
831842
0. By default, the dtype of output is int32. Please also refer to Requantize
832-
operator to understand how to scale back the int32 output to (u)int8.
843+
operator to understand how to scale back the int32 output to (u)int8 or (u)int16.
833844
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
834845
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
835846
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])

src/relay/qnn/op/dequantize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4747

4848
const auto input_dtype = data->dtype;
4949
ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
50-
input_dtype == DataType::Int(32))
51-
<< "Input type should be one of the quantized types [unit8, int8, int32] but was "
50+
input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32))
51+
<< "Input type should be one of the quantized types [unit8, int8, int16, int32] but was "
5252
<< input_dtype;
5353

5454
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();

src/relay/qnn/op/quantize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
7676
const Array<tvm::PrimExpr> oshape = data->shape;
7777
const DataType out_dtype = quantize_attrs->out_dtype;
7878
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
79-
out_dtype == DataType::Int(32))
80-
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
79+
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
80+
<< "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype;
8181
// assign output type
8282
reporter->Assign(types[3], TensorType(oshape, out_dtype));
8383
return true;

src/relay/qnn/op/requantize.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,8 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
480480
}
481481
const auto in_dtype = data->dtype;
482482
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
483-
in_dtype == DataType::Int(32))
484-
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
483+
in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
484+
<< "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype;
485485

486486
const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
487487
int axis = requantize_attrs->axis;
@@ -507,8 +507,8 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
507507
// assign output type
508508
auto out_dtype = requantize_attrs->out_dtype;
509509
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
510-
out_dtype == DataType::Int(32))
511-
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
510+
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
511+
<< "Output type should be one of [int8, uint8, int16, int32] but was " << out_dtype;
512512
reporter->Assign(types[5], TensorType(oshape, out_dtype));
513513
return true;
514514
}

0 commit comments

Comments
 (0)