Skip to content

Commit

Permalink
Vectorize std::search of 1 and 2 bytes elements with pcmpestri (#…
Browse files Browse the repository at this point in the history
…4745)

Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
AlexGuteniev and StephanTLavavej authored Sep 9, 2024
1 parent c7c5ca7 commit e931261
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 14 deletions.
46 changes: 32 additions & 14 deletions benchmarks/src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <array>
#include <benchmark/benchmark.h>
#include <cstdint>
#include <cstring>
#include <functional>
#include <string>
#include <string_view>
#include <vector>
using namespace std::string_view_literals;

const char src_haystack[] =
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam mollis imperdiet massa, at dapibus elit interdum "
Expand Down Expand Up @@ -40,9 +43,14 @@ const char src_haystack[] =
"euismod eros, ut posuere ligula ullamcorper id. Nullam aliquet malesuada est at dignissim. Pellentesque finibus "
"sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis.";

const char src_needle[] = "aliquet";
constexpr std::array patterns = {
"aliquet"sv,
"aliquet malesuada"sv,
};

void c_strstr(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::string haystack(std::begin(src_haystack), std::end(src_haystack));
const std::string needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -56,6 +64,8 @@ void c_strstr(benchmark::State& state) {

template <class T>
void classic_search(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -69,6 +79,8 @@ void classic_search(benchmark::State& state) {

template <class T>
void ranges_search(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -82,6 +94,8 @@ void ranges_search(benchmark::State& state) {

template <class T>
void search_default_searcher(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -93,22 +107,26 @@ void search_default_searcher(benchmark::State& state) {
}
}

BENCHMARK(c_strstr);
void common_args(auto bm) {
bm->Range(0, patterns.size() - 1);
}

BENCHMARK(c_strstr)->Apply(common_args);

BENCHMARK(classic_search<std::uint8_t>);
BENCHMARK(classic_search<std::uint16_t>);
BENCHMARK(classic_search<std::uint32_t>);
BENCHMARK(classic_search<std::uint64_t>);
BENCHMARK(classic_search<std::uint8_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint32_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint64_t>)->Apply(common_args);

BENCHMARK(ranges_search<std::uint8_t>);
BENCHMARK(ranges_search<std::uint16_t>);
BENCHMARK(ranges_search<std::uint32_t>);
BENCHMARK(ranges_search<std::uint64_t>);
BENCHMARK(ranges_search<std::uint8_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint32_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint64_t>)->Apply(common_args);

BENCHMARK(search_default_searcher<std::uint8_t>);
BENCHMARK(search_default_searcher<std::uint16_t>);
BENCHMARK(search_default_searcher<std::uint32_t>);
BENCHMARK(search_default_searcher<std::uint64_t>);
BENCHMARK(search_default_searcher<std::uint8_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint16_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint32_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint64_t>)->Apply(common_args);


BENCHMARK_MAIN();
20 changes: 20 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,26 @@ _NODISCARD _CONSTEXPR20 _FwdItHaystack search(_FwdItHaystack _First1, _FwdItHays
if constexpr (_Is_ranges_random_iter_v<_FwdItHaystack> && _Is_ranges_random_iter_v<_FwdItPat>) {
const _Iter_diff_t<_FwdItPat> _Count2 = _ULast2 - _UFirst2;
if (_ULast1 - _UFirst1 >= _Count2) {
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
const auto _Ptr1 = _STD _To_address(_UFirst1);

const auto _Ptr_res1 = _STD _Search_vectorized(
_Ptr1, _STD _To_address(_ULast1), _STD _To_address(_UFirst2), static_cast<size_t>(_Count2));

if constexpr (is_pointer_v<decltype(_UFirst1)>) {
_UFirst1 = _Ptr_res1;
} else {
_UFirst1 += _Ptr_res1 - _Ptr1;
}

_STD _Seek_wrapped(_Last1, _UFirst1);
return _Last1;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

const auto _Last_possible = _ULast1 - static_cast<_Iter_diff_t<_FwdItHaystack>>(_Count2);
for (;; ++_UFirst1) {
if (_STD _Equal_rev_pred_unchecked(_UFirst1, _UFirst2, _ULast2, _STD _Pass_fn(_Pred))) {
Expand Down
23 changes: 23 additions & 0 deletions stl/inc/functional
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,29 @@ _CONSTEXPR20 pair<_FwdItHaystack, _FwdItHaystack> _Search_pair_unchecked(
_Iter_diff_t<_FwdItHaystack> _Count1 = _Last1 - _First1;
_Iter_diff_t<_FwdItPat> _Count2 = _Last2 - _First2;

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<_FwdItHaystack, _FwdItPat, _Pred_eq>) {
if (!_STD _Is_constant_evaluated()) {
const auto _Ptr1 = _STD _To_address(_First1);

const auto _Ptr_res1 = _STD _Search_vectorized(
_Ptr1, _STD _To_address(_Last1), _STD _To_address(_First2), static_cast<size_t>(_Count2));

if constexpr (is_pointer_v<_FwdItHaystack>) {
_First1 = _Ptr_res1;
} else {
_First1 += _Ptr_res1 - _Ptr1;
}

if (_First1 != _Last1) {
return {_First1, _First1 + static_cast<_Iter_diff_t<_FwdItHaystack>>(_Count2)};
} else {
return {_Last1, _Last1};
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _Count2 <= _Count1; ++_First1, (void) --_Count1) { // room for match, try it
_FwdItHaystack _Mid1 = _First1;
for (_FwdItPat _Mid2 = _First2;; ++_Mid1, (void) ++_Mid2) {
Expand Down
57 changes: 57 additions & 0 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ const void* __stdcall __std_find_first_of_trivial_4(
const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;

const void* __stdcall __std_search_1(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
const void* __stdcall __std_search_2(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;

const void* __stdcall __std_min_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
Expand Down Expand Up @@ -231,6 +236,18 @@ _Ty1* _Find_first_of_vectorized(
}
}

template <class _Ty1, class _Ty2>
_Ty1* _Search_vectorized(_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, const size_t _Count2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_search_1(_First1, _Last1, _First2, _Count2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_search_2(_First1, _Last1, _First2, _Count2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty>
_Ty* _Min_element_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
constexpr bool _Signed = is_signed_v<_Ty>;
Expand Down Expand Up @@ -5411,6 +5428,11 @@ template <class _Iter1, class _Iter2, class _Pr>
constexpr bool _Equal_memcmp_is_safe =
_Equal_memcmp_is_safe_helper<remove_const_t<_Iter1>, remove_const_t<_Iter2>, remove_const_t<_Pr>>;

// Can we activate the vector algorithms for std::search?
template <class _It1, class _It2, class _Pr>
constexpr bool _Vector_alg_in_search_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr> // can search bitwise
&& sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible element size

template <class _CtgIt1, class _CtgIt2>
_NODISCARD int _Memcmp_count(_CtgIt1 _First1, _CtgIt2 _First2, const size_t _Count) {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Iter_value_t<_CtgIt1>) == sizeof(_Iter_value_t<_CtgIt2>));
Expand Down Expand Up @@ -6788,6 +6810,41 @@ namespace ranges {
_STL_INTERNAL_CHECK(_RANGES distance(_First1, _Last1) == _Count1);
_STL_INTERNAL_CHECK(_RANGES distance(_First2, _Last2) == _Count2);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
&& is_same_v<_Pj2, identity>) {
if (!_STD is_constant_evaluated()) {
const auto _Ptr1 = _STD to_address(_First1);
const auto _Ptr2 = _STD to_address(_First2);
remove_const_t<decltype(_Ptr1)> _Ptr_last1;

if constexpr (is_same_v<_It1, _Se1>) {
_Ptr_last1 = _STD to_address(_Last1);
} else {
_Ptr_last1 = _Ptr1 + _Count1;
}

const auto _Ptr_res1 =
_STD _Search_vectorized(_Ptr1, _Ptr_last1, _Ptr2, static_cast<size_t>(_Count2));

if constexpr (is_pointer_v<_It1>) {
if (_Ptr_res1 != _Ptr_last1) {
return {_Ptr_res1, _Ptr_res1 + _Count2};
} else {
return {_Ptr_res1, _Ptr_res1};
}
} else {
_First1 += _Ptr_res1 - _Ptr1;
if (_First1 != _Last1) {
return {_First1, _First1 + static_cast<iter_difference_t<_It1>>(_Count2)};
} else {
return {_First1, _First1};
}
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _Count1 >= _Count2; ++_First1, (void) --_Count1) {
auto _Match_and_mid1 = _RANGES _Equal_rev_pred(_First1, _First2, _Last2, _Pred, _Proj1, _Proj2);
if (_Match_and_mid1.first) {
Expand Down
Loading

0 comments on commit e931261

Please sign in to comment.