Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ddab468

Browse files
mseth10ruro
andauthored
Fix SoftReLU fused operator numerical stability (#17849) (#19390)
* fix numerically unstable fused softrelu op * implement test for softrelu numerical stability Co-authored-by: RuRo <[email protected]>
1 parent aba3aa3 commit ddab468

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/operator/fusion/fused_op-inl.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,10 @@ __device__ inline DType sigmoid(const DType val) {
566566
567567
template <typename DType>
568568
__device__ inline DType softrelu(const DType val) {
569-
return logf(1 + expf(val));
569+
// Avoid overflow of exp for large inputs.
570+
// The threshold 20 is chosen such that softrelu(a) = a
571+
// for a > 20 using floating precision.
572+
return val > 20 ? val : logf(1 + expf(val));
570573
}
571574
572575
template <typename DType>

tests/python/gpu/test_fusion.py

+3
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def announce_check(op_name):
138138
for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']:
139139
announce_check("Activation(act_type='{}')".format(act_type))
140140
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr)
141+
if act_type == 'softrelu':
142+
# Check that softrelu implementation doesn't overflow on large inputs
143+
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000 * arr)
141144

142145
# Cast requires dtype
143146
for dtype in ['float16', 'float32', 'float64', 'int32']:

0 commit comments

Comments
 (0)