Skip to content

Commit b0582ee

Browse files
committed
[TFLite] Add support to int16 data type in TFLite frontend
Add support for int16 data type and int64 biases/accumulators in the TFLite frontend. Adjusts TFLite tests to cover int16 convolutions and element-wise; Fixes a minor typo negtive->negative in the element-wise tests.
1 parent de21c8f commit b0582ee

File tree

6 files changed

+235
-99
lines changed

6 files changed

+235
-99
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def get_tensor_type_as_numpy(self, tensor_wrapper):
390390
return {
391391
TensorType.UINT8: np.uint8,
392392
TensorType.INT8: np.int8,
393+
TensorType.INT16: np.int16,
393394
TensorType.FLOAT16: np.float16,
394395
TensorType.FLOAT32: np.float32,
395396
TensorType.INT32: np.int32,
@@ -430,6 +431,8 @@ def get_tensor_type_str(self, tensor_type):
430431

431432
if tensor_type == TensorType.INT8:
432433
return "int8"
434+
if tensor_type == TensorType.INT16:
435+
return "int16"
433436
if tensor_type == TensorType.UINT8:
434437
return "uint8"
435438
if tensor_type == TensorType.FLOAT16:
@@ -2149,7 +2152,9 @@ def convert_conv(self, op, conv_type):
21492152
qnn_conv2d_params = dict(params)
21502153
qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"]
21512154
qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"]
2152-
qnn_conv2d_params["out_dtype"] = "int32"
2155+
qnn_conv2d_params["out_dtype"] = (
2156+
"int64" if output_tensor_type_str == "int16" else "int32"
2157+
)
21532158
qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"]
21542159
qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"]
21552160
out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params)
@@ -2160,8 +2165,8 @@ def convert_conv(self, op, conv_type):
21602165
if len(input_tensors) == 3:
21612166
bias_tensor = input_tensors[2]
21622167
bias_tensor_type = bias_tensor.tensor.Type()
2163-
# bias tensor type should be INT32 (quantization) or FLOAT32
2164-
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
2168+
# bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32
2169+
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
21652170
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
21662171
if self.has_expr(bias_tensor.tensor_idx):
21672172
bias_expr = self.get_expr(bias_tensor.tensor_idx)

src/relay/qnn/op/convolution.cc

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
5050
if (data == nullptr || weight == nullptr) return false;
5151
const auto* param = attrs.as<Conv2DAttrs>();
5252
ICHECK(param != nullptr) << "Conv2DAttrs cannot be nullptr.";
53-
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
54-
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
53+
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
54+
data->dtype == DataType::Int(16))
55+
<< "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype;
5556
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
56-
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
57-
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
58-
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
57+
<< "Expected qnn conv2d type(int8, uint8, int16) for weight but was " << weight->dtype;
58+
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
59+
param->out_dtype == DataType::Int(64))
60+
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
5961
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
6062

6163
// Check the types of scale and zero points.
@@ -190,19 +192,21 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DA
190192
*/
191193
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point,
192194
const Expr& kernel_zero_point, const Conv2DAttrs* param) {
193-
// Upcast the zero point to Int16.
194-
auto zp_data = Cast(input_zero_point, DataType::Int(16));
195-
auto zp_kernel = Cast(kernel_zero_point, DataType::Int(16));
195+
// Upcast the parameters to be at least int32 to avoid overflow
196+
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
196197

197-
auto shifted_data = Cast(data, DataType::Int(16));
198-
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
198+
auto zp_data = Cast(input_zero_point, DataType::Int(upcast_bits));
199+
auto zp_kernel = Cast(kernel_zero_point, DataType::Int(upcast_bits));
200+
201+
auto shifted_data = Cast(data, DataType::Int(upcast_bits));
202+
auto zero_scalar = MakeConstantScalar(DataType::Int(upcast_bits), 0);
199203
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
200-
shifted_data = Subtract(Cast(data, DataType::Int(16)), zp_data);
204+
shifted_data = Subtract(Cast(data, DataType::Int(upcast_bits)), zp_data);
201205
}
202206

203-
auto shifted_kernel = Cast(weight, DataType::Int(16));
207+
auto shifted_kernel = Cast(weight, DataType::Int(upcast_bits));
204208
if (!IsEqualScalar(kernel_zero_point, zero_scalar)) {
205-
shifted_kernel = Subtract(Cast(weight, DataType::Int(16)), zp_kernel);
209+
shifted_kernel = Subtract(Cast(weight, DataType::Int(upcast_bits)), zp_kernel);
206210
}
207211

208212
return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
@@ -557,17 +561,19 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con
557561
* \param in_channels The number of input channels.
558562
* \param kernel_h The height of kernel.
559563
* \param kernel_w The width of kernel.
564+
* \param param The qnn conv2d attributes.
560565
* \return The sequence of Relay operators for term4.
561566
* \note The term4 looks like this
562567
*
563568
* Sigma(c,r,s) zp_a * zp_w
564569
*
565570
*/
566571
Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int in_channels,
567-
int kernel_h, int kernel_w) {
572+
int kernel_h, int kernel_w, const Conv2DAttrs* param) {
573+
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
568574
int scalar_term4 =
569575
input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w;
570-
return MakeConstantScalar(DataType::Int(32), scalar_term4);
576+
return MakeConstantScalar(DataType::Int(upcast_bits), scalar_term4);
571577
}
572578

573579
/*
@@ -578,15 +584,18 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i
578584
* \param in_channels The number of input channels.
579585
* \param kernel_h The height of kernel.
580586
* \param kernel_w The width of kernel.
587+
* \param param The qnn conv2d attributes.
581588
* \return The sequence of Relay operators for term4.
582589
* \note The term4 looks like this
583590
*
584591
* Sigma(c,r,s) zp_a * zp_w
585592
*
586593
*/
587594
Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels,
588-
int kernel_h, int kernel_w) {
589-
Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w);
595+
int kernel_h, int kernel_w, const Conv2DAttrs* param) {
596+
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
597+
Expr scalar_term4 =
598+
MakeConstantScalar(DataType::Int(upcast_bits), in_channels * kernel_h * kernel_w);
590599
Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point);
591600
return Multiply(scalar_term4, variable_term4);
592601
}
@@ -791,10 +800,11 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
791800
auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels);
792801
Expr term4;
793802
if (dynamic_zp) {
794-
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w);
803+
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w,
804+
param);
795805
} else {
796806
term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
797-
kernel_w);
807+
kernel_w, param);
798808
}
799809
return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
800810
kernel_zero_point_int);
@@ -829,7 +839,7 @@ This operator convolves quantized weight with quantized data. The scale of the
829839
output quantized tensor is the product of the weight_scale and input_scale of
830840
the input quantized tensors. The zero point of the output quantized tensor is
831841
0. By default, the dtype of output is int32. Please also refer to Requantize
832-
operator to understand how to scale back the int32 output to (u)int8.
842+
operator to understand how to scale back the int32 output to (u)int8 or (u)int16.
833843
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
834844
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
835845
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])

src/relay/qnn/op/dequantize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4747

4848
const auto input_dtype = data->dtype;
4949
ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
50-
input_dtype == DataType::Int(32))
51-
<< "Input type should be one of the quantized types [unit8, int8, int32] but was "
50+
input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32))
51+
<< "Input type should be one of the quantized types [unit8, int8, int16, int32] but was "
5252
<< input_dtype;
5353

5454
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();

src/relay/qnn/op/quantize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
7676
const Array<tvm::PrimExpr> oshape = data->shape;
7777
const DataType out_dtype = quantize_attrs->out_dtype;
7878
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
79-
out_dtype == DataType::Int(32))
80-
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
79+
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
80+
<< "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype;
8181
// assign output type
8282
reporter->Assign(types[3], TensorType(oshape, out_dtype));
8383
return true;

src/relay/qnn/op/requantize.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,8 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
480480
}
481481
const auto in_dtype = data->dtype;
482482
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
483-
in_dtype == DataType::Int(32))
484-
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
483+
in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
484+
<< "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype;
485485

486486
const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
487487
int axis = requantize_attrs->axis;
@@ -507,8 +507,8 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
507507
// assign output type
508508
auto out_dtype = requantize_attrs->out_dtype;
509509
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
510-
out_dtype == DataType::Int(32))
511-
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
510+
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
511+
<< "Output type should be one of [int8, uint8, int16, int32] but was " << out_dtype;
512512
reporter->Assign(types[5], TensorType(oshape, out_dtype));
513513
return true;
514514
}

0 commit comments

Comments
 (0)