Skip to content

Commit 53166a0

Browse files
committed
Fix TFLite frontend bug and add test
1 parent a8c5804 commit 53166a0

File tree

2 files changed

+36
-142
lines changed

2 files changed

+36
-142
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,7 @@ def convert_conv(self, op, conv_type):
21462146
_, kernel_h, kernel_w, in_channels = to_int_list(self.get_tensor_shape(weight_tensor))
21472147
assert in_channels == input_c * depth_multiplier
21482148
else:
2149-
output_channels, kernel_h, kernel_w, _ = to_int_list(
2149+
output_channels, kernel_h, kernel_w, in_channels = to_int_list(
21502150
self.get_tensor_shape(weight_tensor)
21512151
)
21522152

@@ -2170,6 +2170,9 @@ def convert_conv(self, op, conv_type):
21702170
else:
21712171
params["channels"] = int(output_channels)
21722172
params["kernel_layout"] = "HWIO"
2173+
if input_c != in_channels:
2174+
assert input_c % in_channels == 0
2175+
params["groups"] = int(input_c / in_channels)
21732176

21742177
# weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
21752178
weight_tensor_type = weight_tensor.tensor.Type()

tests/python/frontend/tflite/test_forward.py

Lines changed: 32 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from packaging import version as package_version
3030
import pytest
3131
import numpy as np
32+
import typing
3233

3334
from PIL import Image
3435

@@ -292,11 +293,11 @@ def run_tflite_graph(tflite_model_buf, input_data):
292293

293294

294295
def compare_tflite_with_tvm(
295-
in_data,
296-
in_name,
297-
input_tensors,
298-
output_tensors,
299-
init_global_variables=False,
296+
in_data: typing.List[np.ndarray],
297+
in_name: typing.List[str],
298+
input_tensors: typing.List,
299+
output_tensors: typing.List,
300+
init_global_variables: bool = False,
300301
out_names=None,
301302
quantized=False,
302303
input_range=None,
@@ -5301,140 +5302,30 @@ def _golden():
53015302
_test_reshape_span()
53025303

53035304

5304-
#######################################################################
5305-
# Main
5306-
# ----
5305+
class TestConv2d:
5306+
input_shape, kernel_shape, padding = tvm.testing.parameters(
5307+
((1, 128, 256, 6), (5, 5, 6, 10), "SAME"),
5308+
((1, 128, 256, 6), (5, 5, 6, 10), "VALID"),
5309+
# conv2d_group cases
5310+
((1, 30, 40, 6), (5, 5, 1, 6), "SAME"),
5311+
((1, 30, 40, 6), (5, 5, 1, 6), "VALID"),
5312+
)
5313+
5314+
def test_conv2d(self, input_shape: tuple, kernel_shape: tuple, padding: str):
5315+
dtype = tf.float32
5316+
kernel_in = np.ones(kernel_shape)
5317+
with tf.Graph().as_default():
5318+
x = array_ops.placeholder(shape=input_shape, dtype=dtype.name, name="input")
5319+
kernel = tf.constant(kernel_in, dtype=dtype, name="filter_weight")
5320+
out = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding=padding, name="conv2d")
5321+
input_data = np.random.randn(*input_shape).astype(dtype.name)
5322+
compare_tflite_with_tvm(
5323+
[input_data],
5324+
["input"],
5325+
[x],
5326+
[out],
5327+
)
5328+
5329+
53075330
if __name__ == "__main__":
5308-
# BatchToSpaceND
5309-
test_forward_batch_to_space_nd()
5310-
5311-
# SpaceToBatchND
5312-
test_forward_space_to_batch_nd()
5313-
5314-
# Split
5315-
test_forward_split()
5316-
5317-
# Transpose
5318-
test_forward_transpose()
5319-
5320-
# Cast
5321-
test_forward_cast()
5322-
5323-
# BatchMatMul
5324-
test_forward_batch_matmul()
5325-
5326-
# Tile
5327-
test_forward_tile()
5328-
5329-
# Query
5330-
test_forward_shape()
5331-
5332-
# Transforms
5333-
test_forward_concatenation()
5334-
test_forward_pad()
5335-
test_forward_pack()
5336-
test_forward_unpack()
5337-
test_forward_reshape()
5338-
test_all_resize()
5339-
test_forward_range()
5340-
test_forward_squeeze()
5341-
test_forward_slice()
5342-
test_forward_topk()
5343-
test_forward_gather()
5344-
test_forward_gather_nd()
5345-
test_forward_stridedslice()
5346-
test_forward_depthtospace()
5347-
test_forward_spacetodepth()
5348-
test_forward_reverse_sequence()
5349-
test_forward_sparse_to_dense()
5350-
test_forward_select()
5351-
test_forward_quantize_dequantize()
5352-
test_forward_arg_min_max()
5353-
test_forward_expand_dims()
5354-
test_forward_reverse_v2()
5355-
test_forward_matrix_set_diag()
5356-
test_forward_matrix_diag()
5357-
5358-
# NN
5359-
test_forward_convolution()
5360-
test_forward_transpose_conv()
5361-
test_forward_logistic()
5362-
test_forward_pooling()
5363-
test_forward_l2_pool2d()
5364-
test_forward_softmax()
5365-
test_forward_tanh()
5366-
test_forward_relu()
5367-
test_forward_relu6()
5368-
test_forward_leaky_relu()
5369-
test_forward_relu_n1_to_1()
5370-
test_forward_log_softmax()
5371-
test_forward_fully_connected()
5372-
test_forward_l2_normalization()
5373-
test_forward_local_response_normalization()
5374-
test_forward_prelu()
5375-
test_forward_unidirectional_sequence_lstm()
5376-
5377-
# Elemwise
5378-
test_all_elemwise()
5379-
test_forward_add_n()
5380-
5381-
# Unary elemwise
5382-
test_all_unary_elemwise()
5383-
# Zeros Like
5384-
test_forward_zeros_like()
5385-
5386-
# Fill
5387-
test_forward_fill()
5388-
5389-
# Reduce
5390-
test_all_reduce()
5391-
5392-
# Logical
5393-
test_all_logical()
5394-
5395-
# Detection_PostProcess
5396-
test_detection_postprocess()
5397-
5398-
# NonMaxSuppressionV5
5399-
test_forward_nms_v5()
5400-
5401-
# Overwrite Converter
5402-
test_custom_op_converter()
5403-
5404-
# test structural_equal and span information
5405-
test_structure_and_span()
5406-
5407-
# End to End
5408-
test_forward_mobilenet_v1()
5409-
test_forward_mobilenet_v2()
5410-
test_forward_mobilenet_v3()
5411-
test_forward_inception_v3_net()
5412-
test_forward_inception_v4_net()
5413-
test_forward_inception_v4_net_batched()
5414-
test_forward_coco_ssd_mobilenet_v1()
5415-
test_forward_mediapipe_hand_landmark()
5416-
5417-
# End to End Sparse models
5418-
test_forward_sparse_mobilenet_v1()
5419-
test_forward_sparse_mobilenet_v2()
5420-
5421-
# End to End quantized
5422-
test_forward_qnn_inception_v1_net()
5423-
test_forward_qnn_mobilenet_v1_net()
5424-
test_forward_qnn_mobilenet_v2_net()
5425-
# This also fails with a segmentation fault in my run
5426-
# with Tflite 1.15.2
5427-
test_forward_qnn_mobilenet_v3_net()
5428-
test_forward_qnn_coco_ssd_mobilenet_v1()
5429-
5430-
# TFLite 2.1.0 quantized tests
5431-
test_forward_quantized_convolution()
5432-
test_forward_quantized_depthwise_convolution()
5433-
test_forward_tflite2_qnn_resnet50()
5434-
test_forward_tflite2_qnn_inception_v1()
5435-
test_forward_tflite2_qnn_mobilenet_v2()
5436-
5437-
test_forward_tflite_float16()
5438-
5439-
test_forward_tflite_int16()
5440-
test_forward_ds_cnn_int16()
5331+
tvm.testing.main()

0 commit comments

Comments
 (0)