diff --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h index 3541db49dd97..9dc0f24b1d9f 100644 --- a/llvm/include/llvm/ADT/Bitset.h +++ b/llvm/include/llvm/ADT/Bitset.h @@ -16,21 +16,24 @@ #ifndef LLVM_ADT_BITSET_H #define LLVM_ADT_BITSET_H +#include "llvm/ADT/bit.h" #include #include #include -#include namespace llvm { /// This is a constexpr reimplementation of a subset of std::bitset. It would be /// nice to use std::bitset directly, but it doesn't support constant /// initialization. -template -class Bitset { +template class Bitset { using BitWord = uintptr_t; static constexpr unsigned BitwordBits = sizeof(BitWord) * CHAR_BIT; + static constexpr unsigned RemainderNumBits = NumBits % BitwordBits; + static constexpr BitWord RemainderMask = + RemainderNumBits == 0 ? ~BitWord(0) + : ((BitWord(1) << RemainderNumBits) - 1); static_assert(BitwordBits == 64 || BitwordBits == 32, "Unsupported word size"); @@ -38,9 +41,16 @@ class Bitset { static constexpr unsigned NumWords = (NumBits + BitwordBits - 1) / BitwordBits; + // Returns the index of the last word (0-based). The last word may be + // partially filled and requires masking to maintain the invariant that + // unused high bits are always zero. + static constexpr unsigned getLastWordIndex() { return NumWords - 1; } + using StorageType = std::array; StorageType Bits{}; + constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; } + protected: constexpr Bitset(const std::array &B) { if constexpr (sizeof(BitWord) == sizeof(uint64_t)) { @@ -52,12 +62,13 @@ protected: uint64_t Elt = B[I]; // On a 32-bit system the storage type will be 32-bit, so we may only // need half of a uint64_t. - for (size_t offset = 0; offset != 2 && BitsToAssign; ++offset) { - Bits[2 * I + offset] = static_cast(Elt >> (32 * offset)); + for (size_t Offset = 0; Offset != 2 && BitsToAssign; ++Offset) { + Bits[2 * I + Offset] = static_cast(Elt >> (32 * Offset)); BitsToAssign = BitsToAssign >= 32 ? BitsToAssign - 32 : 0; } } } + maskLastWord(); } public: @@ -67,8 +78,11 @@ public: set(I); } - Bitset &set() { - llvm::fill(Bits, -BitWord(0)); + constexpr Bitset &set() { + constexpr const BitWord AllOnes = ~BitWord(0); + for (BitWord &B : Bits) + B = AllOnes; + maskLastWord(); return *this; } @@ -96,14 +110,27 @@ public: constexpr size_t size() const { return NumBits; } - bool any() const { - return llvm::any_of(Bits, [](BitWord I) { return I != 0; }); + constexpr bool any() const { + for (BitWord B : Bits) + if (B != 0) + return true; + return false; } - bool none() const { return !any(); } - size_t count() const { + + constexpr bool none() const { return !any(); } + + constexpr bool all() const { + constexpr const BitWord AllOnes = ~BitWord(0); + for (unsigned I = 0; I < getLastWordIndex(); ++I) + if (Bits[I] != AllOnes) + return false; + return Bits[getLastWordIndex()] == RemainderMask; + } + + constexpr size_t count() const { size_t Count = 0; - for (auto B : Bits) - Count += llvm::popcount(B); + for (BitWord Word : Bits) + Count += popcount(Word); return Count; } @@ -144,18 +171,22 @@ public: constexpr Bitset operator~() const { Bitset Result = *this; - for (auto &B : Result.Bits) + for (BitWord &B : Result.Bits) B = ~B; + Result.maskLastWord(); return Result; } - bool operator==(const Bitset &RHS) const { - return std::equal(std::begin(Bits), std::end(Bits), std::begin(RHS.Bits)); + constexpr bool operator==(const Bitset &RHS) const { + for (unsigned I = 0; I < NumWords; ++I) + if (Bits[I] != RHS.Bits[I]) + return false; + return true; } - bool operator!=(const Bitset &RHS) const { return !(*this == RHS); } + constexpr bool operator!=(const Bitset &RHS) const { return !(*this == RHS); } - bool operator<(const Bitset &Other) const { + constexpr bool operator<(const Bitset &Other) const { for (unsigned I = 0, E = size(); I != E; ++I) { bool LHS = test(I), RHS = Other.test(I); if (LHS != RHS) diff --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp index 0ecd213d6a78..678197e31a37 100644 --- a/llvm/unittests/ADT/BitsetTest.cpp +++ b/llvm/unittests/ADT/BitsetTest.cpp @@ -68,4 +68,230 @@ TEST(BitsetTest, Construction) { EXPECT_TRUE(Test33.verifyValue(TestSingleVal)); Test33.verifyStorageSize(1, 2); } + +TEST(BitsetTest, SetAndQuery) { + // Test set() with all bits. + Bitset<64> A; + A.set(); + EXPECT_TRUE(A.all()); + EXPECT_TRUE(A.any()); + EXPECT_FALSE(A.none()); + + static_assert(Bitset<64>().set().all()); + static_assert(Bitset<33>().set().all()); + + // Test set() with single bit. + Bitset<64> B; + B.set(10); + B.set(20); + EXPECT_TRUE(B.test(10)); + EXPECT_TRUE(B.test(20)); + EXPECT_FALSE(B.test(15)); + + static_assert(Bitset<64>().set(10).test(10)); + static_assert(Bitset<64>().set(0).set(63).test(0) && + Bitset<64>().set(0).set(63).test(63)); + static_assert(Bitset<33>().set(32).test(32)); + static_assert(Bitset<128>().set(64).set(127).test(64) && + Bitset<128>().set(64).set(127).test(127)); + + // Test reset() with single bit. + Bitset<64> C({10, 20, 30}); + C.reset(20); + EXPECT_TRUE(C.test(10)); + EXPECT_FALSE(C.test(20)); + EXPECT_TRUE(C.test(30)); + + static_assert(!Bitset<64>({10, 20}).reset(10).test(10)); + static_assert(Bitset<64>({10, 20}).reset(10).test(20)); + static_assert(!Bitset<96>({31, 32, 63}).reset(32).test(32)); + static_assert(Bitset<33>({0, 32}).reset(0).test(32)); + + // Test flip() with single bit. + Bitset<64> D({10, 20}); + D.flip(10); + D.flip(30); + EXPECT_FALSE(D.test(10)); + EXPECT_TRUE(D.test(20)); + EXPECT_TRUE(D.test(30)); + + static_assert(!Bitset<64>({10, 20}).flip(10).test(10)); + static_assert(Bitset<64>({10, 20}).flip(30).test(30)); + static_assert(Bitset<100>({50, 99}).flip(50).test(99) && + !Bitset<100>({50, 99}).flip(50).test(50)); + static_assert(Bitset<33>().flip(32).test(32)); + + // Test operator[]. + Bitset<64> E({5, 15, 25}); + EXPECT_TRUE(E[5]); + EXPECT_FALSE(E[10]); + EXPECT_TRUE(E[15]); + + static_assert(Bitset<64>({10, 20})[10]); + static_assert(!Bitset<64>({10, 20})[15]); + static_assert(Bitset<128>({127})[127]); + static_assert(Bitset<96>({63, 64})[63] && Bitset<96>({63, 64})[64]); + + // Test size(). + EXPECT_EQ(A.size(), 64u); + Bitset<33> F; + EXPECT_EQ(F.size(), 33u); + + static_assert(Bitset<64>().size() == 64); + static_assert(Bitset<128>().size() == 128); + static_assert(Bitset<33>().size() == 33); + + // Test any() and none(). + static_assert(!Bitset<64>().any()); + static_assert(Bitset<64>().none()); + static_assert(Bitset<64>({10}).any()); + static_assert(!Bitset<64>({10}).none()); +} + +TEST(BitsetTest, ComparisonOperators) { + // Test operator==. + Bitset<64> A({10, 20, 30}); + Bitset<64> B({10, 20, 30}); + Bitset<64> C({10, 20, 31}); + EXPECT_TRUE(A == B); + EXPECT_FALSE(A == C); + + static_assert(Bitset<64>({10, 20}) == Bitset<64>({10, 20})); + static_assert(Bitset<64>({10, 20}) != Bitset<64>({10, 21})); + + // Test operator< (lexicographic comparison, bit 0 is least significant). + static_assert(Bitset<64>({5, 11}) < + Bitset<64>({5, 10})); // At bit 10: A=0, B=1. + static_assert(!(Bitset<64>({5, 10}) < Bitset<64>({5, 10}))); +} + +TEST(BitsetTest, BitwiseNot) { + // Test operator~. + Bitset<64> A; + A.set(); + Bitset<64> B = ~A; + EXPECT_TRUE(B.none()); + + static_assert((~Bitset<64>()).all()); + static_assert((~Bitset<64>().set()).none()); + static_assert((~Bitset<33>().set()).none()); +} + +TEST(BitsetTest, BitwiseOperators) { + // Test operator&. + Bitset<64> A({10, 20, 30}); + Bitset<64> B({20, 30, 40}); + Bitset<64> Result1 = A & B; + EXPECT_FALSE(Result1.test(10)); + EXPECT_TRUE(Result1.test(20)); + EXPECT_TRUE(Result1.test(30)); + EXPECT_FALSE(Result1.test(40)); + EXPECT_EQ(Result1.count(), 2u); + + static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(20)); + static_assert(!(Bitset<64>({10, 20}) & Bitset<64>({20, 30})).test(10)); + static_assert((Bitset<64>({10, 20}) & Bitset<64>({20, 30})).count() == 1); + static_assert( + (Bitset<96>({31, 32, 63, 64}) & Bitset<96>({32, 64, 95})).count() == 2); + static_assert((Bitset<33>({0, 32}) & Bitset<33>({32})).test(32)); + + // Test operator&=. + Bitset<64> C({10, 20, 30}); + C &= Bitset<64>({20, 30, 40}); + EXPECT_FALSE(C.test(10)); + EXPECT_TRUE(C.test(20)); + EXPECT_TRUE(C.test(30)); + EXPECT_FALSE(C.test(40)); + + constexpr Bitset<64> TestAnd = [] { + Bitset<64> X({10, 20, 30}); + X &= Bitset<64>({20, 30, 40}); + return X; + }(); + static_assert(TestAnd.test(20) && TestAnd.test(30) && !TestAnd.test(10)); + + constexpr Bitset<100> TestAnd100 = [] { + Bitset<100> X({10, 50, 99}); + X &= Bitset<100>({50, 99}); + return X; + }(); + static_assert(TestAnd100.count() == 2 && TestAnd100.test(50) && + TestAnd100.test(99)); + + // Test operator|. + Bitset<64> D({10, 20}); + Bitset<64> E({20, 30}); + Bitset<64> Result2 = D | E; + EXPECT_TRUE(Result2.test(10)); + EXPECT_TRUE(Result2.test(20)); + EXPECT_TRUE(Result2.test(30)); + EXPECT_EQ(Result2.count(), 3u); + + static_assert((Bitset<64>({10}) | Bitset<64>({20})).count() == 2); + static_assert((Bitset<128>({0, 64, 127}) | Bitset<128>({64, 100})).count() == + 4); + static_assert((Bitset<33>({0, 16}) | Bitset<33>({16, 32})).count() == 3); + + // Test operator|=. + Bitset<64> F({10, 20}); + F |= Bitset<64>({20, 30}); + EXPECT_TRUE(F.test(10)); + EXPECT_TRUE(F.test(20)); + EXPECT_TRUE(F.test(30)); + + constexpr Bitset<64> TestOr = [] { + Bitset<64> X({10}); + X |= Bitset<64>({20}); + return X; + }(); + static_assert(TestOr.test(10) && TestOr.test(20)); + + constexpr Bitset<96> TestOr96 = [] { + Bitset<96> X({31, 63}); + X |= Bitset<96>({32, 64}); + return X; + }(); + static_assert(TestOr96.count() == 4); + + // Test operator^. + Bitset<64> G({10, 20, 30}); + Bitset<64> H({20, 30, 40}); + Bitset<64> Result3 = G ^ H; + EXPECT_TRUE(Result3.test(10)); + EXPECT_FALSE(Result3.test(20)); + EXPECT_FALSE(Result3.test(30)); + EXPECT_TRUE(Result3.test(40)); + EXPECT_EQ(Result3.count(), 2u); + + static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(10)); + static_assert(!(Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(20)); + static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).test(30)); + static_assert((Bitset<64>({10, 20}) ^ Bitset<64>({20, 30})).count() == 2); + static_assert((Bitset<100>({0, 50, 99}) ^ Bitset<100>({50})).count() == 2); + static_assert((Bitset<33>({0, 32}) ^ Bitset<33>({0, 16})).count() == 2); + + // Test operator^=. + Bitset<64> I({10, 20, 30}); + I ^= Bitset<64>({20, 30, 40}); + EXPECT_TRUE(I.test(10)); + EXPECT_FALSE(I.test(20)); + EXPECT_FALSE(I.test(30)); + EXPECT_TRUE(I.test(40)); + + constexpr Bitset<64> TestXor = [] { + Bitset<64> X({10, 20}); + X ^= Bitset<64>({20, 30}); + return X; + }(); + static_assert(TestXor.test(10) && !TestXor.test(20) && TestXor.test(30)); + + constexpr Bitset<128> TestXor128 = [] { + Bitset<128> X({0, 64, 127}); + X ^= Bitset<128>({64}); + return X; + }(); + static_assert(TestXor128.count() == 2 && TestXor128.test(0) && + TestXor128.test(127)); +} + } // namespace