Skip to content

Commit 7c28c86

Browse files
authored
[Relax][PyTorch] Support binary, statistical and search ops for ExportedProgram importer (#17424)
* support binary ops * support mean * support sum * support argmax and argmin
1 parent 176d01e commit 7c28c86

File tree

4 files changed

+599
-62
lines changed

4 files changed

+599
-62
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,39 @@ def convert(node: fx.Node) -> relax.Var:
185185

186186
return convert
187187

188+
########## Binary Ops ##########
189+
190+
def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
191+
from torch import fx
192+
193+
def convert(node: fx.Node) -> relax.Var:
194+
def promote_binary_op_args(lhs, rhs):
195+
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
196+
return lhs, rhs
197+
elif isinstance(lhs, relax.Expr):
198+
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
199+
return lhs, relax.const(rhs, lhs.struct_info.dtype)
200+
elif isinstance(rhs, relax.Expr):
201+
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
202+
return relax.const(lhs, rhs.struct_info.dtype), rhs
203+
else:
204+
assert False
205+
206+
def call_binary_op(op, lhs, rhs):
207+
lhs, rhs = promote_binary_op_args(lhs, rhs)
208+
return self.block_builder.emit(op(lhs, rhs))
209+
210+
lhs, rhs = self.retrieve_args(node)
211+
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
212+
return call_binary_op(relax_op, lhs, rhs)
213+
elif isinstance(lhs, relax.expr.Constant):
214+
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
215+
elif isinstance(rhs, relax.expr.Constant):
216+
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
217+
return intrinsic_op(lhs, rhs)
218+
219+
return convert
220+
188221
########## Neural Network ##########
189222

190223
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
@@ -283,6 +316,35 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var:
283316

284317
return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)
285318

319+
########## Statistical ##########
320+
321+
def _mean(self, node: fx.Node) -> relax.Var:
322+
args = self.retrieve_args(node)
323+
x = args[0]
324+
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
325+
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
326+
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
327+
328+
def _sum(self, node: fx.Node) -> relax.Var:
329+
args = self.retrieve_args(node)
330+
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
331+
if len(args) == 1:
332+
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
333+
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
334+
335+
########## Search ##########
336+
337+
def _argmax_argmin(self, op: Callable) -> Callable:
338+
from torch import fx
339+
340+
def convert(node: fx.Node):
341+
x = self.env[node.args[0]]
342+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
343+
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
344+
return self.block_builder.emit(op(x, dim, keepdim))
345+
346+
return convert
347+
286348
########## Manipulation ##########
287349

288350
def _reshape(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# pylint: disable=import-outside-toplevel
2020
"""PyTorch ExportedProgram of Relax."""
2121
from collections import ChainMap, OrderedDict
22+
from functools import partial
2223
from typing import Callable, Dict, List, Tuple
2324

2425
import torch
@@ -76,6 +77,8 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr:
7677
def create_convert_map(
7778
self,
7879
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
80+
import operator
81+
7982
return {
8083
# unary
8184
"acos.default": self._unary_op(relax.op.acos),
@@ -109,11 +112,33 @@ def create_convert_map(
109112
"tanh.default": self._unary_op(relax.op.tanh),
110113
"tril.default": self._tril_triu(relax.op.tril),
111114
"triu.default": self._tril_triu(relax.op.triu),
115+
# binary
116+
"add.Tensor": self._binary_op(relax.op.add, operator.add),
117+
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
118+
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
119+
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
120+
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),
121+
"lt.Scalar": self._binary_op(relax.op.less, operator.lt),
122+
"lt.Tensor": self._binary_op(relax.op.less, operator.lt),
123+
"matmul.default": self._binary_op(
124+
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
125+
),
126+
"max.other": self._binary_op(relax.op.maximum, max),
127+
"mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
128+
"pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
129+
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
130+
"sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
112131
# neural network
113132
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
114133
"conv2d.default": self._conv2d,
115134
"linear.default": self._linear,
116135
"max_pool2d.default": self._max_pool2d,
136+
# statistical
137+
"mean.dim": self._mean,
138+
"sum.dim_IntList": self._sum,
139+
# search
140+
"argmax.default": self._argmax_argmin(relax.op.argmax),
141+
"argmin.default": self._argmax_argmin(relax.op.argmin),
117142
# tensor manipulation
118143
"view.default": self._reshape,
119144
}

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -96,39 +96,6 @@ def convert(node: fx.Node) -> relax.Var:
9696

9797
return convert
9898

99-
########## Binary Ops ##########
100-
101-
def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
102-
from torch import fx
103-
104-
def convert(node: fx.Node) -> relax.Var:
105-
def promote_binary_op_args(lhs, rhs):
106-
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
107-
return lhs, rhs
108-
elif isinstance(lhs, relax.Expr):
109-
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
110-
return lhs, relax.const(rhs, lhs.struct_info.dtype)
111-
elif isinstance(rhs, relax.Expr):
112-
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
113-
return relax.const(lhs, rhs.struct_info.dtype), rhs
114-
else:
115-
assert False
116-
117-
def call_binary_op(op, lhs, rhs):
118-
lhs, rhs = promote_binary_op_args(lhs, rhs)
119-
return self.block_builder.emit(op(lhs, rhs))
120-
121-
lhs, rhs = self.retrieve_args(node)
122-
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
123-
return call_binary_op(relax_op, lhs, rhs)
124-
elif isinstance(lhs, relax.expr.Constant):
125-
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
126-
elif isinstance(rhs, relax.expr.Constant):
127-
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
128-
return intrinsic_op(lhs, rhs)
129-
130-
return convert
131-
13299
########## Neural Network ##########
133100

134101
def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
@@ -794,35 +761,6 @@ def _unbind(self, node: fx.Node) -> relax.Var:
794761
ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim)))
795762
return self.block_builder.emit(relax.Tuple(ret))
796763

797-
########## Statistical ##########
798-
799-
def _mean(self, node: fx.Node) -> relax.Var:
800-
args = self.retrieve_args(node)
801-
x = args[0]
802-
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
803-
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
804-
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
805-
806-
def _sum(self, node: fx.Node) -> relax.Var:
807-
args = self.retrieve_args(node)
808-
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
809-
if len(args) == 1:
810-
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
811-
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
812-
813-
########## Search ##########
814-
815-
def _argmax_argmin(self, op: Callable) -> Callable:
816-
from torch import fx
817-
818-
def convert(node: fx.Node):
819-
x = self.env[node.args[0]]
820-
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
821-
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
822-
return self.block_builder.emit(op(x, dim, keepdim))
823-
824-
return convert
825-
826764
########## Manipulation ##########
827765

828766
def _cat(self, node: fx.Node) -> relax.Var:

0 commit comments

Comments
 (0)