Skip to content

Commit cc7eb2f

Browse files
authored
[Relax] [PyTorch] Add support for torch.nn.Hardswish (#17084)
* add hardswish support to fx_frontend * run ./tests/lint/git-black.sh -i --rev upstream/main * fix ci lint error
1 parent ab02979 commit cc7eb2f

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,15 @@ def _gelu(self, node: fx.node.Node) -> relax.Expr:
243243
else:
244244
raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate))
245245

246+
def _hardswish(self, node: fx.node.Node) -> relax.Var:
247+
args = self.retrieve_args(node)
248+
x = args[0]
249+
dtype = x.struct_info.dtype
250+
x0 = relax.op.add(x, relax.const(3, dtype))
251+
x1 = relax.op.clip(x0, 0, 6)
252+
x2 = relax.op.divide(x1, relax.const(6, dtype))
253+
return self.block_builder.emit(relax.op.multiply(x, x2))
254+
246255
########## Compare ##########
247256

248257
def _lt(self, node: fx.node.Node) -> relax.Expr:
@@ -1358,6 +1367,7 @@ def create_convert_map(self):
13581367
nn.Sigmoid: self._sigmoid,
13591368
nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
13601369
nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
1370+
nn.Hardswish: self._hardswish,
13611371
nn.Flatten: self._flatten,
13621372
nn.BatchNorm2d: self._batch_norm_2d,
13631373
nn.LayerNorm: self._layer_norm,
@@ -1437,6 +1447,7 @@ def create_convert_map(self):
14371447
"leaky_relu": self._leakyrelu,
14381448
"gelu": self._gelu,
14391449
"silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
1450+
"hardswish": self._hardswish,
14401451
"interpolate": self._interpolate,
14411452
"size": self._size,
14421453
"getattr": self._getattr,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,42 @@ def main(
14161416
verify_model(SiLU2(), input_info, {}, expected1)
14171417

14181418

1419+
def test_hardswish():
1420+
input_info = [([1, 3, 10, 10], "float32")]
1421+
1422+
class Hardswish(torch.nn.Module):
1423+
def __init__(self):
1424+
super().__init__()
1425+
self.hs = torch.nn.Hardswish()
1426+
1427+
def forward(self, input):
1428+
return self.hs(input)
1429+
1430+
class Hardswish2(torch.nn.Module):
1431+
def forward(self, input):
1432+
return torch.nn.functional.hardswish(input)
1433+
1434+
@tvm.script.ir_module
1435+
class expected1:
1436+
@R.function
1437+
def main(
1438+
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
1439+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
1440+
with R.dataflow():
1441+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32"))
1442+
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6)
1443+
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
1444+
lv1, R.const(6, "float32")
1445+
)
1446+
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2)
1447+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv3
1448+
R.output(gv)
1449+
return gv
1450+
1451+
verify_model(Hardswish(), input_info, {}, expected1)
1452+
verify_model(Hardswish2(), input_info, {}, expected1)
1453+
1454+
14191455
def test_groupnorm():
14201456
import torch
14211457
from torch.nn import Module

0 commit comments

Comments
 (0)