diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index c838d8523baf..0b10f821d8e1 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -566,7 +566,10 @@ __device__ inline DType sigmoid(const DType val) { template __device__ inline DType softrelu(const DType val) { - return logf(1 + expf(val)); + // Avoid overflow of exp for large inputs. + // The threshold 20 is chosen such that softrelu(a) = a + // for a > 20 using floating precision. + return val > 20 ? val : logf(1 + expf(val)); } template diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 61fba10913cc..a6be6c7d6629 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -136,6 +136,9 @@ def announce_check(op_name): for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']: announce_check("Activation(act_type='{}')".format(act_type)) check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr) + if act_type == 'softrelu': + # Check that softrelu implementation doesn't overflow on large inputs + check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000 * arr) # Cast requires dtype for dtype in ['float16', 'float32', 'float64', 'int32']: