[libc++] Vectorize std::find (#156431)
``` Apple M4: ----------------------------------------------------------------------------- Benchmark old new ----------------------------------------------------------------------------- std::find(vector<char>) (bail 25%)/8 1.43 ns 1.44 ns std::find(vector<char>) (bail 25%)/1024 5.54 ns 5.59 ns std::find(vector<char>) (bail 25%)/8192 38.4 ns 39.1 ns std::find(vector<char>) (bail 25%)/32768 134 ns 136 ns std::find(vector<int>) (bail 25%)/8 1.56 ns 1.57 ns std::find(vector<int>) (bail 25%)/1024 65.3 ns 65.4 ns std::find(vector<int>) (bail 25%)/8192 465 ns 464 ns std::find(vector<int>) (bail 25%)/32768 1832 ns 1832 ns std::find(vector<long long>) (bail 25%)/8 0.920 ns 1.20 ns std::find(vector<long long>) (bail 25%)/1024 65.2 ns 31.2 ns std::find(vector<long long>) (bail 25%)/8192 464 ns 255 ns std::find(vector<long long>) (bail 25%)/32768 1833 ns 992 ns std::find(vector<char>) (process all)/8 1.21 ns 1.22 ns std::find(vector<char>) (process all)/50 1.92 ns 1.93 ns std::find(vector<char>) (process all)/1024 16.6 ns 16.9 ns std::find(vector<char>) (process all)/8192 134 ns 136 ns std::find(vector<char>) (process all)/32768 488 ns 503 ns std::find(vector<int>) (process all)/8 2.45 ns 2.48 ns std::find(vector<int>) (process all)/50 12.7 ns 12.7 ns std::find(vector<int>) (process all)/1024 236 ns 236 ns std::find(vector<int>) (process all)/8192 1830 ns 1834 ns std::find(vector<int>) (process all)/32768 7351 ns 7346 ns std::find(vector<long long>) (process all)/8 2.02 ns 1.45 ns std::find(vector<long long>) (process all)/50 12.0 ns 6.12 ns std::find(vector<long long>) (process all)/1024 235 ns 123 ns std::find(vector<long long>) (process all)/8192 1830 ns 983 ns std::find(vector<long long>) (process all)/32768 7306 ns 3969 ns std::find(vector<bool>) (process all)/8 1.14 ns 1.15 ns std::find(vector<bool>) (process all)/50 1.16 ns 1.17 ns std::find(vector<bool>) (process all)/1024 4.51 ns 4.53 ns std::find(vector<bool>) (process all)/8192 33.6 ns 33.5 ns std::find(vector<bool>) (process all)/1048576 3660 ns 3660 ns ```
This commit is contained in:
parent
8d57211d6f
commit
97367d1046
@ -64,6 +64,8 @@ Improvements and New Features
|
||||
- Multiple internal types have been refactored to use ``[[no_unique_address]]``, resulting in faster compile times and
|
||||
reduced debug information.
|
||||
|
||||
- The performance of ``std::find`` has been improved by up to 2x for integral types
|
||||
|
||||
Deprecations and Removals
|
||||
-------------------------
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
|
||||
#include <__algorithm/find_segment_if.h>
|
||||
#include <__algorithm/min.h>
|
||||
#include <__algorithm/simd_utils.h>
|
||||
#include <__algorithm/unwrap_iter.h>
|
||||
#include <__bit/countr.h>
|
||||
#include <__bit/invert_if.h>
|
||||
@ -44,39 +45,102 @@ _LIBCPP_BEGIN_NAMESPACE_STD
|
||||
// generic implementation
|
||||
template <class _Iter, class _Sent, class _Tp, class _Proj>
|
||||
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Iter
|
||||
__find(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
|
||||
__find_loop(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
|
||||
for (; __first != __last; ++__first)
|
||||
if (std::__invoke(__proj, *__first) == __value)
|
||||
break;
|
||||
return __first;
|
||||
}
|
||||
|
||||
// trivially equality comparable implementations
|
||||
template <class _Tp,
|
||||
class _Up,
|
||||
class _Proj,
|
||||
__enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value &&
|
||||
sizeof(_Tp) == 1,
|
||||
int> = 0>
|
||||
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __find(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
|
||||
if (auto __ret = std::__constexpr_memchr(__first, __value, __last - __first))
|
||||
return __ret;
|
||||
return __last;
|
||||
template <class _Iter, class _Sent, class _Tp, class _Proj>
|
||||
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Iter
|
||||
__find(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
|
||||
return std::__find_loop(std::move(__first), std::move(__last), __value, __proj);
|
||||
}
|
||||
|
||||
#if _LIBCPP_HAS_WIDE_CHARACTERS
|
||||
template <class _Tp,
|
||||
class _Up,
|
||||
class _Proj,
|
||||
__enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value &&
|
||||
sizeof(_Tp) == sizeof(wchar_t) && _LIBCPP_ALIGNOF(_Tp) >= _LIBCPP_ALIGNOF(wchar_t),
|
||||
int> = 0>
|
||||
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __find(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
|
||||
if (auto __ret = std::__constexpr_wmemchr(__first, __value, __last - __first))
|
||||
return __ret;
|
||||
return __last;
|
||||
#if _LIBCPP_VECTORIZE_ALGORITHMS
|
||||
template <class _Tp, class _Up>
|
||||
[[__nodiscard__]] _LIBCPP_HIDE_FROM_ABI
|
||||
_LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __find_vectorized(_Tp* __first, _Tp* __last, _Up __value) {
|
||||
if (!__libcpp_is_constant_evaluated()) {
|
||||
constexpr size_t __unroll_count = 4;
|
||||
constexpr size_t __vec_size = __native_vector_size<_Tp>;
|
||||
using __vec = __simd_vector<_Tp, __vec_size>;
|
||||
|
||||
auto __orig_first = __first;
|
||||
|
||||
auto __values = static_cast<__simd_vector<_Up, __vec_size>>(__value); // broadcast the value
|
||||
while (static_cast<size_t>(__last - __first) >= __unroll_count * __vec_size) [[__unlikely__]] {
|
||||
__vec __lhs[__unroll_count];
|
||||
|
||||
for (size_t __i = 0; __i != __unroll_count; ++__i)
|
||||
__lhs[__i] = std::__load_vector<__vec>(__first + __i * __vec_size);
|
||||
|
||||
for (size_t __i = 0; __i != __unroll_count; ++__i) {
|
||||
if (auto __cmp_res = __lhs[__i] == __values; std::__any_of(__cmp_res)) {
|
||||
auto __offset = __i * __vec_size + std::__find_first_set(__cmp_res);
|
||||
return __first + __offset;
|
||||
}
|
||||
}
|
||||
|
||||
__first += __unroll_count * __vec_size;
|
||||
}
|
||||
|
||||
// check the remaining 0-3 vectors
|
||||
while (static_cast<size_t>(__last - __first) >= __vec_size) {
|
||||
if (auto __cmp_res = std::__load_vector<__vec>(__first) == __values; std::__any_of(__cmp_res)) {
|
||||
return __first + std::__find_first_set(__cmp_res);
|
||||
}
|
||||
__first += __vec_size;
|
||||
}
|
||||
|
||||
if (__last - __first == 0)
|
||||
return __first;
|
||||
|
||||
// Check if we can load elements in front of the current pointer. If that's the case load a vector at
|
||||
// (last - vector_size) to check the remaining elements
|
||||
if (static_cast<size_t>(__first - __orig_first) >= __vec_size) {
|
||||
__first = __last - __vec_size;
|
||||
return __first + std::__find_first_set(std::__load_vector<__vec>(__first) == __values);
|
||||
}
|
||||
}
|
||||
|
||||
__identity __proj;
|
||||
return std::__find_loop(__first, __last, __value, __proj);
|
||||
}
|
||||
#endif // _LIBCPP_HAS_WIDE_CHARACTERS
|
||||
#endif
|
||||
|
||||
#ifndef _LIBCPP_CXX03_LANG
|
||||
// trivially equality comparable implementations
|
||||
template <
|
||||
class _Tp,
|
||||
class _Up,
|
||||
class _Proj,
|
||||
__enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value, int> = 0>
|
||||
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __find(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
|
||||
if constexpr (sizeof(_Tp) == 1) {
|
||||
if (auto __ret = std::__constexpr_memchr(__first, __value, __last - __first))
|
||||
return __ret;
|
||||
return __last;
|
||||
}
|
||||
# if _LIBCPP_HAS_WIDE_CHARACTERS
|
||||
else if constexpr (sizeof(_Tp) == sizeof(wchar_t) && _LIBCPP_ALIGNOF(_Tp) >= _LIBCPP_ALIGNOF(wchar_t)) {
|
||||
if (auto __ret = std::__constexpr_wmemchr(__first, __value, __last - __first))
|
||||
return __ret;
|
||||
return __last;
|
||||
}
|
||||
# endif
|
||||
# if _LIBCPP_VECTORIZE_ALGORITHMS
|
||||
else if constexpr (is_integral<_Tp>::value) {
|
||||
return std::__find_vectorized(__first, __last, __value);
|
||||
}
|
||||
# endif
|
||||
else {
|
||||
__identity __proj;
|
||||
return std::__find_loop(__first, __last, __value, __proj);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: This should also be possible to get right with different signedness
|
||||
// cast integral types to allow vectorization
|
||||
|
||||
@ -114,6 +114,11 @@ template <class _VecT, class _Iter>
|
||||
}(make_index_sequence<__simd_vector_size_v<_VecT>>{});
|
||||
}
|
||||
|
||||
template <class _Tp, size_t _Np>
|
||||
[[__nodiscard__]] _LIBCPP_HIDE_FROM_ABI bool __any_of(__simd_vector<_Tp, _Np> __vec) noexcept {
|
||||
return __builtin_reduce_or(__builtin_convertvector(__vec, __simd_vector<bool, _Np>));
|
||||
}
|
||||
|
||||
template <class _Tp, size_t _Np>
|
||||
[[__nodiscard__]] _LIBCPP_HIDE_FROM_ABI bool __all_of(__simd_vector<_Tp, _Np> __vec) noexcept {
|
||||
return __builtin_reduce_and(__builtin_convertvector(__vec, __simd_vector<bool, _Np>));
|
||||
|
||||
@ -1225,6 +1225,7 @@ module std [system] {
|
||||
header "deque"
|
||||
export *
|
||||
export std.iterator.reverse_iterator
|
||||
export std.algorithm.simd_utils // This is a workaround for https://llvm.org/PR120108.
|
||||
}
|
||||
|
||||
module exception {
|
||||
@ -2238,6 +2239,7 @@ module std [system] {
|
||||
header "vector"
|
||||
export std.iterator.reverse_iterator
|
||||
export *
|
||||
export std.algorithm.simd_utils // This is a workaround for https://llvm.org/PR120108.
|
||||
}
|
||||
|
||||
// Experimental C++ Standard Library interfaces
|
||||
|
||||
@ -51,6 +51,7 @@ int main(int argc, char** argv) {
|
||||
// find
|
||||
bm.template operator()<std::vector<char>>("std::find(vector<char>) (" + comment + ")", std_find);
|
||||
bm.template operator()<std::vector<int>>("std::find(vector<int>) (" + comment + ")", std_find);
|
||||
bm.template operator()<std::vector<long long>>("std::find(vector<long long>) (" + comment + ")", std_find);
|
||||
bm.template operator()<std::deque<int>>("std::find(deque<int>) (" + comment + ")", std_find);
|
||||
bm.template operator()<std::list<int>>("std::find(list<int>) (" + comment + ")", std_find);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user