Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix GELU backward possible NaN #14782

Merged
merged 2 commits into from
Apr 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 9 additions & 20 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,12 @@ namespace mshadow_op {
__constant__ const float PI = 3.14159265358979323846;
__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
__constant__ const float GELU_CUBIC_CONSTANT = 0.044715;
haojin2 marked this conversation as resolved.
Show resolved Hide resolved
__constant__ const float GELU_ROOT_2_OVER_PI = 0.7978845608028654;
__constant__ const float SQRT_2 = 1.4142135623730950488016887242096;
#else
const float PI = 3.14159265358979323846;
const float SELU_ALPHA = 1.6732632423543772848170429916717;
const float SELU_LAMBDA = 1.0507009873554804934193349852946;
const float GELU_CUBIC_CONSTANT = 0.044715;
const float GELU_ROOT_2_OVER_PI = 0.7978845608028654;
const float SQRT_2 = 1.4142135623730950488016887242096;
using std::isnan;
#endif
using std::enable_if;
Expand Down Expand Up @@ -131,21 +129,6 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));

#define MXNET_GELU_GX(a) \
a * (DType(1.0f) + DType(GELU_CUBIC_CONSTANT) * a * a)

#define MXNET_GELU_GX_GRAD(a) \
(DType(1.0f) + DType(3.0f * GELU_CUBIC_CONSTANT) * a * a)

#define MXNET_GELU_TANH(a) \
math::tanh(DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX(a))

MXNET_UNARY_MATH_OP(gelu, DType(0.5f) * a * (DType(1.0f) + MXNET_GELU_TANH(a)));

MXNET_BINARY_MATH_OP_NC(gelu_grad,
b / a + b * (DType(1.0f) - MXNET_GELU_TANH(a)) *
DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX_GRAD(a));

MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));

Expand Down Expand Up @@ -191,6 +174,13 @@ MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a)));

MXNET_SIMPLE_UNARY_MATH_OP(erf);

MXNET_UNARY_MATH_OP(gelu,
DType(0.5f * static_cast<float>(a) * (1.0f + math::erf(static_cast<float>(a) / SQRT_2))));

MXNET_BINARY_MATH_OP_NC(gelu_grad,
DType(0.5f * (1.0f + math::erf(static_cast<float>(a) / SQRT_2) +
static_cast<float>(a) * erf_grad::Map(static_cast<float>(a) / SQRT_2) / SQRT_2)));

MXNET_SIMPLE_UNARY_MATH_OP(exp);

MXNET_SIMPLE_UNARY_MATH_OP(expm1);
Expand Down Expand Up @@ -355,7 +345,6 @@ MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));
MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));

MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));

MXNET_UNARY_MATH_OP(reciprocal_square_root, 1.0f / math::sqrt(a));

MXNET_UNARY_MATH_OP(reciprocal_square_root_grad, -0.5f / (math::sqrt(a) * math::id(a)));
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def fgelu_grad(grad, x, y):
y = mx.sym.LeakyReLU(data=x, act_type="gelu")
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
eps, rtol, atol = (7.5e-4, 2e-2, 1e-3) if dtype is np.float16 else (1e-4, 1e-3, 1e-5)
if dtype is np.float16:
xa /= 10.0
xa[abs(xa) < eps] = 0.01
Expand Down