Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix GELU backward possible NaN (#14782)
Browse files Browse the repository at this point in the history
* fix gelu with erf functions

* fix possible NaN in GELU backward
  • Loading branch information
haojin2 authored and eric-haibin-lin committed Apr 25, 2019
1 parent 587d480 commit 22377ed
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 21 deletions.
29 changes: 9 additions & 20 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,12 @@ namespace mshadow_op {
__constant__ const float PI = 3.14159265358979323846;
__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
__constant__ const float GELU_CUBIC_CONSTANT = 0.044715;
__constant__ const float GELU_ROOT_2_OVER_PI = 0.7978845608028654;
__constant__ const float SQRT_2 = 1.4142135623730950488016887242096;
#else
const float PI = 3.14159265358979323846;
const float SELU_ALPHA = 1.6732632423543772848170429916717;
const float SELU_LAMBDA = 1.0507009873554804934193349852946;
const float GELU_CUBIC_CONSTANT = 0.044715;
const float GELU_ROOT_2_OVER_PI = 0.7978845608028654;
const float SQRT_2 = 1.4142135623730950488016887242096;
using std::isnan;
#endif
using std::enable_if;
Expand Down Expand Up @@ -131,21 +129,6 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));

#define MXNET_GELU_GX(a) \
a * (DType(1.0f) + DType(GELU_CUBIC_CONSTANT) * a * a)

#define MXNET_GELU_GX_GRAD(a) \
(DType(1.0f) + DType(3.0f * GELU_CUBIC_CONSTANT) * a * a)

#define MXNET_GELU_TANH(a) \
math::tanh(DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX(a))

MXNET_UNARY_MATH_OP(gelu, DType(0.5f) * a * (DType(1.0f) + MXNET_GELU_TANH(a)));

MXNET_BINARY_MATH_OP_NC(gelu_grad,
b / a + b * (DType(1.0f) - MXNET_GELU_TANH(a)) *
DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX_GRAD(a));

MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));

Expand Down Expand Up @@ -191,6 +174,13 @@ MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a)));

MXNET_SIMPLE_UNARY_MATH_OP(erf);

MXNET_UNARY_MATH_OP(gelu,
DType(0.5f * static_cast<float>(a) * (1.0f + math::erf(static_cast<float>(a) / SQRT_2))));

MXNET_BINARY_MATH_OP_NC(gelu_grad,
DType(0.5f * (1.0f + math::erf(static_cast<float>(a) / SQRT_2) +
static_cast<float>(a) * erf_grad::Map(static_cast<float>(a) / SQRT_2) / SQRT_2)));

MXNET_SIMPLE_UNARY_MATH_OP(exp);

MXNET_SIMPLE_UNARY_MATH_OP(expm1);
Expand Down Expand Up @@ -355,7 +345,6 @@ MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));
MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));

MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));

MXNET_UNARY_MATH_OP(reciprocal_square_root, 1.0f / math::sqrt(a));

MXNET_UNARY_MATH_OP(reciprocal_square_root_grad, -0.5f / (math::sqrt(a) * math::id(a)));
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def fgelu_grad(grad, x, y):
y = mx.sym.LeakyReLU(data=x, act_type="gelu")
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
eps, rtol, atol = (7.5e-4, 2e-2, 1e-3) if dtype is np.float16 else (1e-4, 1e-3, 1e-5)
if dtype is np.float16:
xa /= 10.0
xa[abs(xa) < eps] = 0.01
Expand Down

0 comments on commit 22377ed

Please sign in to comment.