Skip to content

Commit

Permalink
[vector] Add specialization of find
Browse files Browse the repository at this point in the history
Adresses #625
  • Loading branch information
miscco committed Nov 28, 2020
1 parent 19c683d commit d88f306
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
49 changes: 49 additions & 0 deletions stl/inc/vector
Original file line number Diff line number Diff line change
Expand Up @@ -2949,6 +2949,55 @@ _CONSTEXPR20 void _Fill_vbool(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) {
*_VbFirst = (*_VbFirst & _LastDestMask) | (_FillVal & _LastSourceMask);
}
}

template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _InIt _Find_vbool(_InIt _First, const _InIt _Last, const _Ty& _Val) {
// Find _Val in [_First, _Last)
if (_First == _Last) {
return _First;
}

_Vbase* _VbFirst = const_cast<_Vbase*>(_First._Myptr);
const _Vbase* const _VbLast = _Last._Myptr;

const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff;

if (_VbFirst == _VbLast) {
// We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _SourceMask = _FirstSourceMask & _LastSourceMask;
const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask;
const auto _Count = _Countr_zero(_SelectVal);
return _Count == _VBITS ? _Last : _First + static_cast<ptrdiff_t>(_Count - _First._Myoff);
}

const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask;
const auto _FirstCount = _Countr_zero(_FirstVal);
if (_FirstCount != _VBITS) {
return _First + static_cast<ptrdiff_t>(_FirstCount - _First._Myoff);
}
++_VbFirst;

_Iter_diff_t<_InIt> _TotalCount = static_cast<ptrdiff_t>(_VBITS - _First._Myoff);
for (; _VbFirst != _VbLast; ++_VbFirst, _TotalCount += _VBITS) {
const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst;
const auto _Count = _Countr_zero(_SelectVal);
if (_Count != _VBITS) {
return _First + (_TotalCount + _Count);
}
}

if (_Last._Myoff != 0) {
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask;
const auto _Count = _Countr_zero(_LastVal);
if (_Count != _VBITS) {
return _First + (_TotalCount + _Count);
}
}

return _Last;
}
#endif // _HAS_IF_CONSTEXPR
_STD_END

Expand Down
12 changes: 10 additions & 2 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5537,8 +5537,16 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(const _InIt _First, const _InIt _L
template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _InIt find(_InIt _First, const _InIt _Last, const _Ty& _Val) { // find first matching _Val
_Adl_verify_range(_First, _Last);
_Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val));
return _First;
#if _HAS_IF_CONSTEXPR
if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) {
return _Find_vbool(_First, _Last, _Val);
} else {
#endif // _HAS_IF_CONSTEXPR
_Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val));
return _First;
#if _HAS_IF_CONSTEXPR
}
#endif // _HAS_IF_CONSTEXPR
}

#if _HAS_CXX17
Expand Down
53 changes: 53 additions & 0 deletions tests/std/tests/GH_000625_vector_bool_optimization/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,59 @@ bool test_fill() {
return true;
}

void test_find_helper(const size_t length) {
// No offset
{
vector<bool> input_true(length + 3, false);
input_true.resize(length + 6, true);
input_true[length + 1].flip();
const auto result_true = find(input_true.cbegin(), prev(input_true.cend(), 3), true);
assert(result_true == next(input_true.cbegin(), static_cast<ptrdiff_t>(length + 1)));

vector<bool> input_false(length + 3, true);
input_false.resize(length + 6, false);
input_false[length + 1].flip();
const auto result_false = find(input_false.cbegin(), prev(input_false.cend(), 3), false);
assert(result_false == next(input_false.cbegin(), static_cast<ptrdiff_t>(length + 1)));
}

// With offset
{
vector<bool> input_true(length + 3, false);
input_true.resize(length + 6, true);
input_true[length + 1].flip();
input_true[0].flip();
const auto result_true = find(next(input_true.cbegin()), prev(input_true.cend(), 3), true);
assert(result_true == next(input_true.cbegin(), static_cast<ptrdiff_t>(length + 1)));

vector<bool> input_false(length + 3, true);
input_false.resize(length + 6, false);
input_false[length + 1].flip();
input_false[0].flip();
const auto result_false = find(next(input_false.cbegin()), prev(input_false.cend(), 3), false);
assert(result_false == next(input_false.cbegin(), static_cast<ptrdiff_t>(length + 1)));
}
}

bool test_find() {
// Empty range
test_find_helper(0);

// One block, ends within block
test_find_helper(15);

// One block, ends at block boundary
test_find_helper(blockSize);

// Multiple blocks, within block
test_find_helper(3 * blockSize + 5);

// Multiple blocks, ends at block boundary
test_find_helper(4 * blockSize);
return true;
}

int main() {
test_fill();
test_find();
}

0 comments on commit d88f306

Please sign in to comment.