diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1895119e79f4..53b1fdd22c61 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1275,9 +1275,13 @@ def _unbind(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) assert isinstance(dim, int), "Expected 2nd argument of unbind as int" selections = self.shape_of(x)[dim].value - ret, split = [], self.block_builder.emit(relax.op.split(x, selections, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + ret = [] + if selections == 1: + ret.append(self.block_builder.emit(relax.op.squeeze(x, axis=dim))) + else: + split = self.block_builder.emit(relax.op.split(x, selections, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) ########## Statistical ########## diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ead341de287a..65a72412179a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3251,9 +3251,25 @@ def main( R.output(gv) return gv + @tvm.script.ir_module + class expected3: + @R.function + def main( + data: R.Tensor((3, 1, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 3), dtype="float32") = R.squeeze(data, axis=[1]) + lv1: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv,) + lv2: R.Tensor((3, 3), dtype="float32") = lv1[0] + gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv2,) + R.output(gv) + return gv + example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) verify_model(Unbind1(), example_args, {}, expected1) verify_model(Unbind2(), example_args, {}, expected2) + single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),) + verify_model(Unbind2(), single_dim_args, {}, expected3) def test_interpolate():