From 65afc61bc58b8bc06d8cd288058da482c818bb45 Mon Sep 17 00:00:00 2001 From: RuRo Date: Wed, 18 Mar 2020 07:42:53 +0300 Subject: [PATCH] Fix SoftReLU fused operator numerical stability (#17849) * fix numerically unstable fused softrelu op * implement test for softrelu numerical stability --- src/operator/fusion/fused_op-inl.h | 5 ++++- tests/python/gpu/test_fusion.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index c838d8523baf..0b10f821d8e1 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -566,7 +566,10 @@ __device__ inline DType sigmoid(const DType val) { template __device__ inline DType softrelu(const DType val) { - return logf(1 + expf(val)); + // Avoid overflow of exp for large inputs. + // The threshold 20 is chosen such that softrelu(a) = a + // for a > 20 using floating precision. + return val > 20 ? val : logf(1 + expf(val)); } template diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 1bbf5982f45f..1febf8d8e23c 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -138,6 +138,9 @@ def announce_check(op_name): for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']: announce_check("Activation(act_type='{}')".format(act_type)) check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr) + if act_type == 'softrelu': + # Check that softrelu implementation doesn't overflow on large inputs + check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000 * arr) # Cast requires dtype for dtype in ['float16', 'float32', 'float64', 'int32']: