Skip to content

Commit

Permalink
Fix SoftReLU fused operator numerical stability (apache#17849)
Browse files Browse the repository at this point in the history
* fix numerically unstable fused softrelu op

* implement test for softrelu numerical stability
  • Loading branch information
RuRo authored and mseth10 committed Oct 20, 2020
1 parent 4e4dfd2 commit e3d25ae
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,10 @@ __device__ inline DType sigmoid(const DType val) {

template <typename DType>
__device__ inline DType softrelu(const DType val) {
return logf(1 + expf(val));
// Avoid overflow of exp for large inputs.
// The threshold 20 is chosen such that softrelu(a) = a
// for a > 20 using floating precision.
return val > 20 ? val : logf(1 + expf(val));
}

template <typename DType>
Expand Down
3 changes: 3 additions & 0 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def announce_check(op_name):
for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']:
announce_check("Activation(act_type='{}')".format(act_type))
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr)
if act_type == 'softrelu':
# Check that softrelu implementation doesn't overflow on large inputs
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000 * arr)

# Cast requires dtype
for dtype in ['float16', 'float32', 'float64', 'int32']:
Expand Down

0 comments on commit e3d25ae

Please sign in to comment.