diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index fa8eee9d2989..c7dc83176e14 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -18,7 +18,7 @@ # coding: utf-8 # pylint: disable= arguments-differ """Basic neural network layers.""" -__all__ = ['Activation', 'LeakyReLU', 'PReLU', 'ELU', 'SELU', 'Swish'] +__all__ = ['Activation', 'LeakyReLU', 'PReLU', 'ELU', 'SELU', 'Swish', 'GELU'] from ... import initializer from ..block import HybridBlock @@ -180,6 +180,25 @@ def __init__(self, **kwargs): def hybrid_forward(self, F, x): return F.LeakyReLU(x, act_type='selu', name='fwd') +class GELU(HybridBlock): + r""" + Gaussian Exponential Linear Unit (GELU) + "Gaussian Error Linear Units (GELUs)", Hendrycks et al, 2016 + https://arxiv.org/abs/1606.08415 + + + Inputs: + - **data**: input tensor with arbitrary shape. + + Outputs: + - **out**: output tensor with the same shape as `data`. + """ + def __init__(self, **kwargs): + super(GELU, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + return F.LeakyReLU(x, act_type='gelu', name='fwd') + class Swish(HybridBlock): r""" diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index c7fa3f0443ee..cfdd1064d6fb 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -47,7 +47,7 @@ namespace op { namespace leakyrelu { enum LeakyReLUOpInputs {kData, kGamma}; enum LeakyReLUOpOutputs {kOut, kMask}; -enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU}; +enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU, kGELU}; enum LeakyReLUOpResource {kRandom}; } // namespace leakyrelu @@ -64,6 +64,7 @@ struct LeakyReLUParam : public dmlc::Parameter { .add_enum("prelu", leakyrelu::kPReLU) .add_enum("elu", leakyrelu::kELU) .add_enum("selu", leakyrelu::kSELU) + .add_enum("gelu", leakyrelu::kGELU) .describe("Activation function to be applied."); DMLC_DECLARE_FIELD(slope).set_default(0.25f) .describe("Init slope for the activation. (For leaky and elu only)"); @@ -190,6 +191,13 @@ class LeakyReLUOp : public Operator { }); break; } + case leakyrelu::kGELU: { + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_); + }); + break; + } default: LOG(FATAL) << "Not implmented"; } @@ -223,7 +231,7 @@ class LeakyReLUOp : public Operator { if (param_.act_type == leakyrelu::kRReLU) { mask = out_data[leakyrelu::kMask].get_with_shape(dshape, s); } - if (param_.act_type == leakyrelu::kPReLU) { + if (param_.act_type == leakyrelu::kPReLU || param_.act_type == leakyrelu::kGELU) { data = in_data[leakyrelu::kData].get_with_shape(dshape, s); } switch (param_.act_type) { @@ -287,6 +295,15 @@ class LeakyReLUOp : public Operator { }); break; } + case leakyrelu::kGELU: { + MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, { + mxnet_op::Kernel, Req>, xpu>::Launch( + s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_, + data.dptr_, output.dptr_); + }); + break; + } default: LOG(FATAL) << "Not implmented"; } diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 6c3ef3fb650b..c27a98ac1940 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -45,10 +45,14 @@ namespace mshadow_op { __constant__ const float PI = 3.14159265358979323846; __constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717; __constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946; +__constant__ const float GELU_CUBIC_CONSTANT = 0.044715; +__constant__ const float GELU_ROOT_2_OVER_PI = 0.7978845608028654; #else const float PI = 3.14159265358979323846; const float SELU_ALPHA = 1.6732632423543772848170429916717; const float SELU_LAMBDA = 1.0507009873554804934193349852946; +const float GELU_CUBIC_CONSTANT = 0.044715; +const float GELU_ROOT_2_OVER_PI = 0.7978845608028654; using std::isnan; #endif using std::enable_if; @@ -127,6 +131,21 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a))); MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a))); +#define MXNET_GELU_GX(a) \ + a * (DType(1.0f) + DType(GELU_CUBIC_CONSTANT) * a * a) + +#define MXNET_GELU_GX_GRAD(a) \ + (DType(1.0f) + DType(3.0f * GELU_CUBIC_CONSTANT) * a * a) + +#define MXNET_GELU_TANH(a) \ + math::tanh(DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX(a)) + +MXNET_UNARY_MATH_OP(gelu, DType(0.5f) * a * (DType(1.0f) + MXNET_GELU_TANH(a))); + +MXNET_BINARY_MATH_OP_NC(gelu_grad, + b / a + b * (DType(1.0f) - MXNET_GELU_TANH(a)) * + DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX_GRAD(a)); + MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) * (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a)))); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 56d35b23b369..98ce14e7bf05 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -219,6 +219,7 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gelu); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu); // NOLINT() @@ -328,6 +329,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum); // NOLINT() diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 6af7a5f948e2..9c69b126e3fb 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1070,10 +1070,10 @@ def elu(x): def selu_test(x): def selu(x): scale, alpha = 1.0507009873554804934193349852946, 1.6732632423543772848170429916717 - return scale * x if x >= 0 else alpha * mx.nd.exp(x) - alpha + return scale * x if x >= 0 else scale * alpha * mx.nd.expm1(x) return [selu(x_i) for x_i in x] - for test_point, ref_point in zip(selu(point_to_validate), selu(point_to_validate)): + for test_point, ref_point in zip(selu_test(point_to_validate), selu(point_to_validate)): assert test_point == ref_point prelu = mx.gluon.nn.PReLU() @@ -1081,6 +1081,20 @@ def selu(x): x = point_to_validate.reshape((1, 3, 2)) assert_almost_equal(prelu(x).asnumpy(), mx.nd.where(x >= 0, x, 0.25 * x).asnumpy()) + gelu = mx.gluon.nn.GELU() + def gelu_test(x): + CUBE_CONSTANT = 0.044715 + ROOT_TWO_OVER_PI = 0.7978845608028654 + def g(x): + return ROOT_TWO_OVER_PI * (x + CUBE_CONSTANT * x * x * x) + def f(x): + return 1.0 + mx.nd.tanh(g(x)) + def gelu(x): + return 0.5 * x * f(x) + for test_point, ref_point in zip(gelu_test(point_to_validate), gelu(point_to_validate)): + assert test_point == ref_point + + @with_seed() def test_dropout(): def get_slice(x, axis, idx): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c9498ecb0bd2..afb5a8e11b4f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -862,6 +862,39 @@ def fselu_grad(grad, x, y): check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype) +@with_seed() +def test_gelu(): + CUBE_CONSTANT = 0.044715 + ROOT_TWO_OVER_PI = 0.7978845608028654 + def g(x): + return ROOT_TWO_OVER_PI * (x + CUBE_CONSTANT * np.power(x, 3)) + def g_grad(x): + return ROOT_TWO_OVER_PI * (1.0 + 3.0 * CUBE_CONSTANT * np.power(x, 2)) + def f(x): + return 1.0 + np.tanh(g(x)) + def f_grad(x): + return (1.0 - np.tanh(g(x)) * np.tanh(g(x))) * g_grad(x) + def fgelu(x): + return 0.5 * x * f(x) + def fgelu_grad(grad, x, y): + return grad * (y / x + y * (1 - np.tanh(g(x))) * g_grad(x)) + + shape = (3, 4) + x = mx.sym.Variable("x") + y = mx.sym.LeakyReLU(data=x, act_type="gelu") + for dtype in [np.float16, np.float32, np.float64]: + xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype) + eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4) + if dtype is np.float16: + xa /= 10.0 + xa[abs(xa) < eps] = 0.01 + ya = fgelu(xa) + ga = fgelu_grad(np.ones(shape).astype(dtype), xa, ya) + check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype) + check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype) + check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype) + + @with_seed() def test_sigmoid(): def fsigmoid(a):