From 34705c4896534d0b1d5ede7b38ff4adccfeafcf9 Mon Sep 17 00:00:00 2001 From: yrx <421626388@qq.com> Date: Mon, 19 May 2025 15:32:50 +0800 Subject: [PATCH 1/2] [Relax][Frontend]Fix: Output tensor with zero dimension after torch.unbind to relax conversion --- .../relax/frontend/torch/base_fx_graph_translator.py | 3 +-- .../relax/test_frontend_from_exported_program.py | 12 ++++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 5c8d7095e511..50969e85a5ea 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1269,8 +1269,7 @@ def _unbind(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) assert isinstance(dim, int), "Expected 2nd argument of unbind as int" selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + ret, split = [], self.block_builder.emit(relax.op.split(x, selections, dim)) for i in range(selections): ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 80da6fcf19ad..aaaf7e6eacb6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3108,8 +3108,7 @@ def main( R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((0, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + ) = R.split(input_1, indices_or_sections=3, axis=0) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] @@ -3152,8 +3151,7 @@ def main( R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 0, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + ) = R.split(input_1, indices_or_sections=3, axis=1) lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] @@ -3978,8 +3976,7 @@ def main( R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((0, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + ) = R.split(input_1, indices_or_sections=3, axis=0) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] @@ -4022,8 +4019,7 @@ def main( R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 0, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + ) = R.split(input_1, indices_or_sections=3, axis=1) lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] From 20859721995230bfecb7d822744c6e18108d5243 Mon Sep 17 00:00:00 2001 From: yrx <421626388@qq.com> Date: Mon, 19 May 2025 18:06:53 +0800 Subject: [PATCH 2/2] Fix: Output tensor with zero dimension after torch.unbind to relax conversion --- tests/python/relax/test_frontend_from_fx.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 7fb2bed328a8..789c5649e605 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3746,8 +3746,7 @@ def main( R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((0, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + ) = R.split(input_1, indices_or_sections=3, axis=0) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] @@ -3783,8 +3782,7 @@ def main( R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 0, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + ) = R.split(input_1, indices_or_sections=3, axis=1) lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]