From a59930cc7a32111696fb9b45894c984a274089da Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 12 Jun 2024 19:21:33 +0900 Subject: [PATCH 1/3] add hardswish support to fx_frontend --- .../tvm/relax/frontend/torch/fx_translator.py | 11 ++++++ tests/python/relax/test_frontend_from_fx.py | 34 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e26e9bc7dc4c..a5efcce27859 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -243,6 +243,15 @@ def _gelu(self, node: fx.node.Node) -> relax.Expr: else: raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + def _hardswish(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + x2 = relax.op.divide(x1, relax.const(6, dtype)) + return self.block_builder.emit(relax.op.multiply(x, x2)) + ########## Compare ########## def _lt(self, node: fx.node.Node) -> relax.Expr: @@ -1358,6 +1367,7 @@ def create_convert_map(self): nn.Sigmoid: self._sigmoid, nn.Tanh: lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.Hardswish: self._hardswish, nn.Flatten: self._flatten, nn.BatchNorm2d: self._batch_norm_2d, nn.LayerNorm: self._layer_norm, @@ -1437,6 +1447,7 @@ def create_convert_map(self): "leaky_relu": self._leakyrelu, "gelu": self._gelu, "silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + "hardswish": self._hardswish, "interpolate": self._interpolate, "size": self._size, "getattr": self._getattr, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index dfa5cad4a5a7..79ce81f38f79 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1416,6 +1416,40 @@ def main( verify_model(SiLU2(), input_info, {}, expected1) +def test_hardswish(): + input_info = [([1, 3, 10, 10], "float32")] + + class Hardswish(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardswish() + + def forward(self, input): + return self.hs(input) + + class Hardswish2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardswish(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lv1, R.const(6, "float32")) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv3 + R.output(gv) + return gv + + verify_model(Hardswish(), input_info, {}, expected1) + verify_model(Hardswish2(), input_info, {}, expected1) + + def test_groupnorm(): import torch from torch.nn import Module From e83ba158fd36af1461be479e227bba45d3f53818 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 12 Jun 2024 19:45:24 +0900 Subject: [PATCH 2/3] run ./tests/lint/git-black.sh -i --rev upstream/main --- tests/python/relax/test_frontend_from_fx.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 79ce81f38f79..8eea9114a515 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1440,7 +1440,9 @@ def main( with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lv1, R.const(6, "float32")) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv3 R.output(gv) @@ -2633,9 +2635,7 @@ def forward(self, input): @tvm.script.ir_module class expected1: @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -2691,9 +2691,7 @@ def forward(self, input): @tvm.script.ir_module class Expected: @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple( + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), From 6b14541f9e04b05377a5f9a1f13c2c3cf7f4cb4f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 12 Jun 2024 20:16:46 +0900 Subject: [PATCH 3/3] fix ci lint error --- tests/python/relax/test_frontend_from_fx.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 8eea9114a515..49131b5ff891 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2635,7 +2635,9 @@ def forward(self, input): @tvm.script.ir_module class expected1: @R.function - def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), @@ -2691,7 +2693,9 @@ def forward(self, input): @tvm.script.ir_module class Expected: @R.function - def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"), R.Tensor((1, 1, 10, 10), dtype="float32"),