Skip to content

Commit f78e9cd

Browse files
committed
fix tensor size
1 parent 7521211 commit f78e9cd

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,7 @@ def _full(self, node: fx.Node) -> relax.Var:
12481248
import torch
12491249

12501250
args = self.retrieve_args(node)
1251-
size = relax.ShapeExpr((args[0],) if isinstance(args[0], (list, tuple)) else args[0])
1251+
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
12521252
dtype = self._convert_data_type(
12531253
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
12541254
)
@@ -1303,7 +1303,7 @@ def _ones(self, node: fx.Node) -> relax.Var:
13031303
import torch
13041304

13051305
args = self.retrieve_args(node)
1306-
size = relax.ShapeExpr((args[0],) if isinstance(args[0], (list, tuple)) else args[0])
1306+
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
13071307
dtype = self._convert_data_type(
13081308
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
13091309
)

0 commit comments

Comments
 (0)