diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5ed0f18deb9e..f9a5d9c33f02 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -550,7 +550,11 @@ def _flatten(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.reshape(x, new_shape)) def _permute(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) def _reshape(self, node: fx.node.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index dd2719f8ce91..46c079aa99cc 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3029,10 +3029,14 @@ def forward(self, x): def test_permute(): input_info = [([1, 2, 3, 4], "float32")] - class Permute(Module): + class Permute1(Module): def forward(self, x): return x.permute(0, 3, 2, 1) + class Permute2(Module): + def forward(self, x): + return torch.permute(x, (0, 3, 2, 1)) + @tvm.script.ir_module class expected1: @R.function @@ -3046,7 +3050,8 @@ def main( R.output(gv) return gv - verify_model(Permute(), input_info, {}, expected1) + verify_model(Permute1(), input_info, {}, expected1) + verify_model(Permute2(), input_info, {}, expected1) def test_reshape():