diff --git a/stl/inc/random b/stl/inc/random index ebb66b92c05..79c11840bf9 100644 --- a/stl/inc/random +++ b/stl/inc/random @@ -2932,15 +2932,32 @@ private: _Ty _Vx1; _Ty _Vx2; _Ty _Sx; - for (;;) { // reject bad values + for (;;) { // reject bad values to avoid generating NaN/Inf on the next calculations _Vx1 = 2 * _NRAND(_Eng, _Ty) - 1; _Vx2 = 2 * _NRAND(_Eng, _Ty) - 1; _Sx = _Vx1 * _Vx1 + _Vx2 * _Vx2; - if (_Sx < 1) { + if (_Sx < _Ty{1} && _Vx1 != _Ty{0} && _Vx2 != _Ty{0}) { + // good values! break; } } - const auto _Fx = static_cast<_Ty>(_CSTD sqrt(-2.0 * _CSTD log(_Sx) / _Sx)); + + _Ty _LogSx; + if (_Sx > _Ty{1e-4}) { + _LogSx = _STD log(_Sx); + } else { + // Bad _Sx value! Very small values will overflow log(_Sx) / _Sx. + // Generate a new value based on scaling method. + const _Ty _Ln2{_Ty{0.69314718055994530941723212145818}}; + const _Ty _Maxabs{(_STD max)(_STD abs(_Vx1), _STD abs(_Vx2))}; + const int _ExpMax{_STD ilogb(_Maxabs)}; + _Vx1 = _STD scalbn(_Vx1, -_ExpMax); + _Vx2 = _STD scalbn(_Vx2, -_ExpMax); + _Sx = _Vx1 * _Vx1 + _Vx2 * _Vx2; + _LogSx = _STD log(_Sx) + static_cast<_Ty>(_ExpMax) * (_Ln2 * 2); + } + + const auto _Fx = _Ty{_STD sqrt(_Ty{-2} * _LogSx / _Sx)}; if (_Keep) { // save second value for next call _Xx2 = _Fx * _Vx2; _Valid = true; @@ -3864,7 +3881,7 @@ public: _Px1 = _CSTD pow(_Px1, _Ty{1} / _Ax); _Px2 = _CSTD pow(_Px2, _Ty{1} / _Bx); _Wx = _Px1 + _Px2; - if (_Wx <= _Ty{1}) { + if (_Wx <= _Ty{1} && _Wx != _Ty{0}) { break; } } @@ -3872,11 +3889,21 @@ public: } else { // use gamma distributions instead _Ty _Px1; _Ty _Px2; + _Ty _PSum; gamma_distribution<_Ty> _Dist1(_Ax, 1); gamma_distribution<_Ty> _Dist2(_Bx, 1); - _Px1 = _Dist1(_Eng); - _Px2 = _Dist2(_Eng); - return _Px1 / (_Px1 + _Px2); + + for (;;) { // reject pairs whose sum is zero + _Px1 = _Dist1(_Eng); + _Px2 = _Dist2(_Eng); + _PSum = _Px1 + _Px2; + + if (_PSum != _Ty{0}) { + break; + } + } + + return _Px1 / _PSum; } } @@ -4001,12 +4028,18 @@ private: _Ty _Px; _Ty _Vx1; _Ty _Vx2; + const _Ty _Vx3{1}; _Vx1 = static_cast<_Ty>(_Par0._Mx) * static_cast<_Ty>(0.5); _Vx2 = static_cast<_Ty>(_Par0._Nx) * static_cast<_Ty>(0.5); _Beta_distribution<_Ty> _Dist(_Vx1, _Vx2); - _Px = _Dist(_Eng); + for (;;) { // reject bad values + _Px = _Dist(_Eng); + if (_Px != _Vx3) { + break; + } + } - return (_Vx2 / _Vx1) * (_Px / (_Ty{1} - _Px)); + return (_Vx2 / _Vx1) * (_Px / (_Vx3 - _Px)); } param_type _Par; @@ -4137,13 +4170,13 @@ private: _Vx2 = _Dist(_Eng); _Rs = _Vx1 * _Vx1 + _Vx2 * _Vx2; - if (_Rs < _Ty{1}) { + // very small _Rs will overflow on pow(_Rx0, -_Ty{4} / _Par0._Nx) + if (_Rs < _Ty{1} && _Rs > _Ty{1e-12}) { break; } } - _Rx0 = _CSTD sqrt(_Rs); - - return _Vx1 * _CSTD sqrt(_Par0._Nx * (_CSTD pow(_Rx0, -_Ty{4} / _Par0._Nx) - _Ty{1}) / _Rs); + _Rx0 = _STD sqrt(_Rs); + return _Vx1 * _STD sqrt(_Par0._Nx * (_STD pow(_Rx0, -_Ty{4} / _Par0._Nx) - _Ty{1}) / _Rs); } param_type _Par; diff --git a/tests/std/tests/GH_001017_discrete_distribution_out_of_range/test.cpp b/tests/std/tests/GH_001017_discrete_distribution_out_of_range/test.cpp index 5a43af9b4df..d4ded00b997 100644 --- a/tests/std/tests/GH_001017_discrete_distribution_out_of_range/test.cpp +++ b/tests/std/tests/GH_001017_discrete_distribution_out_of_range/test.cpp @@ -6,6 +6,23 @@ #include "bad_random_engine.hpp" +template +void Test_for_NaN_Inf(Distribution&& distribution) { + for (bad_random_generator rng; !rng.has_cycled_through();) { + const auto rand_value = distribution(rng); + assert(!isnan(rand_value) && !isinf(rand_value)); + } +} + +template +void Test_distributions() { + // Additionally test GH-1174 ": Some random number distributions could return NaN" + Test_for_NaN_Inf(std::normal_distribution{}); + Test_for_NaN_Inf(std::lognormal_distribution{}); + Test_for_NaN_Inf(std::fisher_f_distribution{}); + Test_for_NaN_Inf(std::student_t_distribution{}); +} + int main() { std::discrete_distribution dist{1, 1, 1, 1, 1, 1}; bad_random_generator rng; @@ -14,4 +31,8 @@ int main() { const auto rand_value = dist(rng); assert(0 <= rand_value && rand_value < 6); } + + Test_distributions(); + Test_distributions(); + Test_distributions(); }