Skip to content

Commit

Permalink
vectorize min/max_element using SSE4.1 for floats (#3928)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
AlexGuteniev and StephanTLavavej authored Feb 6, 2024
1 parent f49ffd2 commit 192a840
Show file tree
Hide file tree
Showing 6 changed files with 506 additions and 61 deletions.
1 change: 1 addition & 0 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ endfunction()

add_benchmark(bitset_to_string src/bitset_to_string.cpp)
add_benchmark(locale_classic src/locale_classic.cpp)
add_benchmark(minmax_element src/minmax_element.cpp)
add_benchmark(path_lexically_normal src/path_lexically_normal.cpp)
add_benchmark(priority_queue_push_range src/priority_queue_push_range.cpp)
add_benchmark(random_integer_generation src/random_integer_generation.cpp)
Expand Down
86 changes: 86 additions & 0 deletions benchmarks/src/minmax_element.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <benchmark/benchmark.h>
#include <cstddef>
#include <cstdint>
#include <random>
#include <ranges>
#include <type_traits>

enum class Op {
Min,
Max,
Both,
};

using namespace std;

template <class T, size_t Size, Op Operation>
void bm(benchmark::State& state) {
T a[Size];

mt19937 gen(84710);

if constexpr (is_floating_point_v<T>) {
normal_distribution<T> dis(0, 10000.0);
ranges::generate(a, [&] { return dis(gen); });
} else {
uniform_int_distribution<conditional_t<sizeof(T) != 1, T, int>> dis(1, 20);
ranges::generate(a, [&] { return static_cast<T>(dis(gen)); });
}

for (auto _ : state) {
if constexpr (Operation == Op::Min) {
benchmark::DoNotOptimize(ranges::min_element(a));
} else if constexpr (Operation == Op::Max) {
benchmark::DoNotOptimize(ranges::max_element(a));
} else if constexpr (Operation == Op::Both) {
benchmark::DoNotOptimize(ranges::minmax_element(a));
}
}
}

BENCHMARK(bm<uint8_t, 8021, Op::Min>);
BENCHMARK(bm<uint8_t, 8021, Op::Max>);
BENCHMARK(bm<uint8_t, 8021, Op::Both>);

BENCHMARK(bm<uint16_t, 8021, Op::Min>);
BENCHMARK(bm<uint16_t, 8021, Op::Max>);
BENCHMARK(bm<uint16_t, 8021, Op::Both>);

BENCHMARK(bm<uint32_t, 8021, Op::Min>);
BENCHMARK(bm<uint32_t, 8021, Op::Max>);
BENCHMARK(bm<uint32_t, 8021, Op::Both>);

BENCHMARK(bm<uint64_t, 8021, Op::Min>);
BENCHMARK(bm<uint64_t, 8021, Op::Max>);
BENCHMARK(bm<uint64_t, 8021, Op::Both>);

BENCHMARK(bm<int8_t, 8021, Op::Min>);
BENCHMARK(bm<int8_t, 8021, Op::Max>);
BENCHMARK(bm<int8_t, 8021, Op::Both>);

BENCHMARK(bm<int16_t, 8021, Op::Min>);
BENCHMARK(bm<int16_t, 8021, Op::Max>);
BENCHMARK(bm<int16_t, 8021, Op::Both>);

BENCHMARK(bm<int32_t, 8021, Op::Min>);
BENCHMARK(bm<int32_t, 8021, Op::Max>);
BENCHMARK(bm<int32_t, 8021, Op::Both>);

BENCHMARK(bm<int64_t, 8021, Op::Min>);
BENCHMARK(bm<int64_t, 8021, Op::Max>);
BENCHMARK(bm<int64_t, 8021, Op::Both>);

BENCHMARK(bm<float, 8021, Op::Min>);
BENCHMARK(bm<float, 8021, Op::Max>);
BENCHMARK(bm<float, 8021, Op::Both>);

BENCHMARK(bm<double, 8021, Op::Min>);
BENCHMARK(bm<double, 8021, Op::Max>);
BENCHMARK(bm<double, 8021, Op::Both>);


BENCHMARK_MAIN();
8 changes: 7 additions & 1 deletion stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ _Min_max_element_t __stdcall __std_minmax_element_1(const void* _First, const vo
_Min_max_element_t __stdcall __std_minmax_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_8(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;

const void* __stdcall __std_find_last_trivial_1(const void* _First, const void* _Last, uint8_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_2(const void* _First, const void* _Last, uint16_t _Val) noexcept;
Expand All @@ -68,7 +70,11 @@ _STD pair<_Ty*, _Ty*> __std_minmax_element(_Ty* _First, _Ty* _Last) noexcept {

_Min_max_element_t _Res;

if constexpr (sizeof(_Ty) == 1) {
if constexpr (_STD is_same_v<_STD remove_const_t<_Ty>, float>) {
_Res = ::__std_minmax_element_f(_First, _Last, false);
} else if constexpr (_STD _Is_any_of_v<_STD remove_const_t<_Ty>, double, long double>) {
_Res = ::__std_minmax_element_d(_First, _Last, false);
} else if constexpr (sizeof(_Ty) == 1) {
_Res = ::__std_minmax_element_1(_First, _Last, _Signed);
} else if constexpr (sizeof(_Ty) == 2) {
_Res = ::__std_minmax_element_2(_First, _Last, _Signed);
Expand Down
38 changes: 35 additions & 3 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ _STL_DISABLE_CLANG_WARNINGS
#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS != 0 ^^^
#endif // ^^^ no support for vector algorithms ^^^

#ifndef _USE_STD_VECTOR_FLOATING_ALGORITHMS
#if _USE_STD_VECTOR_ALGORITHMS && !defined(_M_FP_EXCEPT)
#define _USE_STD_VECTOR_FLOATING_ALGORITHMS 1
#else // ^^^ use vector algorithms and fast math / not use vector algorithms or not use fast math vvv
#define _USE_STD_VECTOR_FLOATING_ALGORITHMS 0
#endif // ^^^ not use vector algorithms or not use fast math ^^^
#else // ^^^ !defined(_USE_STD_VECTOR_FLOATING_ALGORITHMS) / defined(_USE_STD_VECTOR_FLOATING_ALGORITHMS) vvv
#if _USE_STD_VECTOR_FLOATING_ALGORITHMS && !_USE_STD_VECTOR_ALGORITHMS
#error _USE_STD_VECTOR_FLOATING_ALGORITHMS must imply _USE_STD_VECTOR_ALGORITHMS.
#endif // _USE_STD_VECTOR_FLOATING_ALGORITHMS && !_USE_STD_VECTOR_ALGORITHMS
#endif // ^^^ defined(_USE_STD_VECTOR_FLOATING_ALGORITHMS) ^^^

#if _USE_STD_VECTOR_ALGORITHMS
extern "C" {
// The "noalias" attribute tells the compiler optimizer that pointers going into these hand-vectorized algorithms
Expand Down Expand Up @@ -87,11 +99,15 @@ const void* __stdcall __std_min_element_1(const void* _First, const void* _Last,
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_8(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
const void* __stdcall __std_min_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;

const void* __stdcall __std_max_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_8(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
const void* __stdcall __std_max_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;
} // extern "C"

_STD_BEGIN
Expand Down Expand Up @@ -158,7 +174,11 @@ template <class _Ty>
_Ty* __std_min_element(_Ty* _First, _Ty* _Last) noexcept {
constexpr bool _Signed = _STD is_signed_v<_Ty>;

if constexpr (sizeof(_Ty) == 1) {
if constexpr (_STD is_same_v<_STD remove_const_t<_Ty>, float>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_f(_First, _Last, false)));
} else if constexpr (_STD _Is_any_of_v<_STD remove_const_t<_Ty>, double, long double>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_d(_First, _Last, false)));
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_1(_First, _Last, _Signed)));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_2(_First, _Last, _Signed)));
Expand All @@ -175,7 +195,11 @@ template <class _Ty>
_Ty* __std_max_element(_Ty* _First, _Ty* _Last) noexcept {
constexpr bool _Signed = _STD is_signed_v<_Ty>;

if constexpr (sizeof(_Ty) == 1) {
if constexpr (_STD is_same_v<_STD remove_const_t<_Ty>, float>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_f(_First, _Last, false)));
} else if constexpr (_STD _Is_any_of_v<_STD remove_const_t<_Ty>, double, long double>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_d(_First, _Last, false)));
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_1(_First, _Last, _Signed)));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_2(_First, _Last, _Signed)));
Expand Down Expand Up @@ -6607,7 +6631,15 @@ template <class _Iter, class _Pr, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Is_min_max_optimization_safe = // Activate the vector algorithms for min_/max_element?
_Iterator_is_contiguous<_Iter> // The iterator must be contiguous so we can get raw pointers.
&& !_Iterator_is_volatile<_Iter> // The iterator must not be volatile.
&& conjunction_v<disjunction<is_integral<_Elem>, is_pointer<_Elem>>, // Element is of integral or pointer type.
&& conjunction_v<disjunction<
#if _USE_STD_VECTOR_FLOATING_ALGORITHMS
#if defined(__LDBL_DIG__) && __LDBL_DIG__ == 18
is_same<_Elem, float>, is_same<_Elem, double>,
#else // ^^^ 80-bit long double (not supported by MSVC in general, see GH-1316) / 64-bit long double vvv
is_floating_point<_Elem>, // Element is floating point or...
#endif // ^^^ 64-bit long double ^^^
#endif // _USE_STD_VECTOR_FLOATING_ALGORITHMS
is_integral<_Elem>, is_pointer<_Elem>>, // ... integral or pointer type.
disjunction< // And either of the following:
#if _HAS_CXX20
is_same<_Pr, _RANGES less>, // predicate is ranges::less
Expand Down
Loading

0 comments on commit 192a840

Please sign in to comment.