diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d84993c68d4e..e601f1818101 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -127,6 +127,28 @@ def _clamp(self, node: fx.Node) -> relax.Expr: ) return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + def _elu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.0) + dtype = x.struct_info.dtype + + if isinstance(alpha, (int, float)): + alpha = relax.const(alpha, dtype) + else: + if not isinstance(alpha, relax.Var): + alpha = self.block_builder.emit(relax.const(alpha, dtype)) + + # α⋅ReLU(1−exp(x))+ReLU(x) + return self.block_builder.emit( + relax.op.add( + relax.op.multiply( + alpha, + relax.op.nn.relu(relax.op.subtract(relax.const(1, dtype), relax.op.exp(x))), + ), + relax.op.nn.relu(x), + ) + ) + def _gelu(self, node: fx.Node) -> relax.Expr: approximate = node.kwargs.get("approximate", "none") if approximate == "none": @@ -153,6 +175,13 @@ def _hardswish(self, node: fx.Node) -> relax.Var: x2 = relax.op.divide(x1, relax.const(6, dtype)) return self.block_builder.emit(relax.op.multiply(x, x2)) + def _hardtanh(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + x = args[0] + min_val = node.kwargs.get("min_val", -1.0) + max_val = node.kwargs.get("max_val", 1.0) + return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) + def _leakyrelu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index dffe2b60eb31..bbad7c0c704d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -581,9 +581,11 @@ def create_convert_map( ## call_module # unary nn.Dropout: lambda node: self.env[node.args[0]], + nn.ELU: self._elu, nn.GELU: self._gelu, nn.Hardsigmoid: self._hardsigmoid, nn.Hardswish: self._hardswish, + nn.Hardtanh: self._hardtanh, nn.Identity: lambda node: self.env[node.args[0]], nn.LeakyReLU: self._leakyrelu_module, nn.LogSoftmax: self._log_softmax_module, @@ -627,12 +629,14 @@ def create_convert_map( "cos": self._unary_op(relax.op.cos), "cosh": self._unary_op(relax.op.cosh), "dropout": lambda node: self.env[node.args[0]], + "elu": self._elu, "erf": self._unary_op(relax.op.erf), "exp": self._unary_op(relax.op.exp), "floor": self._unary_op(relax.op.floor), "gelu": self._gelu, "hardsigmoid": self._hardsigmoid, "hardswish": self._hardswish, + "hardtanh": self._hardtanh, "isfinite": self._unary_op(relax.op.isfinite), "isinf": self._unary_op(relax.op.isinf), "isnan": self._unary_op(relax.op.isnan), diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 9e7e1ff2ea0a..797ce05a3fa0 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1989,6 +1989,46 @@ def main( verify_model(Dropout1(), input_info, {}, expected_dropout) verify_model(Dropout2(), input_info, {}, expected_dropout) + # elu + class Elu(Module): + def __init__(self): + super().__init__() + self.elu = torch.nn.ELU() + + def forward(self, input): + return self.elu(input) + + class Elu2(Module): + def forward(self, input): + return torch.nn.functional.elu(input) + + @tvm.script.ir_module + class expected_elu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + R.const(1.0, dtype="float32"), lv_exp + ) + lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu( + lv_one_minus_exp + ) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(1.0, dtype="float32"), lv_relu_one_minus_exp + ) + lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_elu + R.output(gv) + return gv + + verify_model(Elu(), input_info, {}, expected_elu) + verify_model(Elu2(), input_info, {}, expected_elu) + # gelu class Gelu(Module): def __init__(self): @@ -2086,6 +2126,36 @@ def main( verify_model(Hardswish(), input_info, {}, expected_hardswish) verify_model(Hardswish2(), input_info, {}, expected_hardswish) + # hardtanh + class Hardtanh(torch.nn.Module): + def __init__(self): + super().__init__() + self.ht = torch.nn.Hardtanh() + + def forward(self, input): + return self.ht(input) + + class Hardtanh2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardtanh(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0)) + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Hardtanh(), input_info, {}, expected1) + verify_model(Hardtanh2(), input_info, {}, expected1) + # logical_not class LogicalNot(Module): def forward(self, input):