From 089b90cb392facb8f3942486bcfa12e0291ae83b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 10 Aug 2021 12:44:29 +0900 Subject: [PATCH] [Torch] Fix ELU conversion --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index da81280d76a3..ca24335d4fd0 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -771,7 +771,7 @@ def leaky_relu(self, inputs, input_types): def elu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] - alpha = _expr.const(float(inputs[1]), dtype=dtype) + alpha = _expr.const(-float(inputs[1]), dtype=dtype) return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) def celu(self, inputs, input_types): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2e6828f693b6..c924e73a9034 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -666,7 +666,7 @@ def test_forward_leakyrelu(): def test_forward_elu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - input_data = torch.rand(input_shape).float() + input_data = torch.randn(input_shape).float() verify_model(torch.nn.ELU().eval(), input_data=input_data) verify_model(torch.nn.ELU(alpha=0.3).eval(), input_data=input_data) verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data)