Skip to content

Commit

Permalink
vectorize replace (#4554)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
AlexGuteniev and StephanTLavavej authored Apr 9, 2024
1 parent f5bccb8 commit bed2673
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 19 deletions.
15 changes: 15 additions & 0 deletions benchmarks/src/replace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ const char src[] =
"euismod eros, ut posuere ligula ullamcorper id. Nullam aliquam malesuada est at dignissim. Pellentesque finibus "
"sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis.";

template <class T>
void r(benchmark::State& state) {
const std::vector<T> a(std::begin(src), std::end(src));
std::vector<T> b(std::size(src));

for (auto _ : state) {
b = a;
std::replace(std::begin(b), std::end(b), T{'m'}, T{'w'});
}
}

template <class T>
void rc(benchmark::State& state) {
const std::vector<T> a(std::begin(src), std::end(src));
Expand All @@ -58,6 +69,10 @@ void rc_if(benchmark::State& state) {
}
}

// replace() is vectorized for 4 and 8 bytes only.
BENCHMARK(r<std::uint32_t>);
BENCHMARK(r<std::uint64_t>);

BENCHMARK(rc<std::uint8_t>);
BENCHMARK(rc<std::uint16_t>);
BENCHMARK(rc<std::uint32_t>);
Expand Down
68 changes: 68 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ __declspec(noalias) _Min_max_8i __stdcall __std_minmax_8i(const void* _First, co
__declspec(noalias) _Min_max_8u __stdcall __std_minmax_8u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_f __stdcall __std_minmax_f(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_d __stdcall __std_minmax_d(const void* _First, const void* _Last) noexcept;

// TRANSITION, DevCom-10610477
__declspec(noalias) void __stdcall __std_replace_4(
void* _First, void* _Last, uint32_t _Old_val, uint32_t _New_val) noexcept;
__declspec(noalias) void __stdcall __std_replace_8(
void* _First, void* _Last, uint64_t _Old_val, uint64_t _New_val) noexcept;
} // extern "C"

_STD_BEGIN
Expand Down Expand Up @@ -201,6 +207,24 @@ _Ty1* _Find_first_of_vectorized(
}
}

template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
if constexpr (is_pointer_v<_Ty>) {
#ifdef _WIN64
::__std_replace_8(_First, _Last, reinterpret_cast<uint64_t>(_Old_val), reinterpret_cast<uint64_t>(_New_val));
#else // ^^^ defined(_WIN64) / !defined(_WIN64) vvv
::__std_replace_4(_First, _Last, reinterpret_cast<uint32_t>(_Old_val), reinterpret_cast<uint32_t>(_New_val));
#endif // ^^^ !defined(_WIN64) ^^^
} else if constexpr (sizeof(_Ty) == 4) {
::__std_replace_4(_First, _Last, static_cast<uint32_t>(_Old_val), static_cast<uint32_t>(_New_val));
} else if constexpr (sizeof(_Ty) == 8) {
::__std_replace_8(_First, _Last, static_cast<uint64_t>(_Old_val), static_cast<uint64_t>(_New_val));
} else {
static_assert(_Always_false<_Ty>, "Unexpected size");
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

Expand All @@ -209,6 +233,17 @@ template <class _It1, class _It2, class _Pr>
constexpr bool _Vector_alg_in_find_first_of_is_safe =
_Equal_memcmp_is_safe<_It1, _It2, _Pr> // can replace value comparison with bitwise comparison
&& sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible size

// Can we activate the vector algorithms for replace?
template <class _Iter, class _Ty1>
constexpr bool _Vector_alg_in_replace_is_safe = _Vector_alg_in_find_is_safe<_Iter, _Ty1> // can search for the value
&& sizeof(_Iter_value_t<_Iter>) >= 4; // avx masked op compatible size

// Can we activate the vector algorithms for ranges::replace?
template <class _Iter, class _Ty1, class _Ty2>
constexpr bool _Vector_alg_in_ranges_replace_is_safe =
_Vector_alg_in_replace_is_safe<_Iter, _Ty1> // can search and replace
&& _Vector_alg_in_find_is_safe_elem<_Ty2, _Iter_value_t<_Iter>>; // replacement fits
_STD_END
#endif // _USE_STD_VECTOR_ALGORITHMS

Expand Down Expand Up @@ -3828,6 +3863,22 @@ _CONSTEXPR20 void replace(const _FwdIt _First, const _FwdIt _Last, const _Ty& _O
_STD _Adl_verify_range(_First, _Last);
auto _UFirst = _STD _Get_unwrapped(_First);
const auto _ULast = _STD _Get_unwrapped(_Last);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_replace_is_safe<decltype(_UFirst), _Ty>) {
#if _HAS_CXX20
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
if (_STD _Could_compare_equal_to_value_type<decltype(_UFirst)>(_Oldval)) {
_STD _Replace_vectorized(_STD _To_address(_UFirst), _STD _To_address(_ULast), _Oldval, _Newval);
}

return;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _UFirst != _ULast; ++_UFirst) {
if (*_UFirst == _Oldval) {
*_UFirst = _Newval;
Expand Down Expand Up @@ -3881,6 +3932,23 @@ namespace ranges {
_STL_INTERNAL_STATIC_ASSERT(indirectly_writable<_It, const _Ty2&>);
_STL_INTERNAL_STATIC_ASSERT(indirect_binary_predicate<ranges::equal_to, projected<_It, _Pj>, const _Ty1*>);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (is_same_v<_Pj, identity> && sized_sentinel_for<_Se, _It>
&& _Vector_alg_in_ranges_replace_is_safe<_It, _Ty1, _Ty2>) {
if (!_STD is_constant_evaluated()) {
const auto _Count = _Last - _First;

if (_STD _Could_compare_equal_to_value_type<_It>(_Oldval)) {
const auto _First_ptr = _STD to_address(_First);
const auto _Last_ptr = _First_ptr + _Count;
_STD _Replace_vectorized(_First_ptr, _Last_ptr, _Oldval, _Newval);
}

return _First + _Count;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _First != _Last; ++_First) {
if (_STD invoke(_Proj, *_First) == _Oldval) {
*_First = _Newval;
Expand Down
42 changes: 23 additions & 19 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5850,30 +5850,34 @@ struct _Vector_alg_in_find_is_safe_object_pointers<_Ty1*, _Ty2*>
// either _Ty1 is the same as _Ty2 (ignoring cv-qualifiers), or one of the two is void
disjunction<is_same<remove_cv_t<_Ty1>, remove_cv_t<_Ty2>>, is_void<_Ty1>, is_void<_Ty2>>> {};

// Can we activate the vector algorithms to find a value in a range of elements?
template <class _Ty, class _Elem>
constexpr bool _Vector_alg_in_find_is_safe_elem = disjunction_v<
#ifdef __cpp_lib_byte
// We're finding a std::byte in a range of std::byte.
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>,
#endif // defined(__cpp_lib_byte)
// We're finding an integer in a range of integers.
// This case is the one that requires careful runtime handling in _Could_compare_equal_to_value_type.
conjunction<is_integral<_Ty>, is_integral<_Elem>>,
// We're finding an (object or function) pointer in a range of pointers of the same type.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>,
// We're finding a nullptr in a range of (object or function) pointers.
conjunction<is_same<_Ty, nullptr_t>, is_pointer<_Elem>>,
// We're finding an object pointer in a range of object pointers, and:
// - One of the pointer types is a cv void*.
// - One of the pointer types is a cv1 U* and the other is a cv2 U*.
_Vector_alg_in_find_is_safe_object_pointers<_Ty, _Elem>>;

// Can we activate the vector algorithms for find/count?
template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe =
template <class _Iter, class _Ty>
constexpr bool _Vector_alg_in_find_is_safe =
// The iterator must be contiguous so we can get raw pointers.
_Iterator_is_contiguous<_Iter>
// The iterator must not be volatile.
&& !_Iterator_is_volatile<_Iter>
// And one of the following conditions must be met:
&& disjunction_v<
#ifdef __cpp_lib_byte
// We're finding a std::byte in a range of std::byte.
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>,
#endif // defined(__cpp_lib_byte)
// We're finding an integer in a range of integers.
// This case is the one that requires careful runtime handling in _Could_compare_equal_to_value_type.
conjunction<is_integral<_Ty>, is_integral<_Elem>>,
// We're finding an (object or function) pointer in a range of pointers of the same type.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>,
// We're finding a nullptr in a range of (object or function) pointers.
conjunction<is_same<_Ty, nullptr_t>, is_pointer<_Elem>>,
// We're finding an object pointer in a range of object pointers, and:
// - One of the pointer types is a cv void*.
// - One of the pointer types is a cv1 U* and the other is a cv2 U*.
_Vector_alg_in_find_is_safe_object_pointers<_Ty, _Elem>>;
// The type of the value to find must be compatible with the type of the elements.
&& _Vector_alg_in_find_is_safe_elem<_Ty, _Iter_value_t<_Iter>>;

template <class _InIt, class _Ty>
_NODISCARD constexpr bool _Could_compare_equal_to_value_type(const _Ty& _Val) {
Expand Down
77 changes: 77 additions & 0 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,83 @@ __declspec(noalias) size_t
return __std_mismatch_impl<_Find_traits_8, uint64_t>(_First1, _First2, _Count);
}

__declspec(noalias) void __stdcall __std_replace_4(
void* _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2()) {
const __m256i _Comparand = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_Old_val));
const __m256i _Replacement = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_New_val));
const size_t _Full_length = _Byte_length(_First, _Last);

void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Full_length & ~size_t{0x1F});

while (_First != _Stop_at) {
const __m256i _Data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First));
const __m256i _Mask = _mm256_cmpeq_epi32(_Comparand, _Data);
_mm256_maskstore_epi32(reinterpret_cast<int*>(_First), _Mask, _Replacement);

_Advance_bytes(_First, 32);
}

if (const size_t _Tail_length = _Full_length & 0x1C; _Tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Tail_length >> 2);
const __m256i _Data = _mm256_maskload_epi32(reinterpret_cast<const int*>(_First), _Tail_mask);
const __m256i _Mask = _mm256_and_si256(_mm256_cmpeq_epi32(_Comparand, _Data), _Tail_mask);
_mm256_maskstore_epi32(reinterpret_cast<int*>(_First), _Mask, _Replacement);
}
} else
#endif // !defined(_M_ARM64EC)
{
for (auto _Cur = reinterpret_cast<uint32_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}

__declspec(noalias) void __stdcall __std_replace_8(
void* _First, void* const _Last, const uint64_t _Old_val, const uint64_t _New_val) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2()) {
#ifdef _WIN64
const __m256i _Comparand = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_Old_val));
const __m256i _Replacement = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_New_val));
#else // ^^^ defined(_WIN64) / !defined(_WIN64), workaround, _mm_cvtsi64_si128 does not compile vvv
const __m256i _Comparand = _mm256_set1_epi64x(_Old_val);
const __m256i _Replacement = _mm256_set1_epi64x(_New_val);
#endif // ^^^ !defined(_WIN64) ^^^
const size_t _Full_length = _Byte_length(_First, _Last);

void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Full_length & ~size_t{0x1F});

while (_First != _Stop_at) {
const __m256i _Data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First));
const __m256i _Mask = _mm256_cmpeq_epi64(_Comparand, _Data);
_mm256_maskstore_epi64(reinterpret_cast<long long*>(_First), _Mask, _Replacement);

_Advance_bytes(_First, 32);
}

if (const size_t _Tail_length = _Full_length & 0x18; _Tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Tail_length >> 2);
const __m256i _Data = _mm256_maskload_epi64(reinterpret_cast<const long long*>(_First), _Tail_mask);
const __m256i _Mask = _mm256_and_si256(_mm256_cmpeq_epi64(_Comparand, _Data), _Tail_mask);
_mm256_maskstore_epi64(reinterpret_cast<long long*>(_First), _Mask, _Replacement);
}
} else
#endif // !defined(_M_ARM64EC)
{
for (auto _Cur = reinterpret_cast<uint64_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}

} // extern "C"

#ifndef _M_ARM64EC
Expand Down
52 changes: 52 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,52 @@ namespace test_mismatch_sizes_and_alignments {
}
} // namespace test_mismatch_sizes_and_alignments

template <class FwdIt, class T>
void last_known_good_replace(FwdIt first, FwdIt last, const T old_val, const T new_val) {
for (; first != last; ++first) {
if (*first == old_val) {
*first = new_val;
}
}
}

template <class T>
void test_case_replace(const vector<T>& input, T old_val, T new_val) {
vector<T> replaced_actual(input);
vector<T> replaced_expected(input);
replace(replaced_actual.begin(), replaced_actual.end(), old_val, new_val);
last_known_good_replace(replaced_expected.begin(), replaced_expected.end(), old_val, new_val);
assert(replaced_expected == replaced_actual);

#if _HAS_CXX20
vector<T> replaced_actual_r(input);
ranges::replace(replaced_actual_r, old_val, new_val);
assert(replaced_expected == replaced_actual_r);
#endif // _HAS_CXX20
}

template <class T>
void test_replace(mt19937_64& gen) {
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis(0, 9);
vector<T> input;

input.reserve(dataCount);

{
const T old_val = static_cast<T>(dis(gen));
const T new_val = static_cast<T>(dis(gen));
test_case_replace(input, old_val, new_val);
}

for (size_t i = 0; i != dataCount; ++i) {
input.push_back(static_cast<T>(dis(gen)));
const T old_val = static_cast<T>(dis(gen));
const T new_val = static_cast<T>(dis(gen));
test_case_replace(input, old_val, new_val);
}
}

template <class BidIt>
void last_known_good_reverse(BidIt first, BidIt last) {
for (; first != last && first != --last; ++first) {
Expand Down Expand Up @@ -728,6 +774,12 @@ void test_vector_algorithms(mt19937_64& gen) {
test_mismatch_sizes_and_alignments::test<int>();
test_mismatch_sizes_and_alignments::test<long long>();

// replace() is vectorized for 4 and 8 bytes only.
test_replace<int>(gen);
test_replace<unsigned int>(gen);
test_replace<long long>(gen);
test_replace<unsigned long long>(gen);

test_reverse<char>(gen);
test_reverse<signed char>(gen);
test_reverse<unsigned char>(gen);
Expand Down

0 comments on commit bed2673

Please sign in to comment.