From a0b57b0d1f021653be39efe7b3f2d9e5ff5bbf33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=8F=E9=B2=81=E8=B1=AB?= Date: Tue, 23 Apr 2019 13:58:42 +0800 Subject: [PATCH] [BUGFIX] fix ELU function will appear nan when calculating the gradient (#14673) * fix ELU * fix * fix * fix * fix * fix --- python/mxnet/gluon/nn/activations.py | 3 ++- tests/python/unittest/test_gluon.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index c7dc83176e14..8c51b0a52592 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -153,12 +153,13 @@ class ELU(HybridBlock): Outputs: - **out**: output tensor with the same shape as `data`. """ + def __init__(self, alpha=1.0, **kwargs): super(ELU, self).__init__(**kwargs) self._alpha = alpha def hybrid_forward(self, F, x): - return F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0)) + return F.LeakyReLU(x, act_type='elu', slope=self._alpha) class SELU(HybridBlock): diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 8c60ef6745f1..efa04f4fa47a 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1180,7 +1180,7 @@ def swish_test(x): elu = mx.gluon.nn.ELU() def elu_test(x): def elu(x): - return 1.0 * (mx.nd.exp(x) - 1) if x < 0 else x + return mx.nd.expm1(x) if x <= 0.0 else x return [elu(x_i) for x_i in x] for test_point, ref_point in zip(elu_test(point_to_validate), elu(point_to_validate)):