Implement vector version of memchr, and dispatch to same (#177711)
As in the description. This implementation shares quite a bit of code with the wide-read versions of string_length.
This commit is contained in:
parent
742af32b67
commit
ecee70e210
@ -17,11 +17,12 @@ namespace LIBC_NAMESPACE_DECL {
|
||||
namespace clang_vector {
|
||||
|
||||
// Exploit the underlying integer representation to do a variable shift.
|
||||
LIBC_INLINE constexpr cpp::simd_mask<char> shift_mask(cpp::simd_mask<char> m,
|
||||
size_t shift) {
|
||||
template <typename byte_ty>
|
||||
LIBC_INLINE constexpr cpp::simd_mask<byte_ty> shift_mask(cpp::simd_mask<char> m,
|
||||
size_t shift) {
|
||||
using bitmask_ty = cpp::internal::get_as_integer_type_t<cpp::simd_mask<char>>;
|
||||
bitmask_ty r = cpp::bit_cast<bitmask_ty>(m) >> shift;
|
||||
return cpp::bit_cast<cpp::simd_mask<char>>(r);
|
||||
return cpp::bit_cast<cpp::simd_mask<byte_ty>>(r);
|
||||
}
|
||||
|
||||
LIBC_NO_SANITIZE_OOB_ACCESS LIBC_INLINE size_t string_length(const char *src) {
|
||||
@ -34,8 +35,8 @@ LIBC_NO_SANITIZE_OOB_ACCESS LIBC_INLINE size_t string_length(const char *src) {
|
||||
cpp::simd<char> chars = cpp::load<cpp::simd<char>>(aligned, /*aligned=*/true);
|
||||
cpp::simd_mask<char> mask = chars == null_byte;
|
||||
size_t offset = src - reinterpret_cast<const char *>(aligned);
|
||||
if (cpp::any_of(shift_mask(mask, offset)))
|
||||
return cpp::find_first_set(shift_mask(mask, offset));
|
||||
if (cpp::any_of(shift_mask<char>(mask, offset)))
|
||||
return cpp::find_first_set(shift_mask<char>(mask, offset));
|
||||
|
||||
for (;;) {
|
||||
cpp::simd<char> chars = cpp::load<cpp::simd<char>>(++aligned,
|
||||
@ -46,6 +47,46 @@ LIBC_NO_SANITIZE_OOB_ACCESS LIBC_INLINE size_t string_length(const char *src) {
|
||||
cpp::find_first_set(mask);
|
||||
}
|
||||
}
|
||||
|
||||
LIBC_INLINE static void *calculate_find_first_character_return(
|
||||
const char *src, cpp::simd_mask<char> c_mask, size_t n_left) {
|
||||
size_t c_offset = cpp::find_first_set(c_mask);
|
||||
if (n_left < c_offset)
|
||||
return nullptr;
|
||||
return const_cast<char *>(src) + c_offset;
|
||||
}
|
||||
|
||||
LIBC_NO_SANITIZE_OOB_ACCESS LIBC_INLINE static void *
|
||||
find_first_character(const unsigned char *s, unsigned char c, size_t n) {
|
||||
using Vector = cpp::simd<char>;
|
||||
using Mask = cpp::simd_mask<char>;
|
||||
Vector c_byte = c;
|
||||
|
||||
size_t alignment = alignof(Vector);
|
||||
const Vector *aligned =
|
||||
reinterpret_cast<const Vector *>(__builtin_align_down(s, alignment));
|
||||
|
||||
Vector chars = cpp::load<Vector>(aligned, /*aligned=*/true);
|
||||
Mask cmp_v = chars == c_byte;
|
||||
size_t offset = s - reinterpret_cast<const unsigned char *>(aligned);
|
||||
|
||||
cmp_v = shift_mask<unsigned char>(cmp_v, offset);
|
||||
if (cpp::any_of(cmp_v))
|
||||
return calculate_find_first_character_return(
|
||||
reinterpret_cast<const char *>(s), cmp_v, n);
|
||||
|
||||
for (size_t bytes_checked = sizeof(Vector) - offset; bytes_checked < n;
|
||||
bytes_checked += sizeof(Vector)) {
|
||||
aligned++;
|
||||
Vector chars = cpp::load<Vector>(aligned, /*aligned=*/true);
|
||||
Mask cmp_v = chars == c_byte;
|
||||
if (cpp::any_of(cmp_v))
|
||||
return calculate_find_first_character_return(
|
||||
reinterpret_cast<const char *>(aligned), cmp_v, n - bytes_checked);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace clang_vector
|
||||
|
||||
} // namespace LIBC_NAMESPACE_DECL
|
||||
|
||||
@ -17,10 +17,11 @@ namespace LIBC_NAMESPACE_DECL {
|
||||
|
||||
namespace internal::arch_vector {
|
||||
|
||||
// Return a bit-mask with the nth bit set if the nth-byte in block_ptr is zero.
|
||||
// Return a bit-mask with the nth bit set if the nth-byte in block_ptr matches
|
||||
// character c.
|
||||
template <typename Vector, typename Mask>
|
||||
LIBC_NO_SANITIZE_OOB_ACCESS LIBC_INLINE static Mask
|
||||
compare_and_mask(const Vector *block_ptr);
|
||||
compare_and_mask(const Vector *block_ptr, char c);
|
||||
|
||||
template <typename Vector, typename Mask,
|
||||
decltype(compare_and_mask<Vector, Mask>)>
|
||||
@ -30,13 +31,13 @@ string_length_vector(const char *src) {
|
||||
|
||||
const Vector *block_ptr =
|
||||
reinterpret_cast<const Vector *>(src - misalign_bytes);
|
||||
auto cmp = compare_and_mask<Vector, Mask>(block_ptr) >> misalign_bytes;
|
||||
auto cmp = compare_and_mask<Vector, Mask>(block_ptr, 0) >> misalign_bytes;
|
||||
if (cmp)
|
||||
return cpp::countr_zero(cmp);
|
||||
|
||||
while (true) {
|
||||
block_ptr++;
|
||||
cmp = compare_and_mask<Vector, Mask>(block_ptr);
|
||||
cmp = compare_and_mask<Vector, Mask>(block_ptr, 0);
|
||||
if (cmp)
|
||||
return static_cast<size_t>(reinterpret_cast<uintptr_t>(block_ptr) -
|
||||
reinterpret_cast<uintptr_t>(src) +
|
||||
@ -44,13 +45,50 @@ string_length_vector(const char *src) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Mask>
|
||||
LIBC_INLINE static void *
|
||||
calculate_find_first_character_return(const unsigned char *src, Mask c_mask,
|
||||
size_t n_left) {
|
||||
size_t c_offset = cpp::countr_zero(c_mask);
|
||||
if (n_left < c_offset)
|
||||
return nullptr;
|
||||
return const_cast<unsigned char *>(src) + c_offset;
|
||||
}
|
||||
|
||||
template <typename Vector, typename Mask,
|
||||
decltype(compare_and_mask<Vector, Mask>)>
|
||||
LIBC_NO_SANITIZE_OOB_ACCESS LIBC_INLINE static void *
|
||||
find_first_character_vector(const unsigned char *s, unsigned char c, size_t n) {
|
||||
uintptr_t misalign_bytes = reinterpret_cast<uintptr_t>(s) % sizeof(Vector);
|
||||
|
||||
const Vector *block_ptr =
|
||||
reinterpret_cast<const Vector *>(s - misalign_bytes);
|
||||
auto cmp_bytes =
|
||||
compare_and_mask<Vector, Mask>(block_ptr, c) >> misalign_bytes;
|
||||
if (cmp_bytes)
|
||||
return calculate_find_first_character_return<Mask>(
|
||||
reinterpret_cast<const unsigned char *>(block_ptr) + misalign_bytes,
|
||||
cmp_bytes, n);
|
||||
|
||||
for (size_t bytes_checked = sizeof(Vector) - misalign_bytes;
|
||||
bytes_checked < n; bytes_checked += sizeof(Vector)) {
|
||||
block_ptr++;
|
||||
cmp_bytes = compare_and_mask<Vector, Mask>(block_ptr, c);
|
||||
if (cmp_bytes)
|
||||
return calculate_find_first_character_return<Mask>(
|
||||
reinterpret_cast<const unsigned char *>(block_ptr), cmp_bytes,
|
||||
n - bytes_checked);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
LIBC_INLINE uint32_t
|
||||
compare_and_mask<__m128i, uint32_t>(const __m128i *block_ptr) {
|
||||
__m128i v = _mm_load_si128(block_ptr);
|
||||
__m128i z = _mm_setzero_si128();
|
||||
__m128i c = _mm_cmpeq_epi8(z, v);
|
||||
return _mm_movemask_epi8(c);
|
||||
compare_and_mask<__m128i, uint32_t>(const __m128i *block_ptr, char c) {
|
||||
__m128i b = _mm_load_si128(block_ptr);
|
||||
__m128i set = _mm_set1_epi8(c);
|
||||
__m128i cmp = _mm_cmpeq_epi8(b, set);
|
||||
return _mm_movemask_epi8(cmp);
|
||||
}
|
||||
|
||||
namespace sse2 {
|
||||
@ -58,16 +96,24 @@ namespace sse2 {
|
||||
return string_length_vector<__m128i, uint32_t,
|
||||
compare_and_mask<__m128i, uint32_t>>(src);
|
||||
}
|
||||
|
||||
[[maybe_unused]] LIBC_INLINE void *
|
||||
find_first_character(const unsigned char *s, unsigned char c, size_t n) {
|
||||
return find_first_character_vector<__m128i, uint32_t,
|
||||
compare_and_mask<__m128i, uint32_t>>(s, c,
|
||||
n);
|
||||
}
|
||||
|
||||
} // namespace sse2
|
||||
|
||||
#if defined(__AVX2__)
|
||||
template <>
|
||||
LIBC_INLINE uint32_t
|
||||
compare_and_mask<__m256i, uint32_t>(const __m256i *block_ptr) {
|
||||
__m256i v = _mm256_load_si256(block_ptr);
|
||||
__m256i z = _mm256_setzero_si256();
|
||||
__m256i c = _mm256_cmpeq_epi8(z, v);
|
||||
return _mm256_movemask_epi8(c);
|
||||
compare_and_mask<__m256i, uint32_t>(const __m256i *block_ptr, char c) {
|
||||
__m256i b = _mm256_load_si256(block_ptr);
|
||||
__m256i set = _mm256_set1_epi16(c);
|
||||
__m256i cmp = _mm256_cmpeq_epi8(b, set);
|
||||
return _mm256_movemask_epi8(cmp);
|
||||
}
|
||||
|
||||
namespace avx2 {
|
||||
@ -75,25 +121,45 @@ namespace avx2 {
|
||||
return string_length_vector<__m256i, uint32_t,
|
||||
compare_and_mask<__m256i, uint32_t>>(src);
|
||||
}
|
||||
|
||||
[[maybe_unused]] LIBC_INLINE void *
|
||||
find_first_character(const unsigned char *s, unsigned char c, size_t n) {
|
||||
return find_first_character_vector<__m256i, uint32_t,
|
||||
compare_and_mask<__m256i, uint32_t>>(s, c,
|
||||
n);
|
||||
}
|
||||
} // namespace avx2
|
||||
#endif
|
||||
|
||||
#if defined(__AVX512F__)
|
||||
template <>
|
||||
LIBC_INLINE __mmask64
|
||||
compare_and_mask<__m512i, __mmask64>(const __m512i *block_ptr) {
|
||||
compare_and_mask<__m512i, __mmask64>(const __m512i *block_ptr, char c) {
|
||||
__m512i v = _mm512_load_si512(block_ptr);
|
||||
__m512i z = _mm512_setzero_si512();
|
||||
return _mm512_cmp_epu8_mask(z, v, _MM_CMPINT_EQ);
|
||||
__m512i set = _mm512_set1_epi8(c);
|
||||
return _mm512_cmp_epu8_mask(set, v, _MM_CMPINT_EQ);
|
||||
}
|
||||
|
||||
namespace avx512 {
|
||||
[[maybe_unused]] LIBC_INLINE size_t string_length(const char *src) {
|
||||
return string_length_vector<__m512i, __mmask64,
|
||||
compare_and_mask<__m512i, __mmask64>>(src);
|
||||
}
|
||||
|
||||
[[maybe_unused]] LIBC_INLINE void *
|
||||
find_first_character(const unsigned char *s, unsigned char c, size_t n) {
|
||||
return find_first_character_vector<__m512i, __mmask64,
|
||||
compare_and_mask<__m512i, __mmask64>>(s, c,
|
||||
n);
|
||||
}
|
||||
|
||||
} // namespace avx512
|
||||
#endif
|
||||
|
||||
// We could directly use the various <function>_vector templates here, but this
|
||||
// indirection allows comparing the various implementations elsewhere by name,
|
||||
// without having to instantiate the templates by hand at those locations.
|
||||
|
||||
[[maybe_unused]] LIBC_INLINE size_t string_length(const char *src) {
|
||||
#if defined(__AVX512F__)
|
||||
return avx512::string_length(src);
|
||||
@ -104,6 +170,17 @@ namespace avx512 {
|
||||
#endif
|
||||
}
|
||||
|
||||
[[maybe_unused]] LIBC_INLINE void *
|
||||
find_first_character(const unsigned char *s, unsigned char c, size_t n) {
|
||||
#if defined(__AVX512F__)
|
||||
return avx512::find_first_character(s, c, n);
|
||||
#elif defined(__AVX2__)
|
||||
return avx2::find_first_character(s, c, n);
|
||||
#else
|
||||
return sse2::find_first_character(s, c, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace internal::arch_vector
|
||||
|
||||
} // namespace LIBC_NAMESPACE_DECL
|
||||
|
||||
@ -98,4 +98,25 @@ TEST_F(LlvmLibcWideAccessMemoryTest, StringLength) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LlvmLibcWideAccessMemoryTest, FindFirstChar) {
|
||||
// 1.5 k long vector of a's.
|
||||
TwoKilobyteBuffer buf;
|
||||
inline_memset(buf.data(), 'a', buf.size());
|
||||
buf[buf.size() - 1] = 'b';
|
||||
this->TestMemoryAccess(buf, [this, buf](const char *test_data) {
|
||||
// Found case
|
||||
ASSERT_EQ(
|
||||
reinterpret_cast<const void *>(internal::find_first_character_impl(
|
||||
reinterpret_cast<const unsigned char *>(test_data), 'b',
|
||||
size_t(buf.size()))),
|
||||
reinterpret_cast<const void *>(test_data + size_t(buf.size()) - 1));
|
||||
// Not found case
|
||||
ASSERT_EQ(
|
||||
reinterpret_cast<const void *>(internal::find_first_character_impl(
|
||||
reinterpret_cast<const unsigned char *>(test_data), 'c',
|
||||
size_t(buf.size()))),
|
||||
nullptr);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace LIBC_NAMESPACE_DECL
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user