|
19 | 19 | # pylint: disable=import-outside-toplevel |
20 | 20 | """PyTorch FX frontend of Relax.""" |
21 | 21 | from typing import Callable, Dict, List, Optional, Tuple, Union |
22 | | -from functools import reduce |
| 22 | +from functools import partial, reduce |
23 | 23 |
|
24 | 24 | import tvm |
25 | 25 | from tvm import relax |
@@ -240,66 +240,26 @@ def convert(node: fx.Node) -> relax.Var: |
240 | 240 |
|
241 | 241 | return convert |
242 | 242 |
|
243 | | - ########## Arithmetic ########## |
| 243 | + ########## Binary Ops ########## |
244 | 244 |
|
245 | | - def _add(self, node: fx.Node) -> relax.Expr: |
246 | | - lhs, rhs = self.retrieve_args(node) |
247 | | - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
248 | | - return self._call_binary_op(relax.op.add, lhs, rhs) |
249 | | - elif isinstance(lhs, relax.expr.Constant): |
250 | | - return self._call_binary_op( |
251 | | - relax.op.add, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) |
252 | | - ) |
253 | | - elif isinstance(rhs, relax.expr.Constant): |
254 | | - return self._call_binary_op( |
255 | | - relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs |
256 | | - ) |
257 | | - return lhs + rhs |
258 | | - |
259 | | - def _max(self, node: fx.Node) -> relax.Expr: |
260 | | - lhs, rhs = self.retrieve_args(node) |
261 | | - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
262 | | - return self._call_binary_op(relax.op.maximum, lhs, rhs) |
263 | | - |
264 | | - def _floordiv(self, node: fx.Node) -> relax.Expr: |
265 | | - lhs, rhs = self.retrieve_args(node) |
266 | | - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
267 | | - return self._call_binary_op(relax.op.floor_divide, lhs, rhs) |
268 | | - return lhs // rhs |
269 | | - |
270 | | - def _mul(self, node: fx.Node) -> relax.Expr: |
271 | | - lhs, rhs = self.retrieve_args(node) |
272 | | - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
273 | | - return self._call_binary_op(relax.op.multiply, lhs, rhs) |
274 | | - return lhs * rhs |
275 | | - |
276 | | - def _pow(self, node: fx.Node) -> relax.Expr: |
277 | | - lhs, rhs = self.retrieve_args(node) |
278 | | - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
279 | | - return self._call_binary_op(relax.op.power, lhs, rhs) |
280 | | - return lhs**rhs |
281 | | - |
282 | | - def _sub(self, node: fx.Node) -> relax.Expr: |
283 | | - lhs, rhs = self.retrieve_args(node) |
284 | | - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
285 | | - return self._call_binary_op(relax.op.subtract, lhs, rhs) |
286 | | - return lhs - rhs |
287 | | - |
288 | | - def _truediv(self, node: fx.Node) -> relax.Expr: |
289 | | - lhs, rhs = self.retrieve_args(node) |
290 | | - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
291 | | - return self._call_binary_op(relax.op.divide, lhs, rhs) |
292 | | - return lhs / rhs |
293 | | - |
294 | | - ########## Compare ########## |
295 | | - |
296 | | - def _lt(self, node: fx.Node) -> relax.Expr: |
297 | | - lhs, rhs = self.retrieve_args(node) |
298 | | - return self._call_binary_op(relax.op.less, lhs, rhs) |
299 | | - |
300 | | - def _eq(self, node: fx.Node) -> relax.Expr: |
301 | | - lhs, rhs = self.retrieve_args(node) |
302 | | - return self._call_binary_op(relax.op.equal, lhs, rhs) |
| 245 | + def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: |
| 246 | + from torch import fx |
| 247 | + |
| 248 | + def convert(node: fx.Node) -> relax.Var: |
| 249 | + lhs, rhs = self.retrieve_args(node) |
| 250 | + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): |
| 251 | + return self._call_binary_op(relax_op, lhs, rhs) |
| 252 | + elif isinstance(lhs, relax.expr.Constant): |
| 253 | + return self._call_binary_op( |
| 254 | + relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) |
| 255 | + ) |
| 256 | + elif isinstance(rhs, relax.expr.Constant): |
| 257 | + return self._call_binary_op( |
| 258 | + relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs |
| 259 | + ) |
| 260 | + return intrinsic_op(lhs, rhs) |
| 261 | + |
| 262 | + return convert |
303 | 263 |
|
304 | 264 | ########## Creation ########## |
305 | 265 |
|
@@ -486,14 +446,6 @@ def _to(self, node: fx.Node) -> relax.Var: |
486 | 446 | def _matmul_impl(self, a: relax.Expr, b: relax.Expr): |
487 | 447 | return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) |
488 | 448 |
|
489 | | - def _matmul(self, node: fx.Node) -> relax.Var: |
490 | | - args = self.retrieve_args(node) |
491 | | - res = self._matmul_impl( |
492 | | - args[0], |
493 | | - args[1], |
494 | | - ) |
495 | | - return res |
496 | | - |
497 | 449 | def _addmm(self, node: fx.Node) -> relax.Var: |
498 | 450 | x = self.env[node.args[0]] |
499 | 451 | y = self.env[node.args[1]] |
@@ -1568,6 +1520,7 @@ def _getitem(self, node: fx.Node) -> relax.Var: |
1568 | 1520 | assert False |
1569 | 1521 |
|
1570 | 1522 | def create_convert_map(self): |
| 1523 | + import operator |
1571 | 1524 | from torch import nn |
1572 | 1525 | from torch import fx |
1573 | 1526 |
|
@@ -1641,23 +1594,27 @@ def create_convert_map(self): |
1641 | 1594 | "triu_": self._inplace_tril_triu(relax.op.triu), |
1642 | 1595 | "triu": self._tril_triu(relax.op.triu), |
1643 | 1596 | # binary |
1644 | | - "add": self._add, |
1645 | | - "eq": self._eq, |
1646 | | - "floordiv": self._floordiv, |
1647 | | - "iadd": self._add, |
1648 | | - "lt": self._lt, |
1649 | | - "matmul": self._matmul, |
1650 | | - "max": self._max, |
1651 | | - "mul": self._mul, |
1652 | | - "pow": self._pow, |
1653 | | - "sub": self._sub, |
1654 | | - "truediv": self._truediv, |
| 1597 | + "add": self._binary_op(relax.op.add, operator.add), |
| 1598 | + "eq": self._binary_op(relax.op.equal, operator.eq), |
| 1599 | + "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), |
| 1600 | + "iadd": self._binary_op(relax.op.add, operator.add), |
| 1601 | + "lt": self._binary_op(relax.op.less, operator.lt), |
| 1602 | + "matmul": self._binary_op( |
| 1603 | + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul |
| 1604 | + ), |
| 1605 | + "max": self._binary_op(relax.op.maximum, max), |
| 1606 | + "mul": self._binary_op(relax.op.multiply, operator.mul), |
| 1607 | + "pow": self._binary_op(relax.op.power, operator.pow), |
| 1608 | + "sub": self._binary_op(relax.op.subtract, operator.sub), |
| 1609 | + "truediv": self._binary_op(relax.op.divide, operator.truediv), |
1655 | 1610 | # neural network |
1656 | 1611 | "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), |
1657 | 1612 | "addmm": self._addmm, |
1658 | 1613 | "avg_pool2d": self._avg_pool2d, |
1659 | 1614 | "baddbmm": self._baddbmm, |
1660 | | - "bmm": self._matmul, |
| 1615 | + "bmm": self._binary_op( |
| 1616 | + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul |
| 1617 | + ), |
1661 | 1618 | "conv_transpose1d": self._conv1d_transpose_functional, |
1662 | 1619 | "conv_transpose2d": self._conv2d_transpose_functional, |
1663 | 1620 | "conv1d": self._conv1d_functional, |
|
0 commit comments