diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index c7dc83176e14..8c51b0a52592 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -153,12 +153,13 @@ class ELU(HybridBlock): Outputs: - **out**: output tensor with the same shape as `data`. """ + def __init__(self, alpha=1.0, **kwargs): super(ELU, self).__init__(**kwargs) self._alpha = alpha def hybrid_forward(self, F, x): - return F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0)) + return F.LeakyReLU(x, act_type='elu', slope=self._alpha) class SELU(HybridBlock): diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 8c60ef6745f1..efa04f4fa47a 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1180,7 +1180,7 @@ def swish_test(x): elu = mx.gluon.nn.ELU() def elu_test(x): def elu(x): - return 1.0 * (mx.nd.exp(x) - 1) if x < 0 else x + return mx.nd.expm1(x) if x <= 0.0 else x return [elu(x_i) for x_i in x] for test_point, ref_point in zip(elu_test(point_to_validate), elu(point_to_validate)):