Skip to content

Commit 103bd6e

Browse files
committed
cleanup datatype ops
1 parent 7624b3d commit 103bd6e

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,32 @@ def convert(node: fx.Node):
913913

914914
return convert
915915

916+
########## DataType ##########
917+
918+
def _float(self, node: fx.Node) -> relax.Var:
919+
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32"))
920+
921+
def _half(self, node: fx.Node) -> relax.Var:
922+
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16"))
923+
924+
def _to(self, node: fx.Node) -> relax.Var:
925+
import torch
926+
927+
x = self.env[node.args[0]]
928+
if len(node.args) == 2:
929+
if isinstance(node.args[1], torch.dtype):
930+
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
931+
return self.block_builder.emit(relax.op.astype(x, dtype))
932+
elif "dtype" in node.kwargs:
933+
dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env)
934+
return self.block_builder.emit(relax.op.astype(x, dtype))
935+
return x
936+
937+
def _type(self, node: fx.Node) -> relax.Var:
938+
x = self.env[node.args[0]]
939+
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
940+
return self.block_builder.emit(relax.op.astype(x, dtype))
941+
916942
########## Creation ##########
917943

918944
def _arange(self, node: fx.Node) -> relax.Var:
@@ -1051,32 +1077,6 @@ def _full(self, node: fx.Node) -> relax.Var:
10511077
)
10521078
)
10531079

1054-
########## DataType ##########
1055-
1056-
def _float(self, node: fx.Node) -> relax.Var:
1057-
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32"))
1058-
1059-
def _half(self, node: fx.Node) -> relax.Var:
1060-
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16"))
1061-
1062-
def _type(self, node: fx.Node) -> relax.Var:
1063-
x = self.env[node.args[0]]
1064-
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
1065-
return self.block_builder.emit(relax.op.astype(x, dtype))
1066-
1067-
def _to(self, node: fx.Node) -> relax.Var:
1068-
import torch
1069-
1070-
x = self.env[node.args[0]]
1071-
if len(node.args) == 2:
1072-
if isinstance(node.args[1], torch.dtype):
1073-
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
1074-
return self.block_builder.emit(relax.op.astype(x, dtype))
1075-
elif "dtype" in node.kwargs:
1076-
dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env)
1077-
return self.block_builder.emit(relax.op.astype(x, dtype))
1078-
return x
1079-
10801080
########## Manipulation ##########
10811081

10821082
def _cat(self, node: fx.Node) -> relax.Var:

0 commit comments

Comments
 (0)