|
29 | 29 | from packaging import version as package_version |
30 | 30 | import pytest |
31 | 31 | import numpy as np |
| 32 | +import typing |
32 | 33 |
|
33 | 34 | from PIL import Image |
34 | 35 |
|
@@ -292,11 +293,11 @@ def run_tflite_graph(tflite_model_buf, input_data): |
292 | 293 |
|
293 | 294 |
|
294 | 295 | def compare_tflite_with_tvm( |
295 | | - in_data, |
296 | | - in_name, |
297 | | - input_tensors, |
298 | | - output_tensors, |
299 | | - init_global_variables=False, |
| 296 | + in_data: typing.List[np.ndarray], |
| 297 | + in_name: typing.List[str], |
| 298 | + input_tensors: typing.List, |
| 299 | + output_tensors: typing.List, |
| 300 | + init_global_variables: bool = False, |
300 | 301 | out_names=None, |
301 | 302 | quantized=False, |
302 | 303 | input_range=None, |
@@ -5301,140 +5302,30 @@ def _golden(): |
5301 | 5302 | _test_reshape_span() |
5302 | 5303 |
|
5303 | 5304 |
|
5304 | | -####################################################################### |
5305 | | -# Main |
5306 | | -# ---- |
| 5305 | +class TestConv2d: |
| 5306 | + input_shape, kernel_shape, padding = tvm.testing.parameters( |
| 5307 | + ((1, 128, 256, 6), (5, 5, 6, 10), "SAME"), |
| 5308 | + ((1, 128, 256, 6), (5, 5, 6, 10), "VALID"), |
| 5309 | + # conv2d_group cases |
| 5310 | + ((1, 30, 40, 6), (5, 5, 1, 6), "SAME"), |
| 5311 | + ((1, 30, 40, 6), (5, 5, 1, 6), "VALID"), |
| 5312 | + ) |
| 5313 | + |
| 5314 | + def test_conv2d(self, input_shape: tuple, kernel_shape: tuple, padding: str): |
| 5315 | + dtype = tf.float32 |
| 5316 | + kernel_in = np.ones(kernel_shape) |
| 5317 | + with tf.Graph().as_default(): |
| 5318 | + x = array_ops.placeholder(shape=input_shape, dtype=dtype.name, name="input") |
| 5319 | + kernel = tf.constant(kernel_in, dtype=dtype, name="filter_weight") |
| 5320 | + out = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding=padding, name="conv2d") |
| 5321 | + input_data = np.random.randn(*input_shape).astype(dtype.name) |
| 5322 | + compare_tflite_with_tvm( |
| 5323 | + [input_data], |
| 5324 | + ["input"], |
| 5325 | + [x], |
| 5326 | + [out], |
| 5327 | + ) |
| 5328 | + |
| 5329 | + |
5307 | 5330 | if __name__ == "__main__": |
5308 | | - # BatchToSpaceND |
5309 | | - test_forward_batch_to_space_nd() |
5310 | | - |
5311 | | - # SpaceToBatchND |
5312 | | - test_forward_space_to_batch_nd() |
5313 | | - |
5314 | | - # Split |
5315 | | - test_forward_split() |
5316 | | - |
5317 | | - # Transpose |
5318 | | - test_forward_transpose() |
5319 | | - |
5320 | | - # Cast |
5321 | | - test_forward_cast() |
5322 | | - |
5323 | | - # BatchMatMul |
5324 | | - test_forward_batch_matmul() |
5325 | | - |
5326 | | - # Tile |
5327 | | - test_forward_tile() |
5328 | | - |
5329 | | - # Query |
5330 | | - test_forward_shape() |
5331 | | - |
5332 | | - # Transforms |
5333 | | - test_forward_concatenation() |
5334 | | - test_forward_pad() |
5335 | | - test_forward_pack() |
5336 | | - test_forward_unpack() |
5337 | | - test_forward_reshape() |
5338 | | - test_all_resize() |
5339 | | - test_forward_range() |
5340 | | - test_forward_squeeze() |
5341 | | - test_forward_slice() |
5342 | | - test_forward_topk() |
5343 | | - test_forward_gather() |
5344 | | - test_forward_gather_nd() |
5345 | | - test_forward_stridedslice() |
5346 | | - test_forward_depthtospace() |
5347 | | - test_forward_spacetodepth() |
5348 | | - test_forward_reverse_sequence() |
5349 | | - test_forward_sparse_to_dense() |
5350 | | - test_forward_select() |
5351 | | - test_forward_quantize_dequantize() |
5352 | | - test_forward_arg_min_max() |
5353 | | - test_forward_expand_dims() |
5354 | | - test_forward_reverse_v2() |
5355 | | - test_forward_matrix_set_diag() |
5356 | | - test_forward_matrix_diag() |
5357 | | - |
5358 | | - # NN |
5359 | | - test_forward_convolution() |
5360 | | - test_forward_transpose_conv() |
5361 | | - test_forward_logistic() |
5362 | | - test_forward_pooling() |
5363 | | - test_forward_l2_pool2d() |
5364 | | - test_forward_softmax() |
5365 | | - test_forward_tanh() |
5366 | | - test_forward_relu() |
5367 | | - test_forward_relu6() |
5368 | | - test_forward_leaky_relu() |
5369 | | - test_forward_relu_n1_to_1() |
5370 | | - test_forward_log_softmax() |
5371 | | - test_forward_fully_connected() |
5372 | | - test_forward_l2_normalization() |
5373 | | - test_forward_local_response_normalization() |
5374 | | - test_forward_prelu() |
5375 | | - test_forward_unidirectional_sequence_lstm() |
5376 | | - |
5377 | | - # Elemwise |
5378 | | - test_all_elemwise() |
5379 | | - test_forward_add_n() |
5380 | | - |
5381 | | - # Unary elemwise |
5382 | | - test_all_unary_elemwise() |
5383 | | - # Zeros Like |
5384 | | - test_forward_zeros_like() |
5385 | | - |
5386 | | - # Fill |
5387 | | - test_forward_fill() |
5388 | | - |
5389 | | - # Reduce |
5390 | | - test_all_reduce() |
5391 | | - |
5392 | | - # Logical |
5393 | | - test_all_logical() |
5394 | | - |
5395 | | - # Detection_PostProcess |
5396 | | - test_detection_postprocess() |
5397 | | - |
5398 | | - # NonMaxSuppressionV5 |
5399 | | - test_forward_nms_v5() |
5400 | | - |
5401 | | - # Overwrite Converter |
5402 | | - test_custom_op_converter() |
5403 | | - |
5404 | | - # test structural_equal and span information |
5405 | | - test_structure_and_span() |
5406 | | - |
5407 | | - # End to End |
5408 | | - test_forward_mobilenet_v1() |
5409 | | - test_forward_mobilenet_v2() |
5410 | | - test_forward_mobilenet_v3() |
5411 | | - test_forward_inception_v3_net() |
5412 | | - test_forward_inception_v4_net() |
5413 | | - test_forward_inception_v4_net_batched() |
5414 | | - test_forward_coco_ssd_mobilenet_v1() |
5415 | | - test_forward_mediapipe_hand_landmark() |
5416 | | - |
5417 | | - # End to End Sparse models |
5418 | | - test_forward_sparse_mobilenet_v1() |
5419 | | - test_forward_sparse_mobilenet_v2() |
5420 | | - |
5421 | | - # End to End quantized |
5422 | | - test_forward_qnn_inception_v1_net() |
5423 | | - test_forward_qnn_mobilenet_v1_net() |
5424 | | - test_forward_qnn_mobilenet_v2_net() |
5425 | | - # This also fails with a segmentation fault in my run |
5426 | | - # with Tflite 1.15.2 |
5427 | | - test_forward_qnn_mobilenet_v3_net() |
5428 | | - test_forward_qnn_coco_ssd_mobilenet_v1() |
5429 | | - |
5430 | | - # TFLite 2.1.0 quantized tests |
5431 | | - test_forward_quantized_convolution() |
5432 | | - test_forward_quantized_depthwise_convolution() |
5433 | | - test_forward_tflite2_qnn_resnet50() |
5434 | | - test_forward_tflite2_qnn_inception_v1() |
5435 | | - test_forward_tflite2_qnn_mobilenet_v2() |
5436 | | - |
5437 | | - test_forward_tflite_float16() |
5438 | | - |
5439 | | - test_forward_tflite_int16() |
5440 | | - test_forward_ds_cnn_int16() |
| 5331 | + tvm.testing.main() |
0 commit comments