Skip to content

Commit 2e7b18e

Browse files
committed
introduce _binary_op()
1 parent 5265d21 commit 2e7b18e

File tree

1 file changed

+37
-80
lines changed

1 file changed

+37
-80
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 37 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# pylint: disable=import-outside-toplevel
2020
"""PyTorch FX frontend of Relax."""
2121
from typing import Callable, Dict, List, Optional, Tuple, Union
22-
from functools import reduce
22+
from functools import partial, reduce
2323

2424
import tvm
2525
from tvm import relax
@@ -240,66 +240,26 @@ def convert(node: fx.Node) -> relax.Var:
240240

241241
return convert
242242

243-
########## Arithmetic ##########
243+
########## Binary Ops ##########
244244

245-
def _add(self, node: fx.Node) -> relax.Expr:
246-
lhs, rhs = self.retrieve_args(node)
247-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
248-
return self._call_binary_op(relax.op.add, lhs, rhs)
249-
elif isinstance(lhs, relax.expr.Constant):
250-
return self._call_binary_op(
251-
relax.op.add, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)
252-
)
253-
elif isinstance(rhs, relax.expr.Constant):
254-
return self._call_binary_op(
255-
relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
256-
)
257-
return lhs + rhs
258-
259-
def _max(self, node: fx.Node) -> relax.Expr:
260-
lhs, rhs = self.retrieve_args(node)
261-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
262-
return self._call_binary_op(relax.op.maximum, lhs, rhs)
263-
264-
def _floordiv(self, node: fx.Node) -> relax.Expr:
265-
lhs, rhs = self.retrieve_args(node)
266-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
267-
return self._call_binary_op(relax.op.floor_divide, lhs, rhs)
268-
return lhs // rhs
269-
270-
def _mul(self, node: fx.Node) -> relax.Expr:
271-
lhs, rhs = self.retrieve_args(node)
272-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
273-
return self._call_binary_op(relax.op.multiply, lhs, rhs)
274-
return lhs * rhs
275-
276-
def _pow(self, node: fx.Node) -> relax.Expr:
277-
lhs, rhs = self.retrieve_args(node)
278-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
279-
return self._call_binary_op(relax.op.power, lhs, rhs)
280-
return lhs**rhs
281-
282-
def _sub(self, node: fx.Node) -> relax.Expr:
283-
lhs, rhs = self.retrieve_args(node)
284-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
285-
return self._call_binary_op(relax.op.subtract, lhs, rhs)
286-
return lhs - rhs
287-
288-
def _truediv(self, node: fx.Node) -> relax.Expr:
289-
lhs, rhs = self.retrieve_args(node)
290-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
291-
return self._call_binary_op(relax.op.divide, lhs, rhs)
292-
return lhs / rhs
293-
294-
########## Compare ##########
295-
296-
def _lt(self, node: fx.Node) -> relax.Expr:
297-
lhs, rhs = self.retrieve_args(node)
298-
return self._call_binary_op(relax.op.less, lhs, rhs)
299-
300-
def _eq(self, node: fx.Node) -> relax.Expr:
301-
lhs, rhs = self.retrieve_args(node)
302-
return self._call_binary_op(relax.op.equal, lhs, rhs)
245+
def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
246+
from torch import fx
247+
248+
def convert(node: fx.Node) -> relax.Var:
249+
lhs, rhs = self.retrieve_args(node)
250+
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
251+
return self._call_binary_op(relax_op, lhs, rhs)
252+
elif isinstance(lhs, relax.expr.Constant):
253+
return self._call_binary_op(
254+
relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)
255+
)
256+
elif isinstance(rhs, relax.expr.Constant):
257+
return self._call_binary_op(
258+
relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
259+
)
260+
return intrinsic_op(lhs, rhs)
261+
262+
return convert
303263

304264
########## Creation ##########
305265

@@ -486,14 +446,6 @@ def _to(self, node: fx.Node) -> relax.Var:
486446
def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
487447
return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32"))
488448

489-
def _matmul(self, node: fx.Node) -> relax.Var:
490-
args = self.retrieve_args(node)
491-
res = self._matmul_impl(
492-
args[0],
493-
args[1],
494-
)
495-
return res
496-
497449
def _addmm(self, node: fx.Node) -> relax.Var:
498450
x = self.env[node.args[0]]
499451
y = self.env[node.args[1]]
@@ -1568,6 +1520,7 @@ def _getitem(self, node: fx.Node) -> relax.Var:
15681520
assert False
15691521

15701522
def create_convert_map(self):
1523+
import operator
15711524
from torch import nn
15721525
from torch import fx
15731526

@@ -1641,23 +1594,27 @@ def create_convert_map(self):
16411594
"triu_": self._inplace_tril_triu(relax.op.triu),
16421595
"triu": self._tril_triu(relax.op.triu),
16431596
# binary
1644-
"add": self._add,
1645-
"eq": self._eq,
1646-
"floordiv": self._floordiv,
1647-
"iadd": self._add,
1648-
"lt": self._lt,
1649-
"matmul": self._matmul,
1650-
"max": self._max,
1651-
"mul": self._mul,
1652-
"pow": self._pow,
1653-
"sub": self._sub,
1654-
"truediv": self._truediv,
1597+
"add": self._binary_op(relax.op.add, operator.add),
1598+
"eq": self._binary_op(relax.op.equal, operator.eq),
1599+
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
1600+
"iadd": self._binary_op(relax.op.add, operator.add),
1601+
"lt": self._binary_op(relax.op.less, operator.lt),
1602+
"matmul": self._binary_op(
1603+
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
1604+
),
1605+
"max": self._binary_op(relax.op.maximum, max),
1606+
"mul": self._binary_op(relax.op.multiply, operator.mul),
1607+
"pow": self._binary_op(relax.op.power, operator.pow),
1608+
"sub": self._binary_op(relax.op.subtract, operator.sub),
1609+
"truediv": self._binary_op(relax.op.divide, operator.truediv),
16551610
# neural network
16561611
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
16571612
"addmm": self._addmm,
16581613
"avg_pool2d": self._avg_pool2d,
16591614
"baddbmm": self._baddbmm,
1660-
"bmm": self._matmul,
1615+
"bmm": self._binary_op(
1616+
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
1617+
),
16611618
"conv_transpose1d": self._conv1d_transpose_functional,
16621619
"conv_transpose2d": self._conv2d_transpose_functional,
16631620
"conv1d": self._conv1d_functional,

0 commit comments

Comments
 (0)