diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 8d18cc2962ae..b696bd6d056b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -390,6 +390,7 @@ def get_tensor_type_as_numpy(self, tensor_wrapper): return { TensorType.UINT8: np.uint8, TensorType.INT8: np.int8, + TensorType.INT16: np.int16, TensorType.FLOAT16: np.float16, TensorType.FLOAT32: np.float32, TensorType.INT32: np.int32, @@ -430,6 +431,8 @@ def get_tensor_type_str(self, tensor_type): if tensor_type == TensorType.INT8: return "int8" + if tensor_type == TensorType.INT16: + return "int16" if tensor_type == TensorType.UINT8: return "uint8" if tensor_type == TensorType.FLOAT16: @@ -2149,7 +2152,9 @@ def convert_conv(self, op, conv_type): qnn_conv2d_params = dict(params) qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"] qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"] - qnn_conv2d_params["out_dtype"] = "int32" + qnn_conv2d_params["out_dtype"] = ( + "int64" if output_tensor_type_str == "int16" else "int32" + ) qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"] qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"] out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params) @@ -2160,8 +2165,8 @@ def convert_conv(self, op, conv_type): if len(input_tensors) == 3: bias_tensor = input_tensors[2] bias_tensor_type = bias_tensor.tensor.Type() - # bias tensor type should be INT32 (quantization) or FLOAT32 - assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32) + # bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32 + assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32) bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) if self.has_expr(bias_tensor.tensor_idx): bias_expr = self.get_expr(bias_tensor.tensor_idx) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 8a7521e8ee50..42e4540f0f2c 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -50,12 +50,14 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, if (data == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); ICHECK(param != nullptr) << "Conv2DAttrs cannot be nullptr."; - ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) - << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype; + ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) || + data->dtype == DataType::Int(16)) + << "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; - ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32)) - << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype; + ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) || + param->out_dtype == DataType::Int(64)) + << "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype; ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; // Check the types of scale and zero points. @@ -190,19 +192,21 @@ WorkloadType GetWorkload(const Array& arg_types, const Conv2DA */ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point, const Expr& kernel_zero_point, const Conv2DAttrs* param) { - // Upcast the zero point to Int16. - auto zp_data = Cast(input_zero_point, DataType::Int(16)); - auto zp_kernel = Cast(kernel_zero_point, DataType::Int(16)); + // Upcast the parameters to be at least int32 to avoid overflow + auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); - auto shifted_data = Cast(data, DataType::Int(16)); - auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); + auto zp_data = Cast(input_zero_point, DataType::Int(upcast_bits)); + auto zp_kernel = Cast(kernel_zero_point, DataType::Int(upcast_bits)); + + auto shifted_data = Cast(data, DataType::Int(upcast_bits)); + auto zero_scalar = MakeConstantScalar(DataType::Int(upcast_bits), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { - shifted_data = Subtract(Cast(data, DataType::Int(16)), zp_data); + shifted_data = Subtract(Cast(data, DataType::Int(upcast_bits)), zp_data); } - auto shifted_kernel = Cast(weight, DataType::Int(16)); + auto shifted_kernel = Cast(weight, DataType::Int(upcast_bits)); if (!IsEqualScalar(kernel_zero_point, zero_scalar)) { - shifted_kernel = Subtract(Cast(weight, DataType::Int(16)), zp_kernel); + shifted_kernel = Subtract(Cast(weight, DataType::Int(upcast_bits)), zp_kernel); } return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation, @@ -557,6 +561,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con * \param in_channels The number of input channels. * \param kernel_h The height of kernel. * \param kernel_w The width of kernel. + * \param param The qnn conv2d attributes. * \return The sequence of Relay operators for term4. * \note The term4 looks like this * @@ -564,10 +569,11 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con * */ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int in_channels, - int kernel_h, int kernel_w) { + int kernel_h, int kernel_w, const Conv2DAttrs* param) { + auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); int scalar_term4 = input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w; - return MakeConstantScalar(DataType::Int(32), scalar_term4); + return MakeConstantScalar(DataType::Int(upcast_bits), scalar_term4); } /* @@ -578,6 +584,7 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i * \param in_channels The number of input channels. * \param kernel_h The height of kernel. * \param kernel_w The width of kernel. + * \param param The qnn conv2d attributes. * \return The sequence of Relay operators for term4. * \note The term4 looks like this * @@ -585,8 +592,10 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i * */ Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels, - int kernel_h, int kernel_w) { - Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w); + int kernel_h, int kernel_w, const Conv2DAttrs* param) { + auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); + Expr scalar_term4 = + MakeConstantScalar(DataType::Int(upcast_bits), in_channels * kernel_h * kernel_w); Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); return Multiply(scalar_term4, variable_term4); } @@ -791,10 +800,11 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels); Expr term4; if (dynamic_zp) { - term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w); + term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w, + param); } else { term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, - kernel_w); + kernel_w, param); } return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); @@ -829,7 +839,7 @@ This operator convolves quantized weight with quantized data. The scale of the output quantized tensor is the product of the weight_scale and input_scale of the input quantized tensors. The zero point of the output quantized tensor is 0. By default, the dtype of output is int32. Please also refer to Requantize -operator to understand how to scale back the int32 output to (u)int8. +operator to understand how to scale back the int32 output to (u)int8 or (u)int16. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, in_channels, height, width) if `layout` is `NCHW`. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 9a9c60d9ea6f..1ddcde81234d 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -47,8 +47,8 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto input_dtype = data->dtype; ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) || - input_dtype == DataType::Int(32)) - << "Input type should be one of the quantized types [unit8, int8, int32] but was " + input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32)) + << "Input type should be one of the quantized types [unit8, int8, int16, int32] but was " << input_dtype; const auto* dequantize_attrs = attrs.as(); diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 1a4c853d8929..06a73ee91cbf 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -76,8 +76,8 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || - out_dtype == DataType::Int(32)) - << "Output type should be one of [int8, unit8, int32] but was " << out_dtype; + out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32)) + << "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype; // assign output type reporter->Assign(types[3], TensorType(oshape, out_dtype)); return true; diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index ea143fe41713..8601264f5313 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -480,8 +480,8 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, } const auto in_dtype = data->dtype; ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) || - in_dtype == DataType::Int(32)) - << "Input type should be one of [int8, uint8, int32] but was " << in_dtype; + in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64)) + << "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype; const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; @@ -507,8 +507,8 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // assign output type auto out_dtype = requantize_attrs->out_dtype; ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || - out_dtype == DataType::Int(32)) - << "Output type should be one of [int8, uint8, int32] but was " << out_dtype; + out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32)) + << "Output type should be one of [int8, uint8, int16, int32] but was " << out_dtype; reporter->Assign(types[5], TensorType(oshape, out_dtype)); return true; } diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 80cdcf327f4b..8c8ca0eab2ff 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -139,19 +139,38 @@ def vmobj_to_list(o): def _quantize_keras_model( - keras_model, representative_data_gen, is_float_input=False, is_float_output=False + keras_model, + representative_data_gen, + is_float_input=False, + is_float_output=False, + int_quant_dtype=tf.int8, ): """Utility function to quantize a Keras model using TFLite converter.""" converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) - converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] - converter.representative_dataset = representative_data_gen - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + if int_quant_dtype == tf.int8: + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + inference_dtype = tf.uint8 + elif int_quant_dtype == tf.int16: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 + ] + inference_dtype = tf.uint16 + else: + raise RuntimeError( + f"Invalid quantized dtype {int_quant_dtype}. Supported types: int8, int16." + ) + # NOTE: If representative dataset is provided, and inference input type is not set, # then converter will self add quant & dequant Op accordingly. if not is_float_input: - converter.inference_input_type = tf.uint8 + converter.inference_input_type = inference_dtype if not is_float_output: - converter.inference_output_type = tf.uint8 + converter.inference_output_type = inference_dtype + return converter.convert() @@ -271,6 +290,7 @@ def compare_tflite_with_tvm( mode="graph_executor", experimental_new_converter=False, fp16_quantized=False, + int_quant_dtype=tf.int8, ): """Generic function to generate and compare TFLite and TVM output""" in_data = convert_to_list(in_data) @@ -287,7 +307,15 @@ def compare_tflite_with_tvm( converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors) converter.experimental_new_converter = experimental_new_converter if quantized: - converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 + if int_quant_dtype == tf.int16: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 + ] + else: + # default to int8 quantization + converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 + input_arrays = converter.get_input_arrays() input_stats = {} # calculate the mean and quantization scale for every input tensor, @@ -875,7 +903,7 @@ def test_forward_l2_pool2d(): def _test_tflite2_quantized_convolution( - input_shape, kernel_shape, dilations, strides, padding, data_format + input_shape, kernel_shape, filters, padding="valid", data_format=None, int_quant_dtype=tf.int8 ): """One iteration of TFLite2 quantized convolution with given shapes and attributes""" data_format = "channels_last" if "NHWC" else "channels_first" @@ -884,23 +912,26 @@ def _test_tflite2_quantized_convolution( data_in = tf.keras.layers.Input(shape=data.shape[1:]) conv = tf.keras.layers.Conv2D( - filters=kernel_shape[3], + filters=filters, kernel_size=(kernel_shape[0], kernel_shape[1]), - strides=strides, + activation=tf.nn.relu, padding=padding, data_format=data_format, - activation="relu", - use_bias=False, )(data_in) keras_model = tf.keras.models.Model(data_in, conv) - keras_model.layers[1].set_weights([kernel]) # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): for i in range(1): yield [data] - tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) + tflite_model_quant = _quantize_keras_model( + keras_model, + representative_data_gen, + is_float_input=True, + is_float_output=True, + int_quant_dtype=int_quant_dtype, + ) tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", "")) @@ -909,6 +940,25 @@ def representative_data_gen(): ) +def test_forward_quantized_convolution(): + for int_quant_dtype in [tf.int8, tf.int16]: + _test_tflite2_quantized_convolution( + (1, 28, 28, 1), + (1, 1), + 12, + data_format="NHWC", + int_quant_dtype=int_quant_dtype, + ) + + _test_tflite2_quantized_convolution( + (1, 1, 28, 28), + (1, 1), + 12, + data_format="NCWH", + int_quant_dtype=int_quant_dtype, + ) + + def _test_tflite2_quantized_depthwise_convolution( input_shape, kernel_shape, dilations, strides, padding, data_format, depth_multiplier ): @@ -1046,7 +1096,6 @@ def _test_convolution( quantized=quantized, input_range=input_range, experimental_new_converter=True, - fp16_quantized=fp16_quantized, ) else: data_array = np.reshape(data_array, tensor_in_sizes).astype("float32") @@ -1765,7 +1814,7 @@ def test_forward_concatenation(): # -------------- -def _test_unary_elemwise(math_op, data, quantized, quant_range=[-6, 6]): +def _test_unary_elemwise(math_op, data, quantized, quant_range=[-6, 6], int_quant_dtype=tf.int8): """One iteration of unary elemwise""" if quantized: with tf.Graph().as_default(): @@ -1787,6 +1836,7 @@ def _test_unary_elemwise(math_op, data, quantized, quant_range=[-6, 6]): quantized=True, input_range=input_range, experimental_new_converter=True, + int_quant_dtype=int_quant_dtype, ) else: with tf.Graph().as_default(): @@ -1795,14 +1845,20 @@ def _test_unary_elemwise(math_op, data, quantized, quant_range=[-6, 6]): compare_tflite_with_tvm(data, ["in:0"], [in_data], [out]) -def _unary_elewise_create_model(math_op, data, offset=0): +def _unary_elewise_create_model(math_op, data, offset=0, int_quant_dtype=tf.int8): class Model(tf.Module): @tf.function def tf_function(self, x): op = math_op(x) return op - dtype = "int8" + if int_quant_dtype in (tf.int8, tf.uint8): + dtype = "int8" + elif int_quant_dtype in (tf.int16, tf.uint16): + dtype = "int16" + else: + raise Exception(f"Unsupported dtype '{int_quant_dtype}' for unary elementwise test.") + model = Model() # Save the model @@ -1824,9 +1880,17 @@ def representative_dataset(): converter = tf.lite.TFLiteConverter.from_saved_model(export_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 + + if int_quant_dtype in (tf.int16, tf.uint16): + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 + ] + else: + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + + converter.inference_input_type = int_quant_dtype + converter.inference_output_type = int_quant_dtype + tflite_model = converter.convert() return tflite_model @@ -1836,24 +1900,28 @@ def representative_dataset(): # ---- -def _test_abs(data, quantized): +def _test_abs(data, quantized, int_quant_dtype=tf.int8): """One iteration of abs""" if quantized: - tflite_model_quant = _unary_elewise_create_model(tf.math.abs, data, offset=1) + tflite_model_quant = _unary_elewise_create_model( + tf.math.abs, data, offset=1, int_quant_dtype=int_quant_dtype + ) tflite_output = run_tflite_graph(tflite_model_quant, data) # TFLite 2.6.x upgrade support if tf.__version__ < LooseVersion("2.6.1"): in_node = ["serving_default_input_int8"] else: - in_node = ["tfl.quantize"] + in_node = ( + ["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"] + ) tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) tvm.testing.assert_allclose( np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2 ) else: - return _test_unary_elemwise(math_ops.abs, data, quantized) + return _test_unary_elemwise(math_ops.abs, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1861,14 +1929,18 @@ def _test_abs(data, quantized): # ---- -def _test_rsqrt(data, quantized): +def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8): """One iteration of rsqrt""" # tensorflow version upgrade support if tf.__version__ < LooseVersion("2.6.1") or not quantized: - return _test_unary_elemwise(math_ops.rsqrt, data, quantized, quant_range=[1, 6]) + return _test_unary_elemwise( + math_ops.rsqrt, data, quantized, quant_range=[1, 6], int_quant_dtype=int_quant_dtype + ) else: - tflite_model_quant = _unary_elewise_create_model(tf.math.rsqrt, data) + tflite_model_quant = _unary_elewise_create_model( + tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype + ) tflite_output = run_tflite_graph(tflite_model_quant, data) in_node = ["tfl.quantize"] @@ -1883,9 +1955,9 @@ def _test_rsqrt(data, quantized): # ---- -def _test_ceil(data, quantized): +def _test_ceil(data, quantized, int_quant_dtype=tf.int8): """One iteration of ceil""" - return _test_unary_elemwise(math_ops.ceil, data, quantized) + return _test_unary_elemwise(math_ops.ceil, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1893,9 +1965,9 @@ def _test_ceil(data, quantized): # ----- -def _test_floor(data, quantized): +def _test_floor(data, quantized, int_quant_dtype=tf.int8): """One iteration of floor""" - return _test_unary_elemwise(math_ops.floor, data, quantized) + return _test_unary_elemwise(math_ops.floor, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1903,9 +1975,9 @@ def _test_floor(data, quantized): # ----- -def _test_round(data, quantized): +def _test_round(data, quantized, int_quant_dtype=tf.int8): """One iteration of round""" - return _test_unary_elemwise(math_ops.round, data, quantized) + return _test_unary_elemwise(math_ops.round, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1913,9 +1985,9 @@ def _test_round(data, quantized): # --- -def _test_exp(data, quantized): +def _test_exp(data, quantized, int_quant_dtype=tf.int8): """One iteration of exp""" - return _test_unary_elemwise(math_ops.exp, data, quantized) + return _test_unary_elemwise(math_ops.exp, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1923,9 +1995,11 @@ def _test_exp(data, quantized): # --- -def _test_log(data, quantized): +def _test_log(data, quantized, int_quant_dtype=tf.int8): """One iteration of log""" - return _test_unary_elemwise(math_ops.log, data, quantized, quant_range=[1, 6]) + return _test_unary_elemwise( + math_ops.log, data, quantized, quant_range=[1, 6], int_quant_dtype=int_quant_dtype + ) ####################################################################### @@ -1933,9 +2007,9 @@ def _test_log(data, quantized): # --- -def _test_sin(data, quantized): +def _test_sin(data, quantized, int_quant_dtype=tf.int8): """One iteration of sin""" - return _test_unary_elemwise(math_ops.sin, data, quantized) + return _test_unary_elemwise(math_ops.sin, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1943,10 +2017,12 @@ def _test_sin(data, quantized): # --- -def _test_cos(data, quantized): +def _test_cos(data, quantized, int_quant_dtype=tf.int8): """One iteration of cos""" if quantized: - tflite_model_quant = _unary_elewise_create_model(tf.math.cos, data) + tflite_model_quant = _unary_elewise_create_model( + tf.math.cos, data, int_quant_dtype=int_quant_dtype + ) tflite_output = run_tflite_graph(tflite_model_quant, data) in_node = ["tfl.quantize"] tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) @@ -1962,9 +2038,9 @@ def _test_cos(data, quantized): # --- -def _test_tan(data, quantized): +def _test_tan(data, quantized, int_quant_dtype=tf.int8): """One iteration of tan""" - return _test_unary_elemwise(math_ops.tan, data, quantized) + return _test_unary_elemwise(math_ops.tan, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1972,9 +2048,9 @@ def _test_tan(data, quantized): # ------ -def _test_square(data, quantized): +def _test_square(data, quantized, int_quant_dtype=tf.int8): """One iteration of square""" - return _test_unary_elemwise(math_ops.square, data, quantized) + return _test_unary_elemwise(math_ops.square, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### @@ -1982,19 +2058,21 @@ def _test_square(data, quantized): # ------ -def _test_neg(data, quantized): +def _test_neg(data, quantized, int_quant_dtype=tf.int8): """One iteration of neg""" - return _test_unary_elemwise(math_ops.neg, data, quantized) + return _test_unary_elemwise(math_ops.neg, data, quantized, int_quant_dtype=int_quant_dtype) ####################################################################### -# Neg +# Sqrt # ------ -def _test_sqrt(data, quantized): +def _test_sqrt(data, quantized, int_quant_dtype=tf.int8): """One iteration of sqrt""" - return _test_unary_elemwise(math_ops.sqrt, data, quantized, quant_range=[1, 6]) + return _test_unary_elemwise( + math_ops.sqrt, data, quantized, quant_range=[1, 6], int_quant_dtype=int_quant_dtype + ) ####################################################################### @@ -2002,28 +2080,29 @@ def _test_sqrt(data, quantized): # --- -def _test_elu(data, quantized): +def _test_elu(data, quantized, int_quant_dtype=tf.int8): """One iteration of elu""" - return _test_unary_elemwise(nn_ops.elu, data, quantized) + return _test_unary_elemwise(nn_ops.elu, data, quantized, int_quant_dtype=int_quant_dtype) -def _test_forward_unary_elemwise(test_op, quant_dtype=None, quantized=True, negtive=True): +def _test_forward_unary_elemwise(test_op, int_quant_dtype=None, quantized=True, negative=True): # input data in_data, inq_data = [], [] + np_dtype = int_quant_dtype.as_numpy_dtype if int_quant_dtype else np.uint8 + # quantized input data if quantized: - quant_dtype = quant_dtype or np.uint8 - inq_data.append(np.arange(1, 240, 40, dtype=quant_dtype)) - inq_data.append(np.arange(1, 240, 40, dtype=quant_dtype).reshape((2, 1, 3))) - if quant_dtype == np.int8: + inq_data.append(np.arange(1, 240, 40, dtype=np_dtype)) + inq_data.append(np.arange(1, 240, 40, dtype=np_dtype).reshape((2, 1, 3))) + if int_quant_dtype == np.int8: inq_data.append(np.arange(-128, 127, 45, dtype=np.int8)) for data in inq_data: - test_op(data, quantized=True) + test_op(data, quantized=True, int_quant_dtype=int_quant_dtype) # normal input data - if negtive: + if negative: in_data.append(np.arange(-2.0, 4.0, dtype=np.float32)) in_data.append(np.arange(-2.0, 4.0, dtype=np.float32).reshape((2, 1, 3))) else: @@ -2031,30 +2110,31 @@ def _test_forward_unary_elemwise(test_op, quant_dtype=None, quantized=True, negt in_data.append(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))) for data in in_data: - test_op(data, quantized=False) + test_op(data, quantized=False, int_quant_dtype=int_quant_dtype) def test_all_unary_elemwise(): - _test_forward_unary_elemwise(_test_abs, quant_dtype=np.int8) + _test_forward_unary_elemwise(_test_abs, int_quant_dtype=tf.int8) + _test_forward_unary_elemwise(_test_abs, int_quant_dtype=tf.int16) _test_forward_unary_elemwise(_test_floor) _test_forward_unary_elemwise(_test_exp) - _test_forward_unary_elemwise(_test_log, negtive=False) + _test_forward_unary_elemwise(_test_log, negative=False) _test_forward_unary_elemwise(_test_square) _test_forward_unary_elemwise(_test_sin) _test_forward_unary_elemwise(_test_neg) - _test_forward_unary_elemwise(_test_sqrt, negtive=False) + _test_forward_unary_elemwise(_test_sqrt, negative=False) # tensorflow version upgrade support if tf.__version__ < LooseVersion("2.6.1"): - _test_forward_unary_elemwise(_test_rsqrt, negtive=False, quant_dtype=np.uint8) + _test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.uint8) else: - _test_forward_unary_elemwise(_test_rsqrt, negtive=False, quant_dtype=np.int8) + _test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.int8) # ceil and cos come with TFLite 1.14.0.post1 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"): _test_forward_unary_elemwise(_test_ceil) if tf.__version__ < LooseVersion("2.6.1"): _test_forward_unary_elemwise(_test_cos, quantized=False) else: - _test_forward_unary_elemwise(_test_cos, quant_dtype=np.int8) + _test_forward_unary_elemwise(_test_cos, int_quant_dtype=tf.int8) _test_forward_unary_elemwise(_test_round) # This fails with TF and Tflite 1.15.2, this could not have been tested # in CI or anywhere else. The failure mode is that we see a backtrace @@ -4572,6 +4652,47 @@ def test_forward_tflite_float16(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +def test_forward_mobilenet_int16(): + """Test int16 quantized model""" + # MobilenetV2 + model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz", + "mobilenet_v1_0.25_128_frozen.pb", + ) + + # Test image. Checking the labels because the requantize implementation is different between + # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via + # labels. Also, giving a real image, instead of random inputs. + # + # According to TFLite documentation, despite the quantization being done to make this model + # use int16 types, inputs and outputs are kept float32 by default. + # https://www.tensorflow.org/lite/performance/post_training_integer_quant_16x8 + data = get_real_image(128, 128, quantized=False) + + converter = tf.lite.TFLiteConverter.from_frozen_graph( + model_file, ["input"], ["MobilenetV1/Predictions/Reshape_1"] + ) + + def representative_dataset(): + for _ in range(1): + yield [data] + + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 + ] + converter.representative_dataset = representative_dataset + tflite_model_buf = converter.convert() + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, data, "input") + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + + ####################################################################### # Quantized SSD Mobilenet # ----------------------- @@ -4867,3 +4988,5 @@ def test_prevent_tensorflow_dynamic_range(): test_forward_tflite2_qnn_mobilenet_v2() test_forward_tflite_float16() + + test_forward_tflite_int16()