Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize find and count for vector<bool> #1131

Merged
merged 11 commits into from
Jun 12, 2021
50 changes: 2 additions & 48 deletions stl/inc/bit
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,6 @@ _NODISCARD constexpr _Ty rotr(const _Ty _Val, const int _Rotation) noexcept {
}
}

// Implementation of popcount without using specialized CPU instructions.
// Used at compile time and when said instructions are not supported.
template <class _Ty>
_NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
// we static_cast these bit patterns in order to truncate them to the correct size
_Val = static_cast<_Ty>(_Val - ((_Val >> 1) & static_cast<_Ty>(0x5555'5555'5555'5555ull)));
_Val = static_cast<_Ty>((_Val & static_cast<_Ty>(0x3333'3333'3333'3333ull))
+ ((_Val >> 2) & static_cast<_Ty>(0x3333'3333'3333'3333ull)));
_Val = static_cast<_Ty>((_Val + (_Val >> 4)) & static_cast<_Ty>(0x0F0F'0F0F'0F0F'0F0Full));
for (int _Shift_digits = 8; _Shift_digits < _Digits; _Shift_digits <<= 1) {
_Val = static_cast<_Ty>(_Val + static_cast<_Ty>(_Val >> _Shift_digits));
}
// we want the bottom "slot" that's big enough to store _Digits
return static_cast<int>(_Val & static_cast<_Ty>(_Digits + _Digits - 1));
}

#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))

extern "C" {
Expand Down Expand Up @@ -184,32 +167,8 @@ _NODISCARD int _Checked_x86_x64_countl_zero(const _Ty _Val) noexcept {
}
#endif // __AVX2__
}

template <class _Ty>
_NODISCARD int _Checked_x86_x64_popcount(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
#ifndef __AVX__
const bool _Definitely_have_popcnt = __isa_available >= __ISA_AVAILABLE_SSE42;
if (!_Definitely_have_popcnt) {
return _Popcount_fallback(_Val);
}
#endif // !defined(__AVX__)

if constexpr (_Digits <= 16) {
return static_cast<int>(__popcnt16(_Val));
} else if constexpr (_Digits == 32) {
return static_cast<int>(__popcnt(_Val));
} else {
#ifdef _M_IX86
return static_cast<int>(__popcnt(_Val >> 32) + __popcnt(static_cast<unsigned int>(_Val)));
#else // ^^^ _M_IX86 / !_M_IX86 vvv
return static_cast<int>(__popcnt64(_Val));
#endif // _M_IX86
}
}
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))


#if defined(_M_ARM) || defined(_M_ARM64)
#ifdef __clang__ // TRANSITION, GH-1586
_NODISCARD constexpr int _Clang_arm_arm64_countl_zero(const unsigned short _Val) {
Expand Down Expand Up @@ -283,14 +242,9 @@ _NODISCARD constexpr int countr_one(const _Ty _Val) noexcept {
return _Countr_zero(static_cast<_Ty>(~_Val));
}

template <class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> _Enabled = 0>
template <class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> = 0>
_NODISCARD constexpr int popcount(const _Ty _Val) noexcept {
#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
if (!_STD is_constant_evaluated()) {
return _Checked_x86_x64_popcount(_Val);
}
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
return _Popcount_fallback(_Val);
return _Popcount(_Val);
}

enum class endian { little = 0, big = 1, native = little };
Expand Down
56 changes: 55 additions & 1 deletion stl/inc/limits
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,23 @@ _NODISCARD constexpr int _Countr_zero_fallback(const _Ty _Val) noexcept {
return _Digits - _Countl_zero_fallback(static_cast<_Ty>(static_cast<_Ty>(~_Val) & static_cast<_Ty>(_Val - 1)));
}

// Implementation of popcount without using specialized CPU instructions.
// Used at compile time and when said instructions are not supported.
template <class _Ty>
_NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
// we static_cast these bit patterns in order to truncate them to the correct size
_Val = static_cast<_Ty>(_Val - ((_Val >> 1) & static_cast<_Ty>(0x5555'5555'5555'5555ull)));
_Val = static_cast<_Ty>((_Val & static_cast<_Ty>(0x3333'3333'3333'3333ull))
+ ((_Val >> 2) & static_cast<_Ty>(0x3333'3333'3333'3333ull)));
_Val = static_cast<_Ty>((_Val + (_Val >> 4)) & static_cast<_Ty>(0x0F0F'0F0F'0F0F'0F0Full));
for (int _Shift_digits = 8; _Shift_digits < _Digits; _Shift_digits <<= 1) {
_Val = static_cast<_Ty>(_Val + static_cast<_Ty>(_Val >> _Shift_digits));
}
// we want the bottom "slot" that's big enough to store _Digits
return static_cast<int>(_Val & static_cast<_Ty>(_Digits + _Digits - 1));
}

#if defined(_M_IX86) || defined(_M_X64)
extern "C" {
extern int __isa_available;
Expand Down Expand Up @@ -1092,6 +1109,31 @@ _NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept {
#undef _TZCNT_U64
#endif // defined(_M_IX86) || defined(_M_X64)

#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
template <class _Ty>
_NODISCARD int _Checked_x86_x64_popcount(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
#ifndef __AVX__
const bool _Definitely_have_popcnt = __isa_available >= __ISA_AVAILABLE_SSE42;
if (!_Definitely_have_popcnt) {
return _Popcount_fallback(_Val);
}
#endif // !defined(__AVX__)

if constexpr (_Digits <= 16) {
return static_cast<int>(__popcnt16(_Val));
} else if constexpr (_Digits == 32) {
return static_cast<int>(__popcnt(_Val));
} else {
#ifdef _M_IX86
return static_cast<int>(__popcnt(_Val >> 32) + __popcnt(static_cast<unsigned int>(_Val)));
#else // ^^^ _M_IX86 / !_M_IX86 vvv
return static_cast<int>(__popcnt64(_Val));
#endif // _M_IX86
}
}
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))

template <class _Ty>
constexpr bool _Is_standard_unsigned_integer =
_Is_any_of_v<remove_cv_t<_Ty>, unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long>;
Expand All @@ -1103,12 +1145,24 @@ _NODISCARD constexpr int _Countr_zero(const _Ty _Val) noexcept {
if (!_STD is_constant_evaluated()) {
return _Checked_x86_x64_countr_zero(_Val);
}
#endif // defined(__cpp_lib_is_constant_evaluated)
#endif // __cpp_lib_is_constant_evaluated
#endif // defined(_M_IX86) || defined(_M_X64)
// C++17 constexpr gcd() calls this function, so it should be constexpr unless we detect runtime evaluation.
return _Countr_zero_fallback(_Val);
}

template <class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> _Enabled = 0>
_NODISCARD constexpr int _Popcount(const _Ty _Val) noexcept {
#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
#ifdef __cpp_lib_is_constant_evaluated
if (!_STD is_constant_evaluated()) {
return _Checked_x86_x64_popcount(_Val);
}
#endif // __cpp_lib_is_constant_evaluated
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
return _Popcount_fallback(_Val);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
}

_STD_END
#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
Expand Down
92 changes: 90 additions & 2 deletions stl/inc/vector
Original file line number Diff line number Diff line change
Expand Up @@ -3123,8 +3123,8 @@ _INLINE_VAR constexpr bool _Is_vb_iterator<_Vb_iterator<_Alloc>, _RequiresMutabl
template <class _Alloc>
_INLINE_VAR constexpr bool _Is_vb_iterator<_Vb_const_iterator<_Alloc>, false> = true;

template <class _FwdIt, class _Ty>
_CONSTEXPR20 void _Fill_vbool(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) {
template <class _VbIt, class _Ty>
_CONSTEXPR20 void _Fill_vbool(_VbIt _First, _VbIt _Last, const _Ty& _Val) {
// Set [_First, _Last) to _Val
if (_First == _Last) {
return;
Expand Down Expand Up @@ -3172,6 +3172,94 @@ _CONSTEXPR20 void _Fill_vbool(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) {
*_VbFirst = (*_VbFirst & _LastDestMask) | (_FillVal & _LastSourceMask);
}
}

template <class _VbIt, class _Ty>
_NODISCARD _CONSTEXPR20 _VbIt _Find_vbool(_VbIt _First, const _VbIt _Last, const _Ty& _Val) {
// Find _Val in [_First, _Last)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a super helpful comment and could be removed (we generally remove these sorts of comments when we see them, as there are many in older code). It doesn't need to block this PR though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep it as we have it in all others

if (_First == _Last) {
return _First;
}

const _Vbase* _VbFirst = _First._Myptr;
const _Vbase* const _VbLast = _Last._Myptr;

const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff;
barcharcraz marked this conversation as resolved.
Show resolved Hide resolved

if (_VbFirst == _VbLast) {
// We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _SourceMask = _FirstSourceMask & _LastSourceMask;
const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask;
const auto _Count = _Countr_zero(_SelectVal);
return _Count == _VBITS ? _Last : _First + static_cast<ptrdiff_t>(_Count - _First._Myoff);
}

const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask;
miscco marked this conversation as resolved.
Show resolved Hide resolved
const auto _FirstCount = _Countr_zero(_FirstVal);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
if (_FirstCount != _VBITS) {
return _First + static_cast<ptrdiff_t>(_FirstCount - _First._Myoff);
}
++_VbFirst;

_Iter_diff_t<_VbIt> _TotalCount = static_cast<ptrdiff_t>(_VBITS - _First._Myoff);
for (; _VbFirst != _VbLast; ++_VbFirst, _TotalCount += _VBITS) {
const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst;
const auto _Count = _Countr_zero(_SelectVal);
if (_Count != _VBITS) {
return _First + (_TotalCount + _Count);
}
}

if (_Last._Myoff != 0) {
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask;
const auto _Count = _Countr_zero(_LastVal);
if (_Count != _VBITS) {
return _First + (_TotalCount + _Count);
}
}

return _Last;
}

template <class _VbIt, class _Ty>
_NODISCARD _CONSTEXPR20 _Iter_diff_t<_VbIt> _Count_vbool(_VbIt _First, const _VbIt _Last, const _Ty& _Val) {
_Iter_diff_t<_VbIt> _Count = 0;

if (_First == _Last) {
return _Count;
}

const _Vbase* _VbFirst = _First._Myptr;
const _Vbase* const _VbLast = _Last._Myptr;

const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff;

if (_VbFirst == _VbLast) {
// We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _SourceMask = _FirstSourceMask & _LastSourceMask;
const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask;
return _Popcount(_SelectVal);
}

const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask;
_Count += _Popcount(_FirstVal);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
++_VbFirst;

for (; _VbFirst != _VbLast; ++_VbFirst) {
const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst;
_Count += _Popcount(_SelectVal);
}

if (_Last._Myoff != 0) {
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask;
_Count += _Popcount(_LastVal);
}

return _Count;
}
_STD_END

#pragma pop_macro("new")
Expand Down
28 changes: 18 additions & 10 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5345,8 +5345,12 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(const _InIt _First, const _InIt _L
template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _InIt find(_InIt _First, const _InIt _Last, const _Ty& _Val) { // find first matching _Val
_Adl_verify_range(_First, _Last);
_Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val));
return _First;
if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) {
return _Find_vbool(_First, _Last, _Val);
} else {
_Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val));
return _First;
}
}

#if _HAS_CXX17
Expand Down Expand Up @@ -5432,17 +5436,21 @@ template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _Iter_diff_t<_InIt> count(const _InIt _First, const _InIt _Last, const _Ty& _Val) {
// count elements that match _Val
_Adl_verify_range(_First, _Last);
auto _UFirst = _Get_unwrapped(_First);
const auto _ULast = _Get_unwrapped(_Last);
_Iter_diff_t<_InIt> _Count = 0;
if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) {
return _Count_vbool(_First, _Last, _Val);
} else {
auto _UFirst = _Get_unwrapped(_First);
const auto _ULast = _Get_unwrapped(_Last);
_Iter_diff_t<_InIt> _Count = 0;

for (; _UFirst != _ULast; ++_UFirst) {
if (*_UFirst == _Val) {
++_Count;
for (; _UFirst != _ULast; ++_UFirst) {
if (*_UFirst == _Val) {
++_Count;
}
}
}

return _Count;
return _Count;
}
}

#if _HAS_CXX17
Expand Down
Loading