Skip to content

Commit a8b62af

Browse files
ArzaghiMattStephansoncbezaultStephanTLavavej
authored
<random>: Prevent distributions from returning NaN/Inf (#1228)
Co-authored-by: MattStephanson <[email protected]> Co-authored-by: Curtis J Bezault <[email protected]> Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent 974582f commit a8b62af

File tree

2 files changed

+67
-13
lines changed
  • stl/inc
  • tests/std/tests/GH_001017_discrete_distribution_out_of_range

2 files changed

+67
-13
lines changed

stl/inc/random

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2990,15 +2990,32 @@ private:
29902990
_Ty _Vx1;
29912991
_Ty _Vx2;
29922992
_Ty _Sx;
2993-
for (;;) { // reject bad values
2993+
for (;;) { // reject bad values to avoid generating NaN/Inf on the next calculations
29942994
_Vx1 = 2 * _NRAND(_Eng, _Ty) - 1;
29952995
_Vx2 = 2 * _NRAND(_Eng, _Ty) - 1;
29962996
_Sx = _Vx1 * _Vx1 + _Vx2 * _Vx2;
2997-
if (_Sx < 1) {
2997+
if (_Sx < _Ty{1} && _Vx1 != _Ty{0} && _Vx2 != _Ty{0}) {
2998+
// good values!
29982999
break;
29993000
}
30003001
}
3001-
const auto _Fx = static_cast<_Ty>(_CSTD sqrt(-2.0 * _CSTD log(_Sx) / _Sx));
3002+
3003+
_Ty _LogSx;
3004+
if (_Sx > _Ty{1e-4}) {
3005+
_LogSx = _STD log(_Sx);
3006+
} else {
3007+
// Bad _Sx value! Very small values will overflow log(_Sx) / _Sx.
3008+
// Generate a new value based on scaling method.
3009+
const _Ty _Ln2{_Ty{0.69314718055994530941723212145818}};
3010+
const _Ty _Maxabs{(_STD max)(_STD abs(_Vx1), _STD abs(_Vx2))};
3011+
const int _ExpMax{_STD ilogb(_Maxabs)};
3012+
_Vx1 = _STD scalbn(_Vx1, -_ExpMax);
3013+
_Vx2 = _STD scalbn(_Vx2, -_ExpMax);
3014+
_Sx = _Vx1 * _Vx1 + _Vx2 * _Vx2;
3015+
_LogSx = _STD log(_Sx) + static_cast<_Ty>(_ExpMax) * (_Ln2 * 2);
3016+
}
3017+
3018+
const auto _Fx = _Ty{_STD sqrt(_Ty{-2} * _LogSx / _Sx)};
30023019
if (_Keep) { // save second value for next call
30033020
_Xx2 = _Fx * _Vx2;
30043021
_Valid = true;
@@ -3922,19 +3939,29 @@ public:
39223939
_Px1 = _CSTD pow(_Px1, _Ty{1} / _Ax);
39233940
_Px2 = _CSTD pow(_Px2, _Ty{1} / _Bx);
39243941
_Wx = _Px1 + _Px2;
3925-
if (_Wx <= _Ty{1}) {
3942+
if (_Wx <= _Ty{1} && _Wx != _Ty{0}) {
39263943
break;
39273944
}
39283945
}
39293946
return _Px1 / _Wx;
39303947
} else { // use gamma distributions instead
39313948
_Ty _Px1;
39323949
_Ty _Px2;
3950+
_Ty _PSum;
39333951
gamma_distribution<_Ty> _Dist1(_Ax, 1);
39343952
gamma_distribution<_Ty> _Dist2(_Bx, 1);
3935-
_Px1 = _Dist1(_Eng);
3936-
_Px2 = _Dist2(_Eng);
3937-
return _Px1 / (_Px1 + _Px2);
3953+
3954+
for (;;) { // reject pairs whose sum is zero
3955+
_Px1 = _Dist1(_Eng);
3956+
_Px2 = _Dist2(_Eng);
3957+
_PSum = _Px1 + _Px2;
3958+
3959+
if (_PSum != _Ty{0}) {
3960+
break;
3961+
}
3962+
}
3963+
3964+
return _Px1 / _PSum;
39383965
}
39393966
}
39403967

@@ -4059,12 +4086,18 @@ private:
40594086
_Ty _Px;
40604087
_Ty _Vx1;
40614088
_Ty _Vx2;
4089+
const _Ty _Vx3{1};
40624090
_Vx1 = static_cast<_Ty>(_Par0._Mx) * static_cast<_Ty>(0.5);
40634091
_Vx2 = static_cast<_Ty>(_Par0._Nx) * static_cast<_Ty>(0.5);
40644092
_Beta_distribution<_Ty> _Dist(_Vx1, _Vx2);
4065-
_Px = _Dist(_Eng);
4093+
for (;;) { // reject bad values
4094+
_Px = _Dist(_Eng);
4095+
if (_Px != _Vx3) {
4096+
break;
4097+
}
4098+
}
40664099

4067-
return (_Vx2 / _Vx1) * (_Px / (_Ty{1} - _Px));
4100+
return (_Vx2 / _Vx1) * (_Px / (_Vx3 - _Px));
40684101
}
40694102

40704103
param_type _Par;
@@ -4195,13 +4228,13 @@ private:
41954228
_Vx2 = _Dist(_Eng);
41964229
_Rs = _Vx1 * _Vx1 + _Vx2 * _Vx2;
41974230

4198-
if (_Rs < _Ty{1}) {
4231+
// very small _Rs will overflow on pow(_Rx0, -_Ty{4} / _Par0._Nx)
4232+
if (_Rs < _Ty{1} && _Rs > _Ty{1e-12}) {
41994233
break;
42004234
}
42014235
}
4202-
_Rx0 = _CSTD sqrt(_Rs);
4203-
4204-
return _Vx1 * _CSTD sqrt(_Par0._Nx * (_CSTD pow(_Rx0, -_Ty{4} / _Par0._Nx) - _Ty{1}) / _Rs);
4236+
_Rx0 = _STD sqrt(_Rs);
4237+
return _Vx1 * _STD sqrt(_Par0._Nx * (_STD pow(_Rx0, -_Ty{4} / _Par0._Nx) - _Ty{1}) / _Rs);
42054238
}
42064239

42074240
param_type _Par;

tests/std/tests/GH_001017_discrete_distribution_out_of_range/test.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,23 @@
66

77
#include "bad_random_engine.hpp"
88

9+
template <class Distribution>
10+
void Test_for_NaN_Inf(Distribution&& distribution) {
11+
for (bad_random_generator rng; !rng.has_cycled_through();) {
12+
const auto rand_value = distribution(rng);
13+
assert(!isnan(rand_value) && !isinf(rand_value));
14+
}
15+
}
16+
17+
template <class T>
18+
void Test_distributions() {
19+
// Additionally test GH-1174 "<random>: Some random number distributions could return NaN"
20+
Test_for_NaN_Inf(std::normal_distribution<T>{});
21+
Test_for_NaN_Inf(std::lognormal_distribution<T>{});
22+
Test_for_NaN_Inf(std::fisher_f_distribution<T>{});
23+
Test_for_NaN_Inf(std::student_t_distribution<T>{});
24+
}
25+
926
int main() {
1027
std::discrete_distribution<int> dist{1, 1, 1, 1, 1, 1};
1128
bad_random_generator rng;
@@ -14,4 +31,8 @@ int main() {
1431
const auto rand_value = dist(rng);
1532
assert(0 <= rand_value && rand_value < 6);
1633
}
34+
35+
Test_distributions<float>();
36+
Test_distributions<double>();
37+
Test_distributions<long double>();
1738
}

0 commit comments

Comments
 (0)