Skip to content

Commit 6eef8db

Browse files
committed
cleanup
1 parent 2e7b18e commit 6eef8db

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,6 @@ def _retrieve_args(self, node):
119119
else:
120120
return node
121121

122-
@staticmethod
123-
def _promote_binary_op_args(lhs, rhs):
124-
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
125-
return lhs, rhs
126-
elif isinstance(lhs, relax.Expr):
127-
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
128-
return lhs, relax.const(rhs, lhs.struct_info.dtype)
129-
elif isinstance(rhs, relax.Expr):
130-
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
131-
return relax.const(lhs, rhs.struct_info.dtype), rhs
132-
else:
133-
assert False
134-
135-
def _call_binary_op(self, op, lhs, rhs):
136-
lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs)
137-
return self.block_builder.emit(op(lhs, rhs))
138-
139122
########## Unary Ops ##########
140123

141124
def _unary_op(self, op: Callable) -> Callable:
@@ -246,17 +229,29 @@ def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
246229
from torch import fx
247230

248231
def convert(node: fx.Node) -> relax.Var:
232+
def promote_binary_op_args(lhs, rhs):
233+
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
234+
return lhs, rhs
235+
elif isinstance(lhs, relax.Expr):
236+
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
237+
return lhs, relax.const(rhs, lhs.struct_info.dtype)
238+
elif isinstance(rhs, relax.Expr):
239+
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
240+
return relax.const(lhs, rhs.struct_info.dtype), rhs
241+
else:
242+
assert False
243+
244+
def call_binary_op(op, lhs, rhs):
245+
lhs, rhs = promote_binary_op_args(lhs, rhs)
246+
return self.block_builder.emit(op(lhs, rhs))
247+
249248
lhs, rhs = self.retrieve_args(node)
250249
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
251-
return self._call_binary_op(relax_op, lhs, rhs)
250+
return call_binary_op(relax_op, lhs, rhs)
252251
elif isinstance(lhs, relax.expr.Constant):
253-
return self._call_binary_op(
254-
relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)
255-
)
252+
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
256253
elif isinstance(rhs, relax.expr.Constant):
257-
return self._call_binary_op(
258-
relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
259-
)
254+
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
260255
return intrinsic_op(lhs, rhs)
261256

262257
return convert

0 commit comments

Comments
 (0)