Skip to content

Commit

Permalink
Add SiLU (apache#19497)
Browse files Browse the repository at this point in the history
* add silu activation

* Update activations.md

* add activation images

* update activation docs

* add silu test

* add silu test
  • Loading branch information
TFUsers authored and Vidya Sagar Ravipati committed Nov 11, 2020
1 parent 4d1ab2f commit b84ebc0
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/python_docs/python/api/gluon/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ Advanced Activation Layers
nn.ELU
nn.SELU
nn.Swish
nn.SiLU
nn.GELU

API Reference
-------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

Deep neural networks are a way to express a nonlinear function with lots of parameters from input data to outputs. The nonlinearities that allow neural networks to capture complex patterns in data are referred to as activation functions. Over the course of the development of neural networks, several nonlinear activation functions have been introduced to make gradient-based deep learning tractable.

If you are looking to answer the question, 'which activation function should I use for my neural network model?', you should probably go with *ReLU*. Unless you're trying to implement something like a gating mechanism, like in LSTMs or GRU cells, then you should opt for sigmoid and/or tanh in those cells. However, if you have a working model architecture and you're trying to improve its performance by swapping out activation functions or treating the activation function as a hyperparameter, then you may want to try hand-designed activations like SELU or a function discovered by reinforcement learning and exhaustive search like Swish. This guide describes these activation functions and others implemented in MXNet in detail.
If you are looking to answer the question, 'which activation function should I use for my neural network model?', you should probably go with *ReLU*. Unless you're trying to implement something like a gating mechanism, like in LSTMs or GRU cells, then you should opt for sigmoid and/or tanh in those cells. However, if you have a working model architecture and you're trying to improve its performance by swapping out activation functions or treating the activation function as a hyperparameter, then you may want to try hand-designed activations like SELU, SiLU, or GELU. This guide describes these activation functions and others implemented in MXNet in detail.

## Visualizing Activations
In order to compare the various activation functions and to understand the nuances of their differences we have a snippet of code to plot the activation functions (used in the forward pass) and their gradients (used in the backward pass).
Expand Down Expand Up @@ -237,24 +237,40 @@ visualize_activation(mx.gluon.nn.SELU())
![selu activation and gradient](images/selu.png)


### Swish
Swish is an activation function that attempts to address the shortcomings of ReLU by combining ideas from ReLU and sigmoid. Swish was discovered by searching the space of activation functions using a combination of exhaustive and reinforcement learning-based search and was introduced in the paper by [Ramchandran et al](https://arxiv.org/pdf/1710.05941.pdf).
### SiLU
The SiLU is an activation function that attempts to address the shortcomings of ReLU by combining ideas from ReLU and sigmoid. The SiLU serves as a smooth approximation to the ReLU and was originally introduced in [Hendrycks et al](https://arxiv.org/abs/1606.08415).

The swish function is given as
The silu function is given as

$$ swish(x) = x\cdot\sigma(\beta x)$$
$$ silu(x) = x\cdot\sigma(x)$$

where $\sigma$ is the sigmoid activation function $\sigma(x) = \frac{1}{1 + e^{-x}}$ described above and $\beta$ is a hyperparameter set to 1 by default in MXNet.
where $\sigma$ is the sigmoid activation function $\sigma(x) = \frac{1}{1 + e^{-x}}$ described above.


```{.python .input}
visualize_activation(mx.gluon.nn.Swish())
visualize_activation(mx.gluon.nn.SiLU())
```


![swish activation and gradient](images/swish.png)
![silu activation and gradient](images/silu.png)

### GELU
The GELU is a smooth approximation to the ReLU and was introduced in [Hendrycks et al](https://arxiv.org/abs/1606.08415). It is a common activation function in architectures such as Transformers, BERT, and GPT.

The gelu function is given as

$$ gelu(x) = x\cdot\Phi(x),$$

whereas the ReLU can be written as $x\cdot\mathbf{1}(x>0)$, so $Phi(x)$ serves as a smooth approximation to the ReLU's indicator function.

Note $\Phi(x) = \frac{1}{\sqrt{2 \pi}} \exp\left\{-\frac{x^2}{2}\right\}$ is the standard normal cumulative distribution.


```{.python .input}
visualize_activation(mx.gluon.nn.GELU())
```

![gelu activation and gradient](images/gelu.png)

## Summary

Expand All @@ -263,7 +279,7 @@ visualize_activation(mx.gluon.nn.Swish())
* Sigmoids like the logistic (sigmoid) function and tanh where the first kinds of activation functions used in neural networks. They have since fallen out of use because of their tendency to saturate and have vanishing gradients.
* Rectifiers like ReLU do not saturate like the Sigmoids and so address the vanishing gradient problem making them the de facto activation functions. ReLU however is still plagued by the dying ReLU problem.
* LeakyReLU and PReLU are two similar approaches to improve ReLU and address the dying ReLU by introducing a parameter $\alpha$ (learned in PReLU) that leaks to the gradient of negative inputs
* MXNet also implements custom state-of-the-art activations like ELU, SELU and Swish.
* MXNet also implements custom state-of-the-art activations like ELU, SELU, SiLU, and GELU.



Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 31 additions & 2 deletions python/mxnet/gluon/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# coding: utf-8
# pylint: disable= arguments-differ
"""Basic neural network layers."""
__all__ = ['Activation', 'LeakyReLU', 'PReLU', 'ELU', 'SELU', 'Swish', 'GELU']
__all__ = ['Activation', 'LeakyReLU', 'PReLU', 'ELU', 'SELU', 'Swish', 'GELU', 'SiLU']

from ... import initializer
from ..block import HybridBlock
Expand Down Expand Up @@ -215,7 +215,7 @@ def hybrid_forward(self, F, x):

class Swish(HybridBlock):
r"""
Swish Activation function
Swish Activation function (SiLU with a hyperparameter)
https://arxiv.org/pdf/1710.05941.pdf
Parameters
Expand All @@ -240,3 +240,32 @@ def hybrid_forward(self, F, x):
return x * F.npx.sigmoid(self._beta * x)
else:
return x * F.sigmoid(self._beta * x, name='fwd')


class SiLU(HybridBlock):
r"""
Sigmoid Linear Units
Originally proposed "Gaussian Error Linear Units (GELUs)", Hendrycks et al, 2016
https://arxiv.org/abs/1606.08415
Parameters
----------
beta : float
silu(x) = x * sigmoid(x)
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
"""

def __init__(self, **kwargs):
super(SiLU, self).__init__(**kwargs)

def hybrid_forward(self, F, x):
if is_np_array():
return x * F.npx.sigmoid(x)
else:
return x * F.sigmoid(x, name='fwd')
7 changes: 7 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,13 @@ def swish_test(x):
for test_point, ref_point in zip(swish_test(point_to_validate), swish(point_to_validate)):
assert test_point == ref_point

silu = mx.gluon.nn.SiLU()
def silu_test(x):
return x * mx.nd.sigmoid(x)

for test_point, ref_point in zip(silu_test(point_to_validate), silu(point_to_validate)):
assert test_point == ref_point

elu = mx.gluon.nn.ELU()
def elu_test(x):
def elu(x):
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_numpy_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,13 @@ def test_activations_swish():
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()


@use_np
def test_activations_silu():
act_layer = nn.SiLU()
out = act_layer(mx.np.random.uniform(size=(10,)))
out.asnumpy()

@use_np
def test_concatenate():
model = nn.HybridConcatenate(axis=1)
Expand Down

0 comments on commit b84ebc0

Please sign in to comment.