[ADT] Reinstate "Refactor Bitset to Be More Constexpr-Usable" (#189497)

Reland of #172062 (a71b1d2), which was reverted in b0234d1.

This patch makes essential Bitset member functions constexpr (`set()`,
`any()`, `none()`, `count()`, `operator==`, `!=`, `<`, `\~`) and adds a
new `all()` method. It also introduces a `maskLastWord()` invariant to
ensure unused high bits in the last word are always zero, which is
required for correctness of `operator~`, `set()`, `all()`, and
comparisons on non-word-aligned sizes (e.g., `Bitset<33>`).

Changes from the original reverted PR:
- Replaced `llvm::any_of` with an inline loop to avoid depending on
constexpr `any_of`/`none_of` from `STLExtras` (#172536), which was also
reverted due to a GCC 15.2.1 bootstrap miscompile.
- The patch is now fully self-contained with no prerequisite changes.

Motivation: This is a prerequisite for making `LaneBitmask` a wrapper
around `Bitset`, enabling scalable lane bitmasks beyond 64 bits
(https://discourse.llvm.org/t/rfc-out-of-lanebitmask-bits-again/88613).
This commit is contained in:
Jiachen Yuan 2026-04-02 02:50:10 -07:00 committed by GitHub
parent dc9be4ee30
commit d0bf354828
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 275 additions and 18 deletions

View File

@ -16,21 +16,24 @@
#ifndef LLVM_ADT_BITSET_H
#define LLVM_ADT_BITSET_H
#include "llvm/ADT/bit.h"
#include <array>
#include <climits>
#include <cstdint>
#include <llvm/ADT/STLExtras.h>
namespace llvm {
/// This is a constexpr reimplementation of a subset of std::bitset. It would be
/// nice to use std::bitset directly, but it doesn't support constant
/// initialization.
template <unsigned NumBits>
class Bitset {
template <unsigned NumBits> class Bitset {
using BitWord = uintptr_t;
static constexpr unsigned BitwordBits = sizeof(BitWord) * CHAR_BIT;
static constexpr unsigned RemainderNumBits = NumBits % BitwordBits;
static constexpr BitWord RemainderMask =
RemainderNumBits == 0 ? ~BitWord(0)
: ((BitWord(1) << RemainderNumBits) - 1);
static_assert(BitwordBits == 64 || BitwordBits == 32,
"Unsupported word size");
@ -38,9 +41,16 @@ class Bitset {
static constexpr unsigned NumWords =
(NumBits + BitwordBits - 1) / BitwordBits;
// Returns the index of the last word (0-based). The last word may be
// partially filled and requires masking to maintain the invariant that
// unused high bits are always zero.
static constexpr unsigned getLastWordIndex() { return NumWords - 1; }
using StorageType = std::array<BitWord, NumWords>;
StorageType Bits{};
constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; }
protected:
constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
@ -52,12 +62,13 @@ protected:
uint64_t Elt = B[I];
// On a 32-bit system the storage type will be 32-bit, so we may only
// need half of a uint64_t.
for (size_t offset = 0; offset != 2 && BitsToAssign; ++offset) {
Bits[2 * I + offset] = static_cast<uint32_t>(Elt >> (32 * offset));
for (size_t Offset = 0; Offset != 2 && BitsToAssign; ++Offset) {
Bits[2 * I + Offset] = static_cast<uint32_t>(Elt >> (32 * Offset));
BitsToAssign = BitsToAssign >= 32 ? BitsToAssign - 32 : 0;
}
}
}
maskLastWord();
}
public:
@ -67,8 +78,11 @@ public:
set(I);
}
Bitset &set() {
llvm::fill(Bits, -BitWord(0));
constexpr Bitset &set() {
constexpr const BitWord AllOnes = ~BitWord(0);
for (BitWord &B : Bits)
B = AllOnes;
maskLastWord();
return *this;
}
@ -96,14 +110,27 @@ public:
constexpr size_t size() const { return NumBits; }
bool any() const {
return llvm::any_of(Bits, [](BitWord I) { return I != 0; });
constexpr bool any() const {
for (BitWord B : Bits)
if (B != 0)
return true;
return false;
}
bool none() const { return !any(); }
size_t count() const {
constexpr bool none() const { return !any(); }
constexpr bool all() const {
constexpr const BitWord AllOnes = ~BitWord(0);
for (unsigned I = 0; I < getLastWordIndex(); ++I)
if (Bits[I] != AllOnes)
return false;
return Bits[getLastWordIndex()] == RemainderMask;
}
constexpr size_t count() const {
size_t Count = 0;
for (auto B : Bits)
Count += llvm::popcount(B);
for (BitWord Word : Bits)
Count += popcount(Word);
return Count;
}
@ -144,18 +171,22 @@ public:
constexpr Bitset operator~() const {
Bitset Result = *this;
for (auto &B : Result.Bits)
for (BitWord &B : Result.Bits)
B = ~B;
Result.maskLastWord();
return Result;
}
bool operator==(const Bitset &RHS) const {
return std::equal(std::begin(Bits), std::end(Bits), std::begin(RHS.Bits));
constexpr bool operator==(const Bitset &RHS) const {
for (unsigned I = 0; I < NumWords; ++I)
if (Bits[I] != RHS.Bits[I])
return false;
return true;
}
bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
constexpr bool operator!=(const Bitset &RHS) const { return !(*this == RHS); }
bool operator<(const Bitset &Other) const {
constexpr bool operator<(const Bitset &Other) const {
for (unsigned I = 0, E = size(); I != E; ++I) {
bool LHS = test(I), RHS = Other.test(I);
if (LHS != RHS)

View File

@ -68,4 +68,230 @@ TEST(BitsetTest, Construction) {
EXPECT_TRUE(Test33.verifyValue(TestSingleVal));
Test33.verifyStorageSize(1, 2);
}
TEST(BitsetTest, SetAndQuery) {
// Test set() with all bits.
Bitset<64> A;
A.set();
EXPECT_TRUE(A.all());
EXPECT_TRUE(A.any());
EXPECT_FALSE(A.none());
static_assert(Bitset<64>().set().all());
static_assert(Bitset<33>().set().all());
// Test set() with single bit.
Bitset<64> B;
B.set(10);
B.set(20);
EXPECT_TRUE(B.test(10));
EXPECT_TRUE(B.test(20));
EXPECT_FALSE(B.test(15));
static_assert(Bitset<64>().set(10).test(10));
static_assert(Bitset<64>().set(0).set(63).test(0) &&
Bitset<64>().set(0).set(63).test(63));
static_assert(Bitset<33>().set(32).test(32));
static_assert(Bitset<128>().set(64).set(127).test(64) &&
Bitset<128>().set(64).set(127).test(127));
// Test reset() with single bit.
Bitset<64> C({10, 20, 30});
C.reset(20);
EXPECT_TRUE(C.test(10));
EXPECT_FALSE(C.test(20));
EXPECT_TRUE(C.test(30));
static_assert(!Bitset<64>({10, 20}).reset(10).test(10));
static_assert(Bitset<64>({10, 20}).reset(10).test(20));
static_assert(!Bitset<96>({31, 32, 63}).reset(32).test(32));
static_assert(Bitset<33>({0, 32}).reset(0).test(32));
// Test flip() with single bit.
Bitset<64> D({10, 20});
D.flip(10);
D.flip(30);
EXPECT_FALSE(D.test(10));
EXPECT_TRUE(D.test(20));
EXPECT_TRUE(D.test(30));
static_assert(!Bitset<64>({10, 20}).flip(10).test(10));
static_assert(Bitset<64>({10, 20}).flip(30).test(30));
static_assert(Bitset<100>({50, 99}).flip(50).test(99) &&
!Bitset<100>({50, 99}).flip(50).test(50));
static_assert(Bitset<33>().flip(32).test(32));
// Test operator[].
Bitset<64> E({5, 15, 25});
EXPECT_TRUE(E[5]);
EXPECT_FALSE(E[10]);
EXPECT_TRUE(E[15]);
static_assert(Bitset<64>({10, 20})[10]);
static_assert(!Bitset<64>({10, 20})[15]);
static_assert(Bitset<128>({127})[127]);
static_assert(Bitset<96>({63, 64})[63] && Bitset<96>({63, 64})[64]);
// Test size().
EXPECT_EQ(A.size(), 64u);
Bitset<33> F;
EXPECT_EQ(F.size(), 33u);
static_assert(Bitset<64>().size() == 64);
static_assert(Bitset<128>().size() == 128);
static_assert(Bitset<33>().size() == 33);
// Test any() and none().
static_assert(!Bitset<64>().any());
static_assert(Bitset<64>().none());
static_assert(Bitset<64>({10}).any());
static_assert(!Bitset<64>({10}).none());
}
TEST(BitsetTest, ComparisonOperators) {
// Test operator==.
Bitset<64> A({10, 20, 30});
Bitset<64> B({10, 20, 30});
Bitset<64> C({10, 20, 31});
EXPECT_TRUE(A == B);
EXPECT_FALSE(A == C);
static_assert(Bitset<64>({10, 20}) == Bitset<64>({10, 20}));
static_assert(Bitset<64>({10, 20}) != Bitset<64>({10, 21}));
// Test operator< (lexicographic comparison, bit 0 is least significant).
static_assert(Bitset<64>({5, 11}) <
Bitset<64>({5, 10})); // At bit 10: A=0, B=1.
static_assert(!(Bitset<64>({5, 10}) < Bitset<64>({5, 10})));
}
TEST(BitsetTest, BitwiseNot) {
// Test operator~.
Bitset<64> A;
A.set();
Bitset<64> B = ~A;
EXPECT_TRUE(B.none());
static_assert((~Bitset<64>()).all());
static_assert((~Bitset<64>().set()).none());
static_assert((~Bitset<33>().set()).none());
}
TEST(BitsetTest, BitwiseOperators) {
// Test operator&.
Bitset<64> A({10, 20, 30});
Bitset<64> B({20, 30, 40});
Bitset<64> Result1 = A & B;
EXPECT_FALSE(Result1.test(10));
EXPECT_TRUE(Result1.test(20));
EXPECT_TRUE(Result1.test(30));
EXPECT_FALSE(Result1.test(40));
EXPECT_EQ(Result1.count(), 2u);
static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(20));
static_assert(!(Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(10));
static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).count() == 1);
static_assert(
(Bitset<96>({31, 32, 63, 64}) & Bitset<96>({32, 64, 95})).count() == 2);
static_assert((Bitset<33>({0, 32}) & Bitset<33>({32})).test(32));
// Test operator&=.
Bitset<64> C({10, 20, 30});
C &= Bitset<64>({20, 30, 40});
EXPECT_FALSE(C.test(10));
EXPECT_TRUE(C.test(20));
EXPECT_TRUE(C.test(30));
EXPECT_FALSE(C.test(40));
constexpr Bitset<64> TestAnd = [] {
Bitset<64> X({10, 20, 30});
X &= Bitset<64>({20, 30, 40});
return X;
}();
static_assert(TestAnd.test(20) && TestAnd.test(30) && !TestAnd.test(10));
constexpr Bitset<100> TestAnd100 = [] {
Bitset<100> X({10, 50, 99});
X &= Bitset<100>({50, 99});
return X;
}();
static_assert(TestAnd100.count() == 2 && TestAnd100.test(50) &&
TestAnd100.test(99));
// Test operator|.
Bitset<64> D({10, 20});
Bitset<64> E({20, 30});
Bitset<64> Result2 = D | E;
EXPECT_TRUE(Result2.test(10));
EXPECT_TRUE(Result2.test(20));
EXPECT_TRUE(Result2.test(30));
EXPECT_EQ(Result2.count(), 3u);
static_assert((Bitset<64>({10}) | Bitset<64>({20})).count() == 2);
static_assert((Bitset<128>({0, 64, 127}) | Bitset<128>({64, 100})).count() ==
4);
static_assert((Bitset<33>({0, 16}) | Bitset<33>({16, 32})).count() == 3);
// Test operator|=.
Bitset<64> F({10, 20});
F |= Bitset<64>({20, 30});
EXPECT_TRUE(F.test(10));
EXPECT_TRUE(F.test(20));
EXPECT_TRUE(F.test(30));
constexpr Bitset<64> TestOr = [] {
Bitset<64> X({10});
X |= Bitset<64>({20});
return X;
}();
static_assert(TestOr.test(10) && TestOr.test(20));
constexpr Bitset<96> TestOr96 = [] {
Bitset<96> X({31, 63});
X |= Bitset<96>({32, 64});
return X;
}();
static_assert(TestOr96.count() == 4);
// Test operator^.
Bitset<64> G({10, 20, 30});
Bitset<64> H({20, 30, 40});
Bitset<64> Result3 = G ^ H;
EXPECT_TRUE(Result3.test(10));
EXPECT_FALSE(Result3.test(20));
EXPECT_FALSE(Result3.test(30));
EXPECT_TRUE(Result3.test(40));
EXPECT_EQ(Result3.count(), 2u);
static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(10));
static_assert(!(Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(20));
static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(30));
static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).count() == 2);
static_assert((Bitset<100>({0, 50, 99}) ^ Bitset<100>({50})).count() == 2);
static_assert((Bitset<33>({0, 32}) ^ Bitset<33>({0, 16})).count() == 2);
// Test operator^=.
Bitset<64> I({10, 20, 30});
I ^= Bitset<64>({20, 30, 40});
EXPECT_TRUE(I.test(10));
EXPECT_FALSE(I.test(20));
EXPECT_FALSE(I.test(30));
EXPECT_TRUE(I.test(40));
constexpr Bitset<64> TestXor = [] {
Bitset<64> X({10, 20});
X ^= Bitset<64>({20, 30});
return X;
}();
static_assert(TestXor.test(10) && !TestXor.test(20) && TestXor.test(30));
constexpr Bitset<128> TestXor128 = [] {
Bitset<128> X({0, 64, 127});
X ^= Bitset<128>({64});
return X;
}();
static_assert(TestXor128.count() == 2 && TestXor128.test(0) &&
TestXor128.test(127));
}
} // namespace