diff --git a/stl/inc/vector b/stl/inc/vector index 86b6216ecef..351fe88c0e9 100644 --- a/stl/inc/vector +++ b/stl/inc/vector @@ -2998,6 +2998,45 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_vbool(_InIt _First, const _InIt _Last, const return _Last; } + +template +_NODISCARD _CONSTEXPR20 _Iter_diff_t<_InIt> _Count_vbool(_InIt _First, const _InIt _Last, const _Ty& _Val) { + _Iter_diff_t<_InIt> _Count = 0; + + if (_First == _Last) { + return _Count; + } + + _Vbase* _VbFirst = const_cast<_Vbase*>(_First._Myptr); + const _Vbase* const _VbLast = _Last._Myptr; + + const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff; + + if (_VbFirst == _VbLast) { + // We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe + const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff); + const auto _SourceMask = _FirstSourceMask & _LastSourceMask; + const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask; + return _STD popcount(_SelectVal); + } + + const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask; + _Count += _STD popcount(_FirstVal); + ++_VbFirst; + + for (; _VbFirst != _VbLast; ++_VbFirst) { + const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst; + _Count += _STD popcount(_SelectVal); + } + + if (_Last._Myoff != 0) { + const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff); + const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask; + _Count += _STD popcount(_LastVal); + } + + return _Count; +} #endif // _HAS_IF_CONSTEXPR _STD_END diff --git a/stl/inc/xutility b/stl/inc/xutility index fea2d3f274a..a70d3aa0129 100644 --- a/stl/inc/xutility +++ b/stl/inc/xutility @@ -5559,17 +5559,24 @@ template _NODISCARD _CONSTEXPR20 _Iter_diff_t<_InIt> count(const _InIt _First, const _InIt _Last, const _Ty& _Val) { // count elements that match _Val _Adl_verify_range(_First, _Last); - auto _UFirst = _Get_unwrapped(_First); - const auto _ULast = _Get_unwrapped(_Last); - _Iter_diff_t<_InIt> _Count = 0; +#ifdef __cpp_lib_bitops + if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) { + return _Count_vbool(_First, _Last, _Val); + } else +#endif // __cpp_lib_bitops + { + auto _UFirst = _Get_unwrapped(_First); + const auto _ULast = _Get_unwrapped(_Last); + _Iter_diff_t<_InIt> _Count = 0; - for (; _UFirst != _ULast; ++_UFirst) { - if (*_UFirst == _Val) { - ++_Count; + for (; _UFirst != _ULast; ++_UFirst) { + if (*_UFirst == _Val) { + ++_Count; + } } - } - return _Count; + return _Count; + } } #if _HAS_CXX17 diff --git a/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp b/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp index fadd3fac3c2..5452f296ca6 100644 --- a/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp +++ b/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp @@ -148,7 +148,72 @@ bool test_find() { return true; } +void test_count_helper(const ptrdiff_t length) { + // This test data is not random, but irregular enough to ensure confidence in the tests + // clang-format off + vector source = { true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false + }; + const int counts_true[8] = { 0, 1, 1, 2, 2, 3, 4, 5 }; + const int counts_false[8] = { 0, 0, 1, 1, 2, 2, 2, 2 }; + // clang-format on + const auto expected = div(int(length), 8); + // No offset + { + const auto result_true = static_cast(count(source.cbegin(), next(source.cbegin(), length), true)); + assert(result_true == expected.quot * 5 + counts_true[expected.rem]); + + const auto result_false = static_cast(count(source.cbegin(), next(source.cbegin(), length), false)); + assert(result_false == expected.quot * 3 + counts_false[expected.rem]); + } + + // With offset + { + const auto result_true = + static_cast(count(next(source.cbegin(), 2), next(source.cbegin(), length + 2), true)); + assert(result_true == expected.quot * 5 + counts_true[expected.rem]); + + const auto result_false = + static_cast(count(next(source.cbegin(), 2), next(source.cbegin(), length + 2), false)); + assert(result_false == expected.quot * 3 + counts_false[expected.rem]); + } +} + +bool test_count() { + // Empty range + test_count_helper(0); + + // One block, ends within block + test_count_helper(15); + + // One block, ends at block boundary + test_count_helper(blockSize); + + // Multiple blocks, within block + test_count_helper(3 * blockSize + 8); + + // Multiple blocks, ends at block boundary + test_count_helper(4 * blockSize); + return true; +} + int main() { test_fill(); test_find(); + test_count(); }