Skip to content

Commit 64dae9d

Browse files
authored
[Relax][PyTorch] Support elu, celu, selu ops for ExportedProgram importer (#17738)
* Update exported_program_translator.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py
1 parent 363ebb4 commit 64dae9d

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,14 @@ def create_convert_map(
192192
"atanh.default": self._unary_op(relax.op.atanh),
193193
"bitwise_not.default": self._unary_op(relax.op.bitwise_not),
194194
"ceil.default": self._unary_op(relax.op.ceil),
195+
"celu.default": self._celu,
195196
"clamp.default": self._clamp,
196197
"clamp_min.default": self._clamp_min,
197198
"clamp_max.default": self._clamp_max,
198199
"cos.default": self._unary_op(relax.op.cos),
199200
"cosh.default": self._unary_op(relax.op.cosh),
200201
"dropout.default": lambda node: self.env[node.args[0]],
202+
"elu.default": self._elu,
201203
"erf.default": self._unary_op(relax.op.erf),
202204
"exp.default": self._unary_op(relax.op.exp),
203205
"floor.default": self._unary_op(relax.op.floor),
@@ -215,6 +217,7 @@ def create_convert_map(
215217
"relu.default": self._unary_op(relax.op.nn.relu),
216218
"round.default": self._round,
217219
"rsqrt.default": self._unary_op(relax.op.rsqrt),
220+
"selu.default": self._selu,
218221
"sigmoid.default": self._unary_op(relax.op.sigmoid),
219222
"sign.default": self._unary_op(relax.op.sign),
220223
"silu.default": self._unary_op(relax.op.nn.silu),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,49 @@ def main(
126126
def test_extended_unary_ops():
127127
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
128128

129+
# celu
130+
class Celu1(Module):
131+
def __init__(self):
132+
super().__init__()
133+
self.celu = torch.nn.CELU()
134+
135+
def forward(self, input):
136+
return self.celu(input)
137+
138+
class Celu2(Module):
139+
def forward(self, input):
140+
return torch.nn.functional.celu(input)
141+
142+
# alpha * min(0, exp(x / alpha) - 1) + max(0, x)
143+
@tvm.script.ir_module
144+
class expected_celu:
145+
@R.function
146+
def main(
147+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
148+
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
149+
with R.dataflow():
150+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
151+
lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
152+
lv, R.const(1.0, "float32")
153+
)
154+
lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
155+
lv_div, R.const(1.0, "float32")
156+
)
157+
lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(
158+
R.const(0.0, "float32"), lv_sub
159+
)
160+
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
161+
R.const(1.0, "float32"), lv_min
162+
)
163+
lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
164+
lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x)
165+
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,)
166+
R.output(gv)
167+
return gv
168+
169+
verify_model(Celu1(), example_args, {}, expected_celu)
170+
verify_model(Celu2(), example_args, {}, expected_celu)
171+
129172
# clamp
130173
class Clamp(Module):
131174
def forward(self, input):
@@ -226,6 +269,46 @@ def main(
226269
verify_model(Dropout1(), example_args, {}, expected_dropout)
227270
verify_model(Dropout2(), example_args, {}, expected_dropout)
228271

272+
# elu
273+
class Elu(Module):
274+
def __init__(self):
275+
super().__init__()
276+
self.elu = torch.nn.ELU()
277+
278+
def forward(self, input):
279+
return self.elu(input)
280+
281+
class Elu2(Module):
282+
def forward(self, input):
283+
return torch.nn.functional.elu(input)
284+
285+
@tvm.script.ir_module
286+
class expected_elu:
287+
@R.function
288+
def main(
289+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
290+
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
291+
# block 0
292+
with R.dataflow():
293+
lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
294+
lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
295+
R.const(1.0, dtype="float32"), lv_exp
296+
)
297+
lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(
298+
lv_one_minus_exp
299+
)
300+
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
301+
R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp
302+
)
303+
lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
304+
lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x)
305+
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,)
306+
R.output(gv)
307+
return gv
308+
309+
verify_model(Elu(), example_args, {}, expected_elu)
310+
verify_model(Elu2(), example_args, {}, expected_elu)
311+
229312
# gelu
230313
class Gelu(Module):
231314
def __init__(self):
@@ -358,6 +441,46 @@ def main(
358441
verify_model(ReLU0(), example_args, {}, expected_relu)
359442
verify_model(ReLU1(), example_args, {}, expected_relu)
360443

444+
# selu
445+
class Selu1(Module):
446+
def __init__(self):
447+
super().__init__()
448+
self.selu = torch.nn.SELU()
449+
450+
def forward(self, input):
451+
return self.selu(input)
452+
453+
class Selu2(Module):
454+
def forward(self, input):
455+
return torch.nn.functional.selu(input)
456+
457+
@tvm.script.ir_module
458+
class expected_selu:
459+
@R.function
460+
def main(
461+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
462+
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
463+
with R.dataflow():
464+
lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
465+
lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
466+
lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
467+
lv_exp, R.const(1.0, "float32")
468+
)
469+
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
470+
R.const(1.6732631921768188, "float32"), lv_sub
471+
)
472+
lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled)
473+
lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
474+
R.const(1.0507010221481323, "float32"), lv_add
475+
)
476+
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,)
477+
R.output(gv)
478+
479+
return gv
480+
481+
verify_model(Selu1(), example_args, {}, expected_selu)
482+
verify_model(Selu2(), example_args, {}, expected_selu)
483+
361484
# sigmoid
362485
class Sigmoid(Module):
363486
def __init__(self):

0 commit comments

Comments
 (0)