Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize lexicographical_compare! #4552

Merged
27 changes: 21 additions & 6 deletions benchmarks/src/mismatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ using namespace std;

constexpr int64_t no_pos = -1;

template <class T>
enum class op {
mismatch,
lexi,
};

template <class T, op Op>
void bm(benchmark::State& state) {
vector<T> a(static_cast<size_t>(state.range(0)), T{'.'});
vector<T> b(static_cast<size_t>(state.range(0)), T{'.'});
Expand All @@ -22,15 +27,25 @@ void bm(benchmark::State& state) {
}

for (auto _ : state) {
benchmark::DoNotOptimize(ranges::mismatch(a, b));
if constexpr (Op == op::mismatch) {
benchmark::DoNotOptimize(ranges::mismatch(a, b));
} else if constexpr (Op == op::lexi) {
benchmark::DoNotOptimize(ranges::lexicographical_compare(a, b));
}
}
}

#define COMMON_ARGS Args({8, 3})->Args({24, 22})->Args({105, -1})->Args({4021, 3056})

BENCHMARK(bm<uint8_t>)->COMMON_ARGS;
BENCHMARK(bm<uint16_t>)->COMMON_ARGS;
BENCHMARK(bm<uint32_t>)->COMMON_ARGS;
BENCHMARK(bm<uint64_t>)->COMMON_ARGS;
BENCHMARK(bm<uint8_t, op::mismatch>)->COMMON_ARGS;
BENCHMARK(bm<uint16_t, op::mismatch>)->COMMON_ARGS;
BENCHMARK(bm<uint32_t, op::mismatch>)->COMMON_ARGS;
BENCHMARK(bm<uint64_t, op::mismatch>)->COMMON_ARGS;

BENCHMARK(bm<uint8_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<int8_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<uint16_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<uint32_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<uint64_t, op::lexi>)->COMMON_ARGS;

BENCHMARK_MAIN();
25 changes: 22 additions & 3 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -10868,7 +10868,8 @@ namespace ranges {
using _Memcmp_classification_pred = _Lex_compare_memcmp_classify<_It1, _It2, _Pr>;
constexpr bool _Is_sized1 = sized_sentinel_for<_Se1, _It1>;
constexpr bool _Is_sized2 = sized_sentinel_for<_Se2, _It2>;
if constexpr (!is_void_v<_Memcmp_classification_pred> && _Sized_or_unreachable_sentinel_for<_Se1, _It1>
if constexpr (!is_void_v<typename _Memcmp_classification_pred::_Pred>
&& _Sized_or_unreachable_sentinel_for<_Se1, _It1>
&& _Sized_or_unreachable_sentinel_for<_Se2, _It2> && same_as<_Pj1, identity>
&& same_as<_Pj2, identity> && (_Is_sized1 || _Is_sized2)) {
if (!_STD is_constant_evaluated()) {
Expand All @@ -10886,8 +10887,26 @@ namespace ranges {
_Num2 = SIZE_MAX;
}

const int _Ans = _STD _Memcmp_count(_First1, _First2, (_STD min)(_Num1, _Num2));
return _Memcmp_classification_pred{}(_Ans, 0) || (_Ans == 0 && _Num1 < _Num2);
const size_t _Num = (_STD min)(_Num1, _Num2);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Memcmp_classification_pred::_Opt == _Lex_cmp_opt::_Mismatch) {
const auto _First1_ptr = _STD to_address(_First1);
const auto _First2_ptr = _STD to_address(_First2);
const size_t _Pos = __std_mismatch<sizeof(*_First1_ptr)>(_First1_ptr, _First2_ptr, _Num);
if (_Pos == _Num2) {
return false;
} else if (_Pos == _Num1) {
return true;
} else {
return _STD invoke(_Pred, _First1_ptr[_Pos], _First2_ptr[_Pos]);
}
} else
#endif // _USE_STD_VECTOR_ALGORITHMS
{
const int _Ans = _STD _Memcmp_count(_First1, _First2, _Num);
return typename _Memcmp_classification_pred::_Pred{}(_Ans, 0) || (_Ans == 0 && _Num1 < _Num2);
}
}
}

Expand Down
5 changes: 3 additions & 2 deletions stl/inc/regex
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ _INLINE_VAR constexpr bool _Can_memcmp_elements_with_pred<_Elem, _Elem, _Char_tr
// TRANSITION: This should not be activated for user-defined specializations of char_traits
template <class _Elem>
struct _Lex_compare_memcmp_classify_pred<_Elem, _Elem, _Char_traits_lt<char_traits<_Elem>>> {
using _UElem = make_unsigned_t<_Elem>;
using _Pred = conditional_t<_Lex_compare_memcmp_classify_elements<_UElem, _UElem>, less<int>, void>;
using _UElem = make_unsigned_t<_Elem>;
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_UElem, _UElem>;
using _Pred = conditional_t<_Opt != _Lex_cmp_opt::_None, less<int>, void>;
};

template <class _RxTraits>
Expand Down
137 changes: 102 additions & 35 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5608,13 +5608,31 @@ namespace ranges {
} // namespace ranges
#endif // _HAS_CXX20

enum class _Lex_cmp_opt {
_None,
_Memcmp,
#if _USE_STD_VECTOR_ALGORITHMS
_Mismatch,
#endif // _USE_STD_VECTOR_ALGORITHMS
};

template <class _Elem1, class _Elem2>
_INLINE_VAR constexpr bool _Lex_compare_memcmp_classify_elements = conjunction_v<_Is_character_or_bool<_Elem1>,
_Is_character_or_bool<_Elem2>, is_unsigned<_Elem1>, is_unsigned<_Elem2>>;
_INLINE_VAR constexpr _Lex_cmp_opt _Lex_compare_memcmp_classify_elements =
conjunction_v<_Is_character_or_bool<_Elem1>, _Is_character_or_bool<_Elem2>, is_unsigned<_Elem1>,
is_unsigned<_Elem2>>
? _Lex_cmp_opt::_Memcmp
#if _USE_STD_VECTOR_ALGORITHMS
: ((is_integral_v<_Elem1> && is_integral_v<_Elem2> && sizeof(_Elem1) == sizeof(_Elem2)
&& is_unsigned_v<_Elem1> == is_unsigned_v<_Elem2>)
? _Lex_cmp_opt::_Mismatch
: _Lex_cmp_opt::_None);
#else // ^^^ _USE_STD_VECTOR_ALGORITHMS / !_USE_STD_VECTOR_ALGORITHMS vvv
: _Lex_cmp_opt::_None;
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^

#ifdef __cpp_lib_byte
template <>
inline constexpr bool _Lex_compare_memcmp_classify_elements<byte, byte> = true;
inline constexpr _Lex_cmp_opt _Lex_compare_memcmp_classify_elements<byte, byte> = _Lex_cmp_opt::_Memcmp;
#endif // defined(__cpp_lib_byte)

template <class _Elem1, class _Elem2, class _Pr>
Expand All @@ -5624,46 +5642,55 @@ struct _Lex_compare_memcmp_classify_pred {

template <class _Elem1, class _Elem2, class _Elem3>
struct _Lex_compare_memcmp_classify_pred<_Elem1, _Elem2, less<_Elem3>> {
using _Pred = conditional_t<_Lex_compare_memcmp_classify_elements<_Elem3, _Elem3>
&& _Iter_copy_cat<_Elem1*, _Elem3*>::_Bitcopy_constructible
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem3, _Elem3>;
using _Pred = conditional_t<_Opt != _Lex_cmp_opt::_None && _Iter_copy_cat<_Elem1*, _Elem3*>::_Bitcopy_constructible
&& _Iter_copy_cat<_Elem2*, _Elem3*>::_Bitcopy_constructible,
less<int>, void>;
};

template <class _Elem1, class _Elem2>
struct _Lex_compare_memcmp_classify_pred<_Elem1, _Elem2, less<>> {
using _Pred = conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>, less<int>, void>;
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Pred = conditional_t<_Opt != _Lex_cmp_opt::_None, less<int>, void>;
};

template <class _Elem1, class _Elem2, class _Elem3>
struct _Lex_compare_memcmp_classify_pred<_Elem1, _Elem2, greater<_Elem3>> {
using _Pred = conditional_t<_Lex_compare_memcmp_classify_elements<_Elem3, _Elem3>
&& _Iter_copy_cat<_Elem1*, _Elem3*>::_Bitcopy_constructible
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem3, _Elem3>;
using _Pred = conditional_t<_Opt != _Lex_cmp_opt::_None && _Iter_copy_cat<_Elem1*, _Elem3*>::_Bitcopy_constructible
&& _Iter_copy_cat<_Elem2*, _Elem3*>::_Bitcopy_constructible,
greater<int>, void>;
};

template <class _Elem1, class _Elem2>
struct _Lex_compare_memcmp_classify_pred<_Elem1, _Elem2, greater<>> {
using _Pred = conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>, greater<int>, void>;
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Pred = conditional_t<_Opt != _Lex_cmp_opt::_None, greater<int>, void>;
};

#if _HAS_CXX20
template <class _Elem1, class _Elem2>
struct _Lex_compare_memcmp_classify_pred<_Elem1, _Elem2, _RANGES less> {
using _Pred = conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>, less<int>, void>;
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Pred = conditional_t<_Opt != _Lex_cmp_opt::_None, less<int>, void>;
};

template <class _Elem1, class _Elem2>
struct _Lex_compare_memcmp_classify_pred<_Elem1, _Elem2, _RANGES greater> {
using _Pred = conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>, greater<int>, void>;
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Pred = conditional_t<_Opt != _Lex_cmp_opt::_None, greater<int>, void>;
};
#endif // _HAS_CXX20

struct _Lex_compare_memcmp_disable {
static constexpr _Lex_cmp_opt _Opt = _Lex_cmp_opt::_None;
using _Pred = void;
};

template <class _It1, class _It2, class _Pr>
using _Lex_compare_memcmp_classify =
conditional_t<_Iterators_are_contiguous<_It1, _It2> && !_Iterator_is_volatile<_It1> && !_Iterator_is_volatile<_It2>,
typename _Lex_compare_memcmp_classify_pred<_Iter_value_t<_It1>, _Iter_value_t<_It2>, _Pr>::_Pred, void>;
_Lex_compare_memcmp_classify_pred<_Iter_value_t<_It1>, _Iter_value_t<_It2>, _Pr>, _Lex_compare_memcmp_disable>;

_EXPORT_STD template <class _InIt1, class _InIt2, class _Pr>
_NODISCARD _CONSTEXPR20 bool lexicographical_compare(
Expand All @@ -5677,15 +5704,32 @@ _NODISCARD _CONSTEXPR20 bool lexicographical_compare(
const auto _ULast2 = _STD _Get_unwrapped(_Last2);

using _Memcmp_pred = _Lex_compare_memcmp_classify<decltype(_UFirst1), decltype(_UFirst2), _Pr>;
if constexpr (!is_void_v<_Memcmp_pred>) {
if constexpr (!is_void_v<typename _Memcmp_pred::_Pred>) {
#if _HAS_CXX20
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, (_STD min)(_Num1, _Num2));
return _Memcmp_pred{}(_Ans, 0) || (_Ans == 0 && _Num1 < _Num2);
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const size_t _Num = (_STD min)(_Num1, _Num2);
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Memcmp_pred::_Opt == _Lex_cmp_opt::_Mismatch) {
const auto _First1_ptr = _STD _To_address(_UFirst1);
const auto _First2_ptr = _STD _To_address(_UFirst2);
const size_t _Pos = __std_mismatch<sizeof(*_First1_ptr)>(_First1_ptr, _First2_ptr, _Num);
if (_Pos == _Num2) {
return false;
} else if (_Pos == _Num1) {
return true;
} else {
return _Pred(_First1_ptr[_Pos], _First2_ptr[_Pos]);
}
} else
#endif // _USE_STD_VECTOR_ALGORITHMS
{
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, _Num);
return typename _Memcmp_pred::_Pred{}(_Ans, 0) || (_Ans == 0 && _Num1 < _Num2);
}
}
}

Expand Down Expand Up @@ -5737,37 +5781,42 @@ struct _Lex_compare_three_way_memcmp_classify_comp {

template <class _Elem1, class _Elem2>
struct _Lex_compare_three_way_memcmp_classify_comp<_Elem1, _Elem2, compare_three_way> {
using _Comp = conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>
&& three_way_comparable_with<const _Elem1&, const _Elem2&>,
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Comp = conditional_t<_Opt != _Lex_cmp_opt::_None && three_way_comparable_with<const _Elem1&, const _Elem2&>,
compare_three_way, void>;
};

template <class _Elem1, class _Elem2>
struct _Lex_compare_three_way_memcmp_classify_comp<_Elem1, _Elem2, _Strong_order::_Cpo> {
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Comp =
conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2> && _Can_strong_order<_Elem1, _Elem2>,
_Strong_order::_Cpo, void>;
conditional_t<_Opt != _Lex_cmp_opt::_None && _Can_strong_order<_Elem1, _Elem2>, _Strong_order::_Cpo, void>;
};

template <class _Elem1, class _Elem2>
struct _Lex_compare_three_way_memcmp_classify_comp<_Elem1, _Elem2, _Weak_order::_Cpo> {
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Comp =
conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2> && _Can_weak_order<_Elem1, _Elem2>,
_Weak_order::_Cpo, void>;
conditional_t<_Opt != _Lex_cmp_opt::_None && _Can_weak_order<_Elem1, _Elem2>, _Weak_order::_Cpo, void>;
};

template <class _Elem1, class _Elem2>
struct _Lex_compare_three_way_memcmp_classify_comp<_Elem1, _Elem2, _Partial_order::_Cpo> {
static constexpr _Lex_cmp_opt _Opt = _Lex_compare_memcmp_classify_elements<_Elem1, _Elem2>;
using _Comp =
conditional_t<_Lex_compare_memcmp_classify_elements<_Elem1, _Elem2> && _Can_partial_order<_Elem1, _Elem2>,
_Partial_order::_Cpo, void>;
conditional_t<_Opt != _Lex_cmp_opt::_None && _Can_partial_order<_Elem1, _Elem2>, _Partial_order::_Cpo, void>;
};

struct _Lex_compare_three_way_memcmp_disable {
static constexpr _Lex_cmp_opt _Opt = _Lex_cmp_opt::_None;
using _Comp = void;
};

template <class _It1, class _It2, class _Cmp>
using _Lex_compare_three_way_memcmp_classify =
conditional_t<_Iterators_are_contiguous<_It1, _It2> && !_Iterator_is_volatile<_It1> && !_Iterator_is_volatile<_It2>,
typename _Lex_compare_three_way_memcmp_classify_comp<_Iter_value_t<_It1>, _Iter_value_t<_It2>, _Cmp>::_Comp,
void>;
_Lex_compare_three_way_memcmp_classify_comp<_Iter_value_t<_It1>, _Iter_value_t<_It2>, _Cmp>,
_Lex_compare_three_way_memcmp_disable>;

_EXPORT_STD template <class _InIt1, class _InIt2, class _Cmp>
_NODISCARD constexpr auto lexicographical_compare_three_way(const _InIt1 _First1, const _InIt1 _Last1,
Expand All @@ -5780,15 +5829,33 @@ _NODISCARD constexpr auto lexicographical_compare_three_way(const _InIt1 _First1
const auto _ULast2 = _STD _Get_unwrapped(_Last2);

using _Memcmp_pred = _Lex_compare_three_way_memcmp_classify<decltype(_UFirst1), decltype(_UFirst2), _Cmp>;
if constexpr (!is_void_v<_Memcmp_pred>) {
if constexpr (!is_void_v<typename _Memcmp_pred::_Comp>) {
if (!_STD is_constant_evaluated()) {
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, (_STD min)(_Num1, _Num2));
if (_Ans == 0) {
return _Num1 <=> _Num2;
} else {
return _Memcmp_pred{}(_Ans, 0);
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const size_t _Num = (_STD min)(_Num1, _Num2);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Memcmp_pred::_Opt == _Lex_cmp_opt::_Mismatch) {
const auto _First1_ptr = _STD to_address(_UFirst1);
const auto _First2_ptr = _STD to_address(_UFirst2);
const size_t _Pos = __std_mismatch<sizeof(*_First1_ptr)>(_First1_ptr, _First2_ptr, _Num);
if (_Pos == _Num1) {
return _Pos == _Num2 ? strong_ordering::equal : strong_ordering::less;
} else if (_Pos == _Num2) {
return strong_ordering::greater;
} else {
return _Comp(_First1_ptr[_Pos], _First2_ptr[_Pos]);
}
} else
#endif // _USE_STD_VECTOR_ALGORITHMS
{
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, _Num);
if (_Ans == 0) {
return _Num1 <=> _Num2;
} else {
return typename _Memcmp_pred::_Comp{}(_Ans, 0);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\char8_t_matrix.lst
RUNALL_CROSSLIST
* PM_CL="" # Test memcmp and manual vectorization
* PM_CL="/D_USE_STD_VECTOR_ALGORITHMS=0" # Test memcmp only
Loading