Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize ranges::find_last #3925

Merged
merged 14 commits into from
Oct 20, 2023
43 changes: 43 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ _Min_max_element_t __stdcall __std_minmax_element_1(const void* _First, const vo
_Min_max_element_t __stdcall __std_minmax_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_8(const void* _First, const void* _Last, bool _Signed) noexcept;

const void* __stdcall __std_find_last_trivial_1(const void* _First, const void* _Last, uint8_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_2(const void* _First, const void* _Last, uint16_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;
_END_EXTERN_C

template <class _Ty>
Expand All @@ -76,6 +81,27 @@ _STD pair<_Ty*, _Ty*> __std_minmax_element(_Ty* _First, _Ty* _Last) noexcept {

return {const_cast<_Ty*>(static_cast<const _Ty*>(_Res._Min)), const_cast<_Ty*>(static_cast<const _Ty*>(_Res._Max))};
}

template <class _Ty, class _TVal>
_Ty* __std_find_last_trivial(_Ty* _First, _Ty* _Last, const _TVal _Val) noexcept {
if constexpr (_STD is_pointer_v<_TVal> || _STD is_null_pointer_v<_TVal>) {
return __std_find_last_trivial(_First, _Last, reinterpret_cast<uintptr_t>(_Val));
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(__std_find_last_trivial_1(_First, _Last, static_cast<uint8_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(__std_find_last_trivial_2(_First, _Last, static_cast<uint16_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 4) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(__std_find_last_trivial_4(_First, _Last, static_cast<uint32_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 8) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(__std_find_last_trivial_8(_First, _Last, static_cast<uint64_t>(_Val))));
} else {
static_assert(_STD _Always_false<_Ty>, "Unexpected size");
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

_STD_BEGIN
Expand Down Expand Up @@ -2834,6 +2860,23 @@ namespace ranges {
_STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>);
_STL_INTERNAL_STATIC_ASSERT(indirect_binary_predicate<ranges::equal_to, projected<_It, _Pj>, const _Ty*>);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (is_same_v<_Pj, identity> && _Vector_alg_in_find_is_safe<_It, _Ty>
&& sized_sentinel_for<_Se, _It>) {
if (!_STD is_constant_evaluated()) {
const auto _First_ptr = _To_address(_First);
const auto _Last_ptr = _First_ptr + (_Last - _First);

const auto _Result = __std_find_last_trivial(_First_ptr, _Last_ptr, _Value);
if constexpr (is_pointer_v<_It>) {
return {_Result, _STD move(_Last)};
} else {
return {_STD move(_First) + (_Result - _First_ptr), _STD move(_Last)};
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

if constexpr (_Bidi_common<_It, _Se>) {
for (auto _Result = _Last; _Result != _First;) {
if (_STD invoke(_Proj, *--_Result) == _Value) {
Expand Down
100 changes: 98 additions & 2 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,21 @@ namespace {
return static_cast<const unsigned char*>(_Last) - static_cast<const unsigned char*>(_First);
}

void _Advance_bytes(void*& _Target, ptrdiff_t _Offset) noexcept {
void _Rewind_bytes(void*& _Target, size_t _Offset) noexcept {
_Target = static_cast<unsigned char*>(_Target) - _Offset;
}

void _Rewind_bytes(const void*& _Target, size_t _Offset) noexcept {
_Target = static_cast<const unsigned char*>(_Target) - _Offset;
}

template <class _Integral>
void _Advance_bytes(void*& _Target, _Integral _Offset) noexcept {
_Target = static_cast<unsigned char*>(_Target) + _Offset;
}

void _Advance_bytes(const void*& _Target, ptrdiff_t _Offset) noexcept {
template <class _Integral>
void _Advance_bytes(const void*& _Target, _Integral _Offset) noexcept {
_Target = static_cast<const unsigned char*>(_Target) + _Offset;
}
} // unnamed namespace
Expand Down Expand Up @@ -1143,6 +1153,20 @@ namespace {
return _Ptr;
}

template <class _Ty>
const void* _Find_trivial_last_tail(const void* _First, const void* _Last, const void* _Real_last, _Ty _Val) {
auto _Ptr = static_cast<const _Ty*>(_Last);
for (;;) {
if (_Ptr == _First) {
return _Real_last;
}
--_Ptr;
if (*_Ptr == _Val) {
return _Ptr;
}
}
}

template <class _Ty>
__declspec(noalias) size_t _Count_trivial_tail(const void* _First, const void* _Last, size_t _Current, _Ty _Val) {
auto _Ptr = static_cast<const _Ty*>(_First);
Expand Down Expand Up @@ -1396,6 +1420,58 @@ namespace {
return _Find_trivial_tail(_First, _Last, _Val);
}

template <class _Traits, class _Ty>
const void* __stdcall __std_find_last_trivial(const void* _First, const void* _Last, _Ty _Val) noexcept {
const void* const _Real_last = _Last;
#ifndef _M_ARM64EC
size_t _Size_bytes = _Byte_length(_First, _Last);

const size_t _Avx_size = _Size_bytes & ~size_t{0x1F};
if (_Avx_size != 0 && _Use_avx2()) {
_Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414

const __m256i _Comparand = _Traits::_Set_avx(_Val);
const void* _Stop_at = _Last;
_Rewind_bytes(_Stop_at, _Avx_size);
do {
_Rewind_bytes(_Last, 32);
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_Last));
const int _Bingo = _mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand));

if (_Bingo != 0) {
const unsigned long _Offset = _lzcnt_u32(_Bingo);
_Advance_bytes(_Last, (31 - _Offset) - (sizeof(_Ty) - 1));
return _Last;
}

} while (_Last != _Stop_at);
_Size_bytes &= 0x1F;
}

const size_t _Sse_size = _Size_bytes & ~size_t{0xF};
if (_Sse_size != 0 && _Traits::_Sse_available()) {
const __m128i _Comparand = _Traits::_Set_sse(_Val);
const void* _Stop_at = _Last;
_Rewind_bytes(_Stop_at, _Sse_size);
do {
_Rewind_bytes(_Last, 16);
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_Last));
const int _Bingo = _mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand));

if (_Bingo != 0) {
unsigned long _Offset;
_BitScanReverse(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable]
_Advance_bytes(_Last, _Offset - (sizeof(_Ty) - 1));
return _Last;
}

} while (_Last != _Stop_at);
}
#endif // !_M_ARM64EC

return _Find_trivial_last_tail(_First, _Last, _Real_last, _Val);
}

template <class _Traits, class _Ty>
__declspec(noalias) size_t
__stdcall __std_count_trivial(const void* _First, const void* const _Last, const _Ty _Val) noexcept {
Expand Down Expand Up @@ -1476,6 +1552,26 @@ const void* __stdcall __std_find_trivial_8(
return __std_find_trivial<_Find_traits_8>(_First, _Last, _Val);
}

const void* __stdcall __std_find_last_trivial_1(
const void* const _First, const void* const _Last, const uint8_t _Val) noexcept {
return __std_find_last_trivial<_Find_traits_1>(_First, _Last, _Val);
}

const void* __stdcall __std_find_last_trivial_2(
const void* const _First, const void* const _Last, const uint16_t _Val) noexcept {
return __std_find_last_trivial<_Find_traits_2>(_First, _Last, _Val);
}

const void* __stdcall __std_find_last_trivial_4(
const void* const _First, const void* const _Last, const uint32_t _Val) noexcept {
return __std_find_last_trivial<_Find_traits_4>(_First, _Last, _Val);
}

const void* __stdcall __std_find_last_trivial_8(
const void* const _First, const void* const _Last, const uint64_t _Val) noexcept {
return __std_find_last_trivial<_Find_traits_8>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_count_trivial_1(const void* const _First, const void* const _Last, const uint8_t _Val) noexcept {
return __std_count_trivial<_Find_traits_1>(_First, _Last, _Val);
Expand Down
119 changes: 119 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ inline auto last_known_good_find(FwdIt first, FwdIt last, T v) {
return first;
}

template <class FwdIt, class T>
inline auto last_known_good_find_last(FwdIt first, FwdIt last, T v) {
FwdIt last_save = last;
for (;;) {
if (last == first) {
return last_save;
}
--last;
if (*last == v) {
return last;
}
}
}

template <class T>
void test_case_find(const vector<T>& input, T v) {
auto expected = last_known_good_find(input.begin(), input.end(), v);
Expand All @@ -123,6 +137,30 @@ void test_find(mt19937_64& gen) {
}
}

#if _HAS_CXX23 && defined(__cpp_lib_concepts)
template <class T>
void test_case_find_last(const vector<T>& input, T v) {
auto expected = last_known_good_find_last(input.begin(), input.end(), v);
auto range = ranges::find_last(input.begin(), input.end(), v);
auto actual = range.begin();
assert(expected == actual);
assert(range.end() == input.end());
}

template <class T>
void test_find_last(mt19937_64& gen) {
using TD = conditional_t<sizeof(T) == 1, int, T>;
binomial_distribution<TD> dis(10);
vector<T> input;
input.reserve(dataCount);
test_case_find_last(input, static_cast<T>(dis(gen)));
for (size_t attempts = 0; attempts < dataCount; ++attempts) {
input.push_back(static_cast<T>(dis(gen)));
test_case_find_last(input, static_cast<T>(dis(gen)));
}
}
#endif // _HAS_CXX23 && defined(__cpp_lib_concepts)

template <class T>
void test_min_max_element(mt19937_64& gen) {
using Limits = numeric_limits<T>;
Expand Down Expand Up @@ -305,6 +343,18 @@ void test_vector_algorithms(mt19937_64& gen) {
test_find<long long>(gen);
test_find<unsigned long long>(gen);

#if _HAS_CXX23 && defined(__cpp_lib_concepts)
test_find_last<char>(gen);
test_find_last<signed char>(gen);
test_find_last<unsigned char>(gen);
test_find_last<short>(gen);
test_find_last<unsigned short>(gen);
test_find_last<int>(gen);
test_find_last<unsigned int>(gen);
test_find_last<long long>(gen);
test_find_last<unsigned long long>(gen);
#endif // _HAS_CXX23 && defined(__cpp_lib_concepts)

test_min_max_element<char>(gen);
test_min_max_element<signed char>(gen);
test_min_max_element<unsigned char>(gen);
Expand Down Expand Up @@ -399,7 +449,76 @@ void test_various_containers() {
test_one_container<list<int>>(); // bidi, not vectorizable
}

#if _HAS_CXX20
constexpr bool test_constexpr() {
const int a[] = {20, 10, 30, 30, 30, 30, 40, 60, 50};

assert(count(begin(a), end(a), 30) == 4);
#ifdef __cpp_lib_concepts
assert(ranges::count(a, 30) == 4);
#endif // defined(__cpp_lib_concepts)

assert(find(begin(a), end(a), 30) == begin(a) + 2);
#ifdef __cpp_lib_concepts
assert(ranges::find(a, 30) == begin(a) + 2);
#endif // defined(__cpp_lib_concepts)

#if defined(__cpp_lib_concepts) && _HAS_CXX23
assert(begin(ranges::find_last(a, 30)) == begin(a) + 5);
assert(end(ranges::find_last(a, 30)) == end(a));
#endif // defined(__cpp_lib_concepts) && _HAS_CXX23

assert(min_element(begin(a), end(a)) == begin(a) + 1);
assert(max_element(begin(a), end(a)) == end(a) - 2);
assert(get<0>(minmax_element(begin(a), end(a))) == begin(a) + 1);
assert(get<1>(minmax_element(begin(a), end(a))) == end(a) - 2);

#ifdef __cpp_lib_concepts
assert(ranges::min_element(a) == begin(a) + 1);
assert(ranges::max_element(a) == end(a) - 2);
assert(ranges::minmax_element(a).min == begin(a) + 1);
assert(ranges::minmax_element(a).max == end(a) - 2);
#endif // defined(__cpp_lib_concepts)

int b[size(a)];
reverse_copy(begin(a), end(a), begin(b));
assert(equal(rbegin(a), rend(a), begin(b)));

int c[size(a)];
#ifdef __cpp_lib_concepts
ranges::reverse_copy(a, c);
assert(equal(rbegin(a), rend(a), begin(c)));
#else // ^^^ defined(__cpp_lib_concepts) / !defined(__cpp_lib_concepts) vvv
reverse_copy(begin(a), end(a), begin(c)); // for swap_ranges test below
#endif // ^^^ !defined(__cpp_lib_concepts) ^^^

reverse(begin(b), end(b));
assert(equal(begin(a), end(a), begin(b)));

swap_ranges(begin(b), end(b), begin(c));
assert(equal(rbegin(a), rend(a), begin(b)));
assert(equal(begin(a), end(a), begin(c)));

#ifdef __cpp_lib_concepts
ranges::swap_ranges(b, c);
assert(equal(begin(a), end(a), begin(b)));
assert(equal(rbegin(a), rend(a), begin(c)));

ranges::reverse(c);
assert(equal(begin(a), end(a), begin(c)));
#endif // defined(__cpp_lib_concepts)

return true;
}

static_assert(test_constexpr());
#endif // _HAS_CXX20

int main() {
#if _HAS_CXX20
assert(test_constexpr());
#endif // _HAS_CXX20

mt19937_64 gen;
initialize_randomness(gen);

Expand Down