From 3665382d2ab0b2cbdbd2391578567ccbce9ef7a9 Mon Sep 17 00:00:00 2001 From: Sungho Shin <87514200+rebel-shshin@users.noreply.github.com> Date: Fri, 7 Jul 2023 16:16:36 +0900 Subject: [PATCH 1/4] TFLite frontend bug fix --- python/tvm/relay/frontend/tflite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 9e88a85e035d..94ca1de1a222 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3041,9 +3041,9 @@ def convert_batch_matmul(self, op): _op.concatenate([out_batch, _op.strided_slice(shape_b, [rank_b - 2], [rank_b])], 0) ) if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape): - input_a = _op.transform.broadcast_to(a, a_broadcasted_shape) + input_a = _op.transform.broadcast_to(input_a, a_broadcasted_shape) if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape): - input_b = _op.transform.broadcast_to(b, b_broadcasted_shape) + input_b = _op.transform.broadcast_to(input_b, b_broadcasted_shape) input_a = self.flatten_to_nd(input_a, shape_a, 3) input_b = self.flatten_to_nd(input_b, shape_b, 3) From 40c2e4f0d2594227d72b7e2080654e627efa7275 Mon Sep 17 00:00:00 2001 From: Sungho Shin <87514200+rebel-shshin@users.noreply.github.com> Date: Fri, 7 Jul 2023 16:46:18 +0900 Subject: [PATCH 2/4] Update tflite.py --- python/tvm/relay/frontend/tflite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 94ca1de1a222..d812164923d4 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3006,7 +3006,7 @@ def convert_batch_matmul(self, op): rank_diff = rank_a - rank_b new_b_shape = _op.concatenate( [ - _expr.const([1] * rank_diff, dtype=_infer_type(b_shape).checked_type.dtype), + _expr.const([1] * rank_diff, dtype=_infer_type(new_b_shape).checked_type.dtype), shape_b, ], 0, @@ -3015,7 +3015,7 @@ def convert_batch_matmul(self, op): rank_diff = rank_b - rank_a new_a_shape = _op.concatenate( [ - _expr.const([1] * rank_diff, dtype=_infer_type(a_shape).checked_type.dtype), + _expr.const([1] * rank_diff, dtype=_infer_type(new_a_shape).checked_type.dtype), shape_a, ], 0, From 6fd5403e206dbf30df846753347c418901864058 Mon Sep 17 00:00:00 2001 From: Sungho Shin <87514200+rebel-shshin@users.noreply.github.com> Date: Mon, 10 Jul 2023 18:31:47 +0900 Subject: [PATCH 3/4] lint --- python/tvm/relay/frontend/tflite.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d812164923d4..dfc7ed27a474 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3006,7 +3006,9 @@ def convert_batch_matmul(self, op): rank_diff = rank_a - rank_b new_b_shape = _op.concatenate( [ - _expr.const([1] * rank_diff, dtype=_infer_type(new_b_shape).checked_type.dtype), + _expr.const( + [1] * rank_diff, dtype=_infer_type(new_b_shape).checked_type.dtype + ), shape_b, ], 0, @@ -3015,7 +3017,9 @@ def convert_batch_matmul(self, op): rank_diff = rank_b - rank_a new_a_shape = _op.concatenate( [ - _expr.const([1] * rank_diff, dtype=_infer_type(new_a_shape).checked_type.dtype), + _expr.const( + [1] * rank_diff, dtype=_infer_type(new_a_shape).checked_type.dtype + ), shape_a, ], 0, From fa434104bd670333d53c41f8708555386c00c072 Mon Sep 17 00:00:00 2001 From: Sungho Shin <87514200+rebel-shshin@users.noreply.github.com> Date: Mon, 10 Jul 2023 21:40:12 +0900 Subject: [PATCH 4/4] Add pytest --- tests/python/frontend/tflite/test_forward.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index c65e48b40288..4ea82e5b4ce7 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -794,6 +794,15 @@ def test_forward_batch_matmul(config): adjoint_b=False, quantized=config[2], ) + _test_batch_matmul( + (2, 3, 5, 4), + (1, 3, 5, 4), + dtype=config[0], + out_dtype=config[1], + adjoint_a=True, + adjoint_b=False, + quantized=config[2], + ) _test_batch_matmul( (3, 5, 4), (3, 5, 4), @@ -803,6 +812,15 @@ def test_forward_batch_matmul(config): adjoint_b=True, quantized=config[2], ) + _test_batch_matmul( + (2, 3, 5, 4), + (1, 3, 5, 4), + dtype=config[0], + out_dtype=config[1], + adjoint_a=False, + adjoint_b=True, + quantized=config[2], + ) _test_batch_matmul( (3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2] )