Skip to content

Commit 12e72c2

Browse files
authored
[Relax][PyTorch] Add support for elu, hardtanh ops (#17694)
* Update fx_translator.py * Update test_frontend_from_fx.py * Update base_fx_graph_translator.py * Update fx_translator.py * Update test_frontend_from_fx.py * Update fx_translator.py * Update fx_translator.py * lint * lint * lint
1 parent 5fc254e commit 12e72c2

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,28 @@ def _clamp(self, node: fx.Node) -> relax.Expr:
127127
)
128128
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
129129

130+
def _elu(self, node: fx.Node) -> relax.Var:
131+
x = self.env[node.args[0]]
132+
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.0)
133+
dtype = x.struct_info.dtype
134+
135+
if isinstance(alpha, (int, float)):
136+
alpha = relax.const(alpha, dtype)
137+
else:
138+
if not isinstance(alpha, relax.Var):
139+
alpha = self.block_builder.emit(relax.const(alpha, dtype))
140+
141+
# α⋅ReLU(1−exp(x))+ReLU(x)
142+
return self.block_builder.emit(
143+
relax.op.add(
144+
relax.op.multiply(
145+
alpha,
146+
relax.op.nn.relu(relax.op.subtract(relax.const(1, dtype), relax.op.exp(x))),
147+
),
148+
relax.op.nn.relu(x),
149+
)
150+
)
151+
130152
def _gelu(self, node: fx.Node) -> relax.Expr:
131153
approximate = node.kwargs.get("approximate", "none")
132154
if approximate == "none":
@@ -153,6 +175,13 @@ def _hardswish(self, node: fx.Node) -> relax.Var:
153175
x2 = relax.op.divide(x1, relax.const(6, dtype))
154176
return self.block_builder.emit(relax.op.multiply(x, x2))
155177

178+
def _hardtanh(self, node: fx.Node) -> relax.Expr:
179+
args = self.retrieve_args(node)
180+
x = args[0]
181+
min_val = node.kwargs.get("min_val", -1.0)
182+
max_val = node.kwargs.get("max_val", 1.0)
183+
return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
184+
156185
def _leakyrelu(self, node: fx.Node) -> relax.Var:
157186
x = self.env[node.args[0]]
158187
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01)

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,9 +581,11 @@ def create_convert_map(
581581
## call_module
582582
# unary
583583
nn.Dropout: lambda node: self.env[node.args[0]],
584+
nn.ELU: self._elu,
584585
nn.GELU: self._gelu,
585586
nn.Hardsigmoid: self._hardsigmoid,
586587
nn.Hardswish: self._hardswish,
588+
nn.Hardtanh: self._hardtanh,
587589
nn.Identity: lambda node: self.env[node.args[0]],
588590
nn.LeakyReLU: self._leakyrelu_module,
589591
nn.LogSoftmax: self._log_softmax_module,
@@ -627,12 +629,14 @@ def create_convert_map(
627629
"cos": self._unary_op(relax.op.cos),
628630
"cosh": self._unary_op(relax.op.cosh),
629631
"dropout": lambda node: self.env[node.args[0]],
632+
"elu": self._elu,
630633
"erf": self._unary_op(relax.op.erf),
631634
"exp": self._unary_op(relax.op.exp),
632635
"floor": self._unary_op(relax.op.floor),
633636
"gelu": self._gelu,
634637
"hardsigmoid": self._hardsigmoid,
635638
"hardswish": self._hardswish,
639+
"hardtanh": self._hardtanh,
636640
"isfinite": self._unary_op(relax.op.isfinite),
637641
"isinf": self._unary_op(relax.op.isinf),
638642
"isnan": self._unary_op(relax.op.isnan),

tests/python/relax/test_frontend_from_fx.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,46 @@ def main(
19891989
verify_model(Dropout1(), input_info, {}, expected_dropout)
19901990
verify_model(Dropout2(), input_info, {}, expected_dropout)
19911991

1992+
# elu
1993+
class Elu(Module):
1994+
def __init__(self):
1995+
super().__init__()
1996+
self.elu = torch.nn.ELU()
1997+
1998+
def forward(self, input):
1999+
return self.elu(input)
2000+
2001+
class Elu2(Module):
2002+
def forward(self, input):
2003+
return torch.nn.functional.elu(input)
2004+
2005+
@tvm.script.ir_module
2006+
class expected_elu:
2007+
@R.function
2008+
def main(
2009+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
2010+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
2011+
# block 0
2012+
with R.dataflow():
2013+
lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
2014+
lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
2015+
R.const(1.0, dtype="float32"), lv_exp
2016+
)
2017+
lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(
2018+
lv_one_minus_exp
2019+
)
2020+
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
2021+
R.const(1.0, dtype="float32"), lv_relu_one_minus_exp
2022+
)
2023+
lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
2024+
lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x)
2025+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_elu
2026+
R.output(gv)
2027+
return gv
2028+
2029+
verify_model(Elu(), input_info, {}, expected_elu)
2030+
verify_model(Elu2(), input_info, {}, expected_elu)
2031+
19922032
# gelu
19932033
class Gelu(Module):
19942034
def __init__(self):
@@ -2086,6 +2126,36 @@ def main(
20862126
verify_model(Hardswish(), input_info, {}, expected_hardswish)
20872127
verify_model(Hardswish2(), input_info, {}, expected_hardswish)
20882128

2129+
# hardtanh
2130+
class Hardtanh(torch.nn.Module):
2131+
def __init__(self):
2132+
super().__init__()
2133+
self.ht = torch.nn.Hardtanh()
2134+
2135+
def forward(self, input):
2136+
return self.ht(input)
2137+
2138+
class Hardtanh2(torch.nn.Module):
2139+
def forward(self, input):
2140+
return torch.nn.functional.hardtanh(input)
2141+
2142+
@tvm.script.ir_module
2143+
class expected1:
2144+
@R.function
2145+
def main(
2146+
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
2147+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
2148+
with R.dataflow():
2149+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
2150+
inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0))
2151+
)
2152+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
2153+
R.output(gv)
2154+
return gv
2155+
2156+
verify_model(Hardtanh(), input_info, {}, expected1)
2157+
verify_model(Hardtanh2(), input_info, {}, expected1)
2158+
20892159
# logical_not
20902160
class LogicalNot(Module):
20912161
def forward(self, input):

0 commit comments

Comments
 (0)