Skip to content

Commit 176d01e

Browse files
authored
[Relax][PyTorch] Support more unary ops for ExportedProgram importer (#17421)
* support more unary ops * support clamp * support gelu * support hardsigmoid * support hardswish * support hardtanh * support leaky_relu * support log_softmax * support round * support softmax * support tril and triu * skip flaky test
1 parent 42ff98b commit 176d01e

File tree

5 files changed

+812
-80
lines changed

5 files changed

+812
-80
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,80 @@ def convert(node: fx.Node) -> relax.Var:
111111

112112
return convert
113113

114+
def _clamp(self, node: fx.Node) -> relax.Expr:
115+
args = self.retrieve_args(node)
116+
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
117+
a_max = args[2] if len(args) > 2 else node.kwargs["max"]
118+
if not isinstance(a_min, (int, float)):
119+
raise ValueError(
120+
f"TVM only supports constant min value for torch.clamp/clip, "
121+
f"but got {a_min} with type {type(a_min)}"
122+
)
123+
if not isinstance(a_max, (int, float)):
124+
raise ValueError(
125+
f"TVM only supports constant max value for torch.clamp/clip, "
126+
f"but got {a_max} with type {type(a_max)}"
127+
)
128+
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
129+
130+
def _gelu(self, node: fx.Node) -> relax.Expr:
131+
approximate = node.kwargs.get("approximate", "none")
132+
if approximate == "none":
133+
return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
134+
elif approximate == "tanh":
135+
return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
136+
else:
137+
raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate))
138+
139+
def _hardsigmoid(self, node: fx.Node) -> relax.Var:
140+
args = self.retrieve_args(node)
141+
x = args[0]
142+
dtype = x.struct_info.dtype
143+
x0 = relax.op.add(x, relax.const(3, dtype))
144+
x1 = relax.op.clip(x0, 0, 6)
145+
return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype)))
146+
147+
def _hardswish(self, node: fx.Node) -> relax.Var:
148+
args = self.retrieve_args(node)
149+
x = args[0]
150+
dtype = x.struct_info.dtype
151+
x0 = relax.op.add(x, relax.const(3, dtype))
152+
x1 = relax.op.clip(x0, 0, 6)
153+
x2 = relax.op.divide(x1, relax.const(6, dtype))
154+
return self.block_builder.emit(relax.op.multiply(x, x2))
155+
156+
def _leakyrelu(self, node: fx.Node) -> relax.Var:
157+
x = self.env[node.args[0]]
158+
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01)
159+
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
160+
161+
def _log_softmax(self, node: fx.Node) -> relax.Var:
162+
x = self.env[node.args[0]]
163+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
164+
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
165+
166+
def _round(self, node: fx.Node) -> relax.Expr:
167+
if node.kwargs.get("decimals", 0) != 0:
168+
raise ValueError("specifying decimals for round is not supported yet")
169+
arg = self.env[node.args[0]]
170+
return self.block_builder.emit(relax.op.round(arg))
171+
172+
def _softmax(self, node: fx.Node) -> relax.Var:
173+
x = self.env[node.args[0]]
174+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
175+
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
176+
177+
def _tril_triu(self, op: Callable) -> Callable:
178+
from torch import fx
179+
180+
def convert(node: fx.Node) -> relax.Var:
181+
x = self.env[node.args[0]]
182+
k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0)
183+
assert isinstance(k, int)
184+
return self.block_builder.emit(op(x, k))
185+
186+
return convert
187+
114188
########## Neural Network ##########
115189

116190
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,51 @@ def create_input_vars(
6464

6565
return parameters_buffers_constants, user_inputs
6666

67+
########## Unary Ops ##########
68+
69+
def _hardtanh(self, node: fx.Node) -> relax.Expr:
70+
args = self.retrieve_args(node)
71+
x = args[0]
72+
min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0)
73+
max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0)
74+
return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
75+
6776
def create_convert_map(
6877
self,
6978
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
7079
return {
7180
# unary
81+
"acos.default": self._unary_op(relax.op.acos),
82+
"acosh.default": self._unary_op(relax.op.acosh),
83+
"asin.default": self._unary_op(relax.op.asin),
84+
"asinh.default": self._unary_op(relax.op.asinh),
85+
"atan.default": self._unary_op(relax.op.atan),
86+
"atanh.default": self._unary_op(relax.op.atanh),
87+
"clamp.default": self._clamp,
88+
"cos.default": self._unary_op(relax.op.cos),
89+
"cosh.default": self._unary_op(relax.op.cosh),
7290
"dropout.default": lambda node: self.env[node.args[0]],
91+
"exp.default": self._unary_op(relax.op.exp),
92+
"gelu.default": self._gelu,
93+
"hardsigmoid.default": self._hardsigmoid,
94+
"hardswish.default": self._hardswish,
95+
"hardtanh.default": self._hardtanh,
96+
"leaky_relu.default": self._leakyrelu,
97+
"log_softmax.int": self._log_softmax,
98+
"neg.default": self._unary_op(relax.op.negative),
7399
"relu.default": self._unary_op(relax.op.nn.relu),
100+
"round.default": self._round,
101+
"rsqrt.default": self._unary_op(relax.op.rsqrt),
102+
"sigmoid.default": self._unary_op(relax.op.sigmoid),
103+
"silu.default": self._unary_op(relax.op.nn.silu),
104+
"sin.default": self._unary_op(relax.op.sin),
105+
"sinh.default": self._unary_op(relax.op.sinh),
106+
"softmax.int": self._softmax,
107+
"sqrt.default": self._unary_op(relax.op.sqrt),
108+
"tan.default": self._unary_op(relax.op.tan),
109+
"tanh.default": self._unary_op(relax.op.tanh),
110+
"tril.default": self._tril_triu(relax.op.tril),
111+
"triu.default": self._tril_triu(relax.op.triu),
74112
# neural network
75113
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
76114
"conv2d.default": self._conv2d,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -62,82 +62,19 @@ def _fetch_attr(self, model, target: str):
6262

6363
########## Unary Ops ##########
6464

65-
def _clamp(self, node: fx.Node) -> relax.Expr:
66-
args = self.retrieve_args(node)
67-
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
68-
a_max = args[2] if len(args) > 2 else node.kwargs["max"]
69-
if not isinstance(a_min, (int, float)):
70-
raise ValueError(
71-
f"TVM only supports constant min value for torch.clamp/clip, "
72-
f"but got {a_min} with type {type(a_min)}"
73-
)
74-
if not isinstance(a_max, (int, float)):
75-
raise ValueError(
76-
f"TVM only supports constant max value for torch.clamp/clip, "
77-
f"but got {a_max} with type {type(a_max)}"
78-
)
79-
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
80-
81-
def _gelu(self, node: fx.Node) -> relax.Expr:
82-
approximate = node.kwargs.get("approximate", "none")
83-
if approximate == "none":
84-
return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
85-
elif approximate == "tanh":
86-
return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
87-
else:
88-
raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate))
89-
90-
def _hardsigmoid(self, node: fx.Node) -> relax.Var:
91-
args = self.retrieve_args(node)
92-
x = args[0]
93-
dtype = x.struct_info.dtype
94-
x0 = relax.op.add(x, relax.const(3, dtype))
95-
x1 = relax.op.clip(x0, 0, 6)
96-
return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype)))
97-
98-
def _hardswish(self, node: fx.Node) -> relax.Var:
99-
args = self.retrieve_args(node)
100-
x = args[0]
101-
dtype = x.struct_info.dtype
102-
x0 = relax.op.add(x, relax.const(3, dtype))
103-
x1 = relax.op.clip(x0, 0, 6)
104-
x2 = relax.op.divide(x1, relax.const(6, dtype))
105-
return self.block_builder.emit(relax.op.multiply(x, x2))
106-
107-
def _leakyrelu(self, node: fx.Node) -> relax.Var:
108-
x = self.env[node.args[0]]
109-
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01)
110-
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
111-
11265
def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
11366
x = self.env[node.args[0]]
11467
module = self.named_modules[node.target]
11568
alpha = module.negative_slope
11669
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
11770

118-
def _log_softmax(self, node: fx.Node) -> relax.Var:
119-
x = self.env[node.args[0]]
120-
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
121-
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
122-
12371
def _log_softmax_module(self, node: fx.Node) -> relax.Var:
12472
x = self.env[node.args[0]]
12573
module = self.named_modules[node.target]
12674
dim = module.dim
12775
assert dim is not None
12876
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
12977

130-
def _round(self, node: fx.Node) -> relax.Expr:
131-
if node.kwargs.get("decimals", 0) != 0:
132-
raise ValueError("specifying decimals for round is not supported yet")
133-
arg = self.env[node.args[0]]
134-
return self.block_builder.emit(relax.op.round(arg))
135-
136-
def _softmax(self, node: fx.Node) -> relax.Var:
137-
x = self.env[node.args[0]]
138-
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
139-
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
140-
14178
def _softmax_module(self, node: fx.Node) -> relax.Var:
14279
x = self.env[node.args[0]]
14380
module = self.named_modules[node.target]
@@ -159,17 +96,6 @@ def convert(node: fx.Node) -> relax.Var:
15996

16097
return convert
16198

162-
def _tril_triu(self, op: Callable) -> Callable:
163-
from torch import fx
164-
165-
def convert(node: fx.Node) -> relax.Var:
166-
x = self.env[node.args[0]]
167-
k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0)
168-
assert isinstance(k, int)
169-
return self.block_builder.emit(op(x, k))
170-
171-
return convert
172-
17399
########## Binary Ops ##########
174100

175101
def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:

0 commit comments

Comments
 (0)