diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 5a75f28b21ba..978b467521a5 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -25,6 +25,7 @@ #define LLVM_TRANSFORMS_VECTORIZE_VPLAN_H #include "VPlanValue.h" +#include "llvm/ADT/Bitfields.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -726,7 +727,7 @@ private: /// Holds both the predicate and fast-math flags for floating-point /// comparisons. struct FCmpFlagsTy { - CmpInst::Predicate Pred; + uint8_t CmpPredStorage; FastMathFlagsTy FMFs; }; /// Holds reduction-specific flags: RecurKind, IsOrdered, IsInLoop, and FMFs. @@ -748,30 +749,33 @@ private: OperationType OpType; union { - CmpInst::Predicate CmpPredicate; + uint8_t CmpPredStorage; WrapFlagsTy WrapFlags; TruncFlagsTy TruncFlags; DisjointFlagsTy DisjointFlags; ExactFlagsTy ExactFlags; - GEPNoWrapFlags GEPFlags; + uint8_t GEPFlagsStorage; NonNegFlagsTy NonNegFlags; FastMathFlagsTy FMFs; FCmpFlagsTy FCmpFlags; ReductionFlagsTy ReductionFlags; - unsigned AllFlags; }; public: - VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {} + VPIRFlags() : OpType(OperationType::Other), CmpPredStorage(0) {} VPIRFlags(Instruction &I) { if (auto *FCmp = dyn_cast(&I)) { OpType = OperationType::FCmp; - FCmpFlags.Pred = FCmp->getPredicate(); + Bitfield::set(FCmpFlags.CmpPredStorage, + FCmp->getPredicate()); + assert(getPredicate() == FCmp->getPredicate() && "predicate truncated"); FCmpFlags.FMFs = FCmp->getFastMathFlags(); } else if (auto *Op = dyn_cast(&I)) { OpType = OperationType::Cmp; - CmpPredicate = Op->getPredicate(); + Bitfield::set(CmpPredStorage, + Op->getPredicate()); + assert(getPredicate() == Op->getPredicate() && "predicate truncated"); } else if (auto *Op = dyn_cast(&I)) { OpType = OperationType::DisjointOp; DisjointFlags.IsDisjoint = Op->isDisjoint(); @@ -786,7 +790,9 @@ public: ExactFlags.IsExact = Op->isExact(); } else if (auto *GEP = dyn_cast(&I)) { OpType = OperationType::GEPOp; - GEPFlags = GEP->getNoWrapFlags(); + GEPFlagsStorage = GEP->getNoWrapFlags().getRaw(); + assert(getGEPNoWrapFlags() == GEP->getNoWrapFlags() && + "wrap flags truncated"); } else if (auto *PNNI = dyn_cast(&I)) { OpType = OperationType::NonNegOp; NonNegFlags.NonNeg = PNNI->hasNonNeg(); @@ -795,16 +801,19 @@ public: FMFs = Op->getFastMathFlags(); } else { OpType = OperationType::Other; - AllFlags = 0; + CmpPredStorage = 0; } } - VPIRFlags(CmpInst::Predicate Pred) - : OpType(OperationType::Cmp), CmpPredicate(Pred) {} + VPIRFlags(CmpInst::Predicate Pred) : OpType(OperationType::Cmp) { + Bitfield::set(CmpPredStorage, Pred); + assert(getPredicate() == Pred && "predicate truncated"); + } VPIRFlags(CmpInst::Predicate Pred, FastMathFlags FMFs) : OpType(OperationType::FCmp) { - FCmpFlags.Pred = Pred; + Bitfield::set(FCmpFlags.CmpPredStorage, Pred); + assert(getPredicate() == Pred && "predicate truncated"); FCmpFlags.FMFs = FMFs; } @@ -826,16 +835,13 @@ public: : OpType(OperationType::PossiblyExactOp), ExactFlags(ExactFlags) {} VPIRFlags(GEPNoWrapFlags GEPFlags) - : OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {} + : OpType(OperationType::GEPOp), GEPFlagsStorage(GEPFlags.getRaw()) {} VPIRFlags(RecurKind Kind, bool IsOrdered, bool IsInLoop, FastMathFlags FMFs) : OpType(OperationType::ReductionOp), ReductionFlags(Kind, IsOrdered, IsInLoop, FMFs) {} - void transferFlags(VPIRFlags &Other) { - OpType = Other.OpType; - AllFlags = Other.AllFlags; - } + void transferFlags(VPIRFlags &Other) { *this = Other; } /// Only keep flags also present in \p Other. \p Other must have the same /// OpType as the current object. @@ -861,7 +867,7 @@ public: ExactFlags.IsExact = false; break; case OperationType::GEPOp: - GEPFlags = GEPNoWrapFlags::none(); + GEPFlagsStorage = 0; break; case OperationType::FPMathOp: case OperationType::FCmp: @@ -896,7 +902,8 @@ public: I.setIsExact(ExactFlags.IsExact); break; case OperationType::GEPOp: - cast(&I)->setNoWrapFlags(GEPFlags); + cast(&I)->setNoWrapFlags( + GEPNoWrapFlags::fromRaw(GEPFlagsStorage)); break; case OperationType::FPMathOp: case OperationType::FCmp: { @@ -924,19 +931,24 @@ public: CmpInst::Predicate getPredicate() const { assert((OpType == OperationType::Cmp || OpType == OperationType::FCmp) && "recipe doesn't have a compare predicate"); - return OpType == OperationType::FCmp ? FCmpFlags.Pred : CmpPredicate; + uint8_t Storage = OpType == OperationType::FCmp ? FCmpFlags.CmpPredStorage + : CmpPredStorage; + return Bitfield::get(Storage); } void setPredicate(CmpInst::Predicate Pred) { assert((OpType == OperationType::Cmp || OpType == OperationType::FCmp) && "recipe doesn't have a compare predicate"); if (OpType == OperationType::FCmp) - FCmpFlags.Pred = Pred; + Bitfield::set(FCmpFlags.CmpPredStorage, Pred); else - CmpPredicate = Pred; + Bitfield::set(CmpPredStorage, Pred); + assert(getPredicate() == Pred && "predicate truncated"); } - GEPNoWrapFlags getGEPNoWrapFlags() const { return GEPFlags; } + GEPNoWrapFlags getGEPNoWrapFlags() const { + return GEPNoWrapFlags::fromRaw(GEPFlagsStorage); + } /// Returns true if the recipe has a comparison predicate. bool hasPredicate() const { @@ -1042,6 +1054,8 @@ public: #endif }; +static_assert(sizeof(VPIRFlags) <= 3, "VPIRFlags should not grow"); + /// A pure-virtual common base class for recipes defining a single VPValue and /// using IR flags. struct VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index ce64ab927182..26183f15306f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -334,12 +334,12 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) { ExactFlags.IsExact &= Other.ExactFlags.IsExact; break; case OperationType::GEPOp: - GEPFlags &= Other.GEPFlags; + GEPFlagsStorage &= Other.GEPFlagsStorage; break; case OperationType::FPMathOp: case OperationType::FCmp: assert((OpType != OperationType::FCmp || - FCmpFlags.Pred == Other.FCmpFlags.Pred) && + FCmpFlags.CmpPredStorage == Other.FCmpFlags.CmpPredStorage) && "Cannot drop CmpPredicate"); getFMFsRef().NoNaNs &= Other.getFMFsRef().NoNaNs; getFMFsRef().NoInfs &= Other.getFMFsRef().NoInfs; @@ -348,7 +348,8 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) { NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg; break; case OperationType::Cmp: - assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate"); + assert(CmpPredStorage == Other.CmpPredStorage && + "Cannot drop CmpPredicate"); break; case OperationType::ReductionOp: assert(ReductionFlags.Kind == Other.ReductionFlags.Kind && @@ -361,7 +362,6 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) { getFMFsRef().NoInfs &= Other.getFMFsRef().NoInfs; break; case OperationType::Other: - assert(AllFlags == Other.AllFlags && "Cannot drop other flags"); break; } } @@ -2219,14 +2219,16 @@ void VPIRFlags::printFlags(raw_ostream &O) const { case OperationType::FPMathOp: getFastMathFlags().print(O); break; - case OperationType::GEPOp: - if (GEPFlags.isInBounds()) + case OperationType::GEPOp: { + GEPNoWrapFlags Flags = getGEPNoWrapFlags(); + if (Flags.isInBounds()) O << " inbounds"; - else if (GEPFlags.hasNoUnsignedSignedWrap()) + else if (Flags.hasNoUnsignedSignedWrap()) O << " nusw"; - if (GEPFlags.hasNoUnsignedWrap()) + if (Flags.hasNoUnsignedWrap()) O << " nuw"; break; + } case OperationType::NonNegOp: if (NonNegFlags.NonNeg) O << " nneg";