Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vectorize min/max_element using SSE4.1 for floats #3928

Merged
merged 43 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
cd7420a
vectorize `min/max_element` using SSE4.1/AVX for floats
AlexGuteniev Aug 6, 2023
80b561c
`const`
AlexGuteniev Aug 6, 2023
86b3100
reverse init
AlexGuteniev Aug 6, 2023
55c5add
format
AlexGuteniev Aug 6, 2023
c88fe9a
more interesting values
AlexGuteniev Aug 6, 2023
8f70406
fix x68 build
AlexGuteniev Aug 6, 2023
52c8191
format
AlexGuteniev Aug 6, 2023
af9c9b5
coverage
AlexGuteniev Aug 6, 2023
ff97c5b
ouch
AlexGuteniev Aug 6, 2023
90b7999
format
AlexGuteniev Aug 6, 2023
568a793
include `/fp:strict` / `/fp:precise`
AlexGuteniev Aug 6, 2023
7f2a635
-extra casts
AlexGuteniev Aug 6, 2023
f6b24a9
copypaste error
AlexGuteniev Aug 6, 2023
91ed8a3
more interesting input
AlexGuteniev Aug 7, 2023
46baf22
Unsupport 80-bit long double
AlexGuteniev Aug 7, 2023
83d208f
+benchmark
AlexGuteniev Aug 7, 2023
1be219f
include order
AlexGuteniev Aug 7, 2023
99b1746
-copy
AlexGuteniev Aug 7, 2023
58fc6b9
simplify benchmark
AlexGuteniev Aug 7, 2023
eb388ad
fix build
AlexGuteniev Aug 7, 2023
5efca15
Merge branch 'main' into guess_whos_back
StephanTLavavej Oct 20, 2023
6e90c6c
Merge remote-tracking branch 'upstream/main' into guess_whos_back
AlexGuteniev Oct 26, 2023
8249aa5
load noexcept
AlexGuteniev Oct 26, 2023
7404970
fix copypasta during merge
AlexGuteniev Oct 26, 2023
b8c61d2
Merge branch 'main' into guess_whos_back
StephanTLavavej Nov 7, 2023
e6cf685
Merge remote-tracking branch 'upstream/main' into guess_whos_back
AlexGuteniev Dec 25, 2023
33834c3
Merge branch 'guess_whos_back' of https://github.com/AlexGuteniev/STL…
AlexGuteniev Dec 25, 2023
b0867f8
ADL-wary
AlexGuteniev Dec 25, 2023
ce7cdc1
ADL-wary
AlexGuteniev Dec 25, 2023
9e87e91
ADL-wary
AlexGuteniev Dec 25, 2023
99fb9c7
Merge branch 'main' into guess_whos_back
StephanTLavavej Jan 31, 2024
43f7d92
Use `_Is_any_of_v`.
StephanTLavavej Jan 31, 2024
83035c0
Comment nitpicks.
StephanTLavavej Jan 31, 2024
30b8748
Fix `#error` message, use "must imply" phrasing.
StephanTLavavej Jan 31, 2024
d2b3320
Style: Unnamed `const bool` => `bool`
StephanTLavavej Jan 31, 2024
eef60ce
Style: Add newline.
StephanTLavavej Jan 31, 2024
00ba973
`test_min_max_element_f` => `test_min_max_element_floating`
StephanTLavavej Jan 31, 2024
e9a76e4
Test ordinary negative values too.
StephanTLavavej Jan 31, 2024
3e17d05
Drop `static_cast<T>` as `input_of_input` is `vector<T>`.
StephanTLavavej Jan 31, 2024
40ca00b
Enable warnings when building the benchmarks.
StephanTLavavej Jan 31, 2024
af4df71
Fix truncation warnings in benchmarks.
StephanTLavavej Jan 31, 2024
bf79787
Fix x86 size_t truncation warnings in the vector.bool benchmarks.
StephanTLavavej Jan 31, 2024
358dd22
Revert enabling warnings for benchmarks.
StephanTLavavej Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ endfunction()

add_benchmark(bitset_to_string src/bitset_to_string.cpp)
add_benchmark(locale_classic src/locale_classic.cpp)
add_benchmark(minmax_element src/minmax_element.cpp)
add_benchmark(path_lexically_normal src/path_lexically_normal.cpp)
add_benchmark(priority_queue_push_range src/priority_queue_push_range.cpp)
add_benchmark(random_integer_generation src/random_integer_generation.cpp)
Expand Down
86 changes: 86 additions & 0 deletions benchmarks/src/minmax_element.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <benchmark/benchmark.h>
#include <cstddef>
#include <cstdint>
#include <random>
#include <ranges>
#include <type_traits>

enum class Op {
Min,
Max,
Both,
};

using namespace std;

template <class T, size_t Size, Op Operation>
void bm(benchmark::State& state) {
T a[Size];

mt19937 gen(84710);

if constexpr (is_floating_point_v<T>) {
normal_distribution<T> dis(0, 10000.0);
ranges::generate(a, [&] { return dis(gen); });
} else {
uniform_int_distribution<conditional_t<sizeof(T) != 1, T, int>> dis(1, 20);
ranges::generate(a, [&] { return static_cast<T>(dis(gen)); });
}

for (auto _ : state) {
if constexpr (Operation == Op::Min) {
benchmark::DoNotOptimize(ranges::min_element(a));
} else if constexpr (Operation == Op::Max) {
benchmark::DoNotOptimize(ranges::max_element(a));
} else if constexpr (Operation == Op::Both) {
benchmark::DoNotOptimize(ranges::minmax_element(a));
}
}
}

BENCHMARK(bm<uint8_t, 8021, Op::Min>);
BENCHMARK(bm<uint8_t, 8021, Op::Max>);
BENCHMARK(bm<uint8_t, 8021, Op::Both>);

BENCHMARK(bm<uint16_t, 8021, Op::Min>);
BENCHMARK(bm<uint16_t, 8021, Op::Max>);
BENCHMARK(bm<uint16_t, 8021, Op::Both>);

BENCHMARK(bm<uint32_t, 8021, Op::Min>);
BENCHMARK(bm<uint32_t, 8021, Op::Max>);
BENCHMARK(bm<uint32_t, 8021, Op::Both>);

BENCHMARK(bm<uint64_t, 8021, Op::Min>);
BENCHMARK(bm<uint64_t, 8021, Op::Max>);
BENCHMARK(bm<uint64_t, 8021, Op::Both>);

BENCHMARK(bm<int8_t, 8021, Op::Min>);
BENCHMARK(bm<int8_t, 8021, Op::Max>);
BENCHMARK(bm<int8_t, 8021, Op::Both>);

BENCHMARK(bm<int16_t, 8021, Op::Min>);
BENCHMARK(bm<int16_t, 8021, Op::Max>);
BENCHMARK(bm<int16_t, 8021, Op::Both>);

BENCHMARK(bm<int32_t, 8021, Op::Min>);
BENCHMARK(bm<int32_t, 8021, Op::Max>);
BENCHMARK(bm<int32_t, 8021, Op::Both>);

BENCHMARK(bm<int64_t, 8021, Op::Min>);
BENCHMARK(bm<int64_t, 8021, Op::Max>);
BENCHMARK(bm<int64_t, 8021, Op::Both>);

BENCHMARK(bm<float, 8021, Op::Min>);
BENCHMARK(bm<float, 8021, Op::Max>);
BENCHMARK(bm<float, 8021, Op::Both>);

BENCHMARK(bm<double, 8021, Op::Min>);
BENCHMARK(bm<double, 8021, Op::Max>);
BENCHMARK(bm<double, 8021, Op::Both>);


BENCHMARK_MAIN();
8 changes: 7 additions & 1 deletion stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ _Min_max_element_t __stdcall __std_minmax_element_1(const void* _First, const vo
_Min_max_element_t __stdcall __std_minmax_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_8(const void* _First, const void* _Last, bool _Signed) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;

const void* __stdcall __std_find_last_trivial_1(const void* _First, const void* _Last, uint8_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_2(const void* _First, const void* _Last, uint16_t _Val) noexcept;
Expand All @@ -68,7 +70,11 @@ _STD pair<_Ty*, _Ty*> __std_minmax_element(_Ty* _First, _Ty* _Last) noexcept {

_Min_max_element_t _Res;

if constexpr (sizeof(_Ty) == 1) {
if constexpr (_STD is_same_v<_STD remove_const_t<_Ty>, float>) {
_Res = ::__std_minmax_element_f(_First, _Last, false);
} else if constexpr (_STD _Is_any_of_v<_STD remove_const_t<_Ty>, double, long double>) {
_Res = ::__std_minmax_element_d(_First, _Last, false);
} else if constexpr (sizeof(_Ty) == 1) {
_Res = ::__std_minmax_element_1(_First, _Last, _Signed);
} else if constexpr (sizeof(_Ty) == 2) {
_Res = ::__std_minmax_element_2(_First, _Last, _Signed);
Expand Down
38 changes: 35 additions & 3 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ _STL_DISABLE_CLANG_WARNINGS
#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS != 0 ^^^
#endif // ^^^ no support for vector algorithms ^^^

#ifndef _USE_STD_VECTOR_FLOATING_ALGORITHMS
#if _USE_STD_VECTOR_ALGORITHMS && !defined(_M_FP_EXCEPT)
#define _USE_STD_VECTOR_FLOATING_ALGORITHMS 1
#else // ^^^ use vector algorithms and fast math / not use vector algorithms or not use fast math vvv
#define _USE_STD_VECTOR_FLOATING_ALGORITHMS 0
#endif // ^^^ not use vector algorithms or not use fast math ^^^
#else // ^^^ !defined(_USE_STD_VECTOR_FLOATING_ALGORITHMS) / defined(_USE_STD_VECTOR_FLOATING_ALGORITHMS) vvv
#if _USE_STD_VECTOR_FLOATING_ALGORITHMS && !_USE_STD_VECTOR_ALGORITHMS
#error _USE_STD_VECTOR_FLOATING_ALGORITHMS must imply _USE_STD_VECTOR_ALGORITHMS.
#endif // _USE_STD_VECTOR_FLOATING_ALGORITHMS && !_USE_STD_VECTOR_ALGORITHMS
#endif // ^^^ defined(_USE_STD_VECTOR_FLOATING_ALGORITHMS) ^^^

#if _USE_STD_VECTOR_ALGORITHMS
extern "C" {
// The "noalias" attribute tells the compiler optimizer that pointers going into these hand-vectorized algorithms
Expand Down Expand Up @@ -87,11 +99,15 @@ const void* __stdcall __std_min_element_1(const void* _First, const void* _Last,
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_8(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
const void* __stdcall __std_min_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;

const void* __stdcall __std_max_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_8(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_max_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
const void* __stdcall __std_max_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;
} // extern "C"

_STD_BEGIN
Expand Down Expand Up @@ -158,7 +174,11 @@ template <class _Ty>
_Ty* __std_min_element(_Ty* _First, _Ty* _Last) noexcept {
constexpr bool _Signed = _STD is_signed_v<_Ty>;

if constexpr (sizeof(_Ty) == 1) {
if constexpr (_STD is_same_v<_STD remove_const_t<_Ty>, float>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_f(_First, _Last, false)));
} else if constexpr (_STD _Is_any_of_v<_STD remove_const_t<_Ty>, double, long double>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_d(_First, _Last, false)));
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_1(_First, _Last, _Signed)));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_min_element_2(_First, _Last, _Signed)));
Expand All @@ -175,7 +195,11 @@ template <class _Ty>
_Ty* __std_max_element(_Ty* _First, _Ty* _Last) noexcept {
constexpr bool _Signed = _STD is_signed_v<_Ty>;

if constexpr (sizeof(_Ty) == 1) {
if constexpr (_STD is_same_v<_STD remove_const_t<_Ty>, float>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_f(_First, _Last, false)));
} else if constexpr (_STD _Is_any_of_v<_STD remove_const_t<_Ty>, double, long double>) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_d(_First, _Last, false)));
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_1(_First, _Last, _Signed)));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(static_cast<const _Ty*>(::__std_max_element_2(_First, _Last, _Signed)));
Expand Down Expand Up @@ -6586,7 +6610,15 @@ template <class _Iter, class _Pr, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Is_min_max_optimization_safe = // Activate the vector algorithms for min_/max_element?
_Iterator_is_contiguous<_Iter> // The iterator must be contiguous so we can get raw pointers.
&& !_Iterator_is_volatile<_Iter> // The iterator must not be volatile.
&& conjunction_v<disjunction<is_integral<_Elem>, is_pointer<_Elem>>, // Element is of integral or pointer type.
&& conjunction_v<disjunction<
#if _USE_STD_VECTOR_FLOATING_ALGORITHMS
#if defined(__LDBL_DIG__) && __LDBL_DIG__ == 18
is_same<_Elem, float>, is_same<_Elem, double>,
#else // ^^^ 80-bit long double (not supported by MSVC in general, see GH-1316) / 64-bit long double vvv
AlexGuteniev marked this conversation as resolved.
Show resolved Hide resolved
is_floating_point<_Elem>, // Element is floating point or...
#endif // ^^^ 64-bit long double ^^^
#endif // _USE_STD_VECTOR_FLOATING_ALGORITHMS
is_integral<_Elem>, is_pointer<_Elem>>, // ... integral or pointer type.
disjunction< // And either of the following:
#ifdef __cpp_lib_concepts
is_same<_Pr, _RANGES less>, // predicate is ranges::less
Expand Down
Loading