Skip to content

Commit 1b92fb7

Browse files
committed
support torch.permute
1 parent f22958e commit 1b92fb7

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,14 @@ def _flatten(self, node: fx.node.Node) -> relax.Var:
550550
return self.block_builder.emit(relax.op.reshape(x, new_shape))
551551

552552
def _permute(self, node: fx.node.Node) -> relax.Var:
553-
args = self.retrieve_args(node)
554-
return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:]))
553+
import torch # type: ignore
554+
555+
args = self.retrieve_args(node)
556+
if isinstance(args[1], (torch.Size, tuple, list)):
557+
return self.block_builder.emit(
558+
relax.op.permute_dims(args[0], tuple(args[1]))
559+
)
560+
return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:]))
555561

556562
def _reshape(self, node: fx.node.Node) -> relax.Var:
557563
import torch # type: ignore

0 commit comments

Comments
 (0)