[VPlan] Use bitfield to store Cmp predicates and GEP wrap flags. (NFC) (#181571)
Instead of storing CmpInst::Predicate/GepNoWrapFlags, only store their raw bitfield values. This reduces the size of VPIRFlags from 12 to 3 bytes. PR: https://github.com/llvm/llvm-project/pull/181571
This commit is contained in:
parent
899080a87a
commit
17aaa0e590
@ -25,6 +25,7 @@
|
||||
#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_H
|
||||
|
||||
#include "VPlanValue.h"
|
||||
#include "llvm/ADT/Bitfields.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -726,7 +727,7 @@ private:
|
||||
/// Holds both the predicate and fast-math flags for floating-point
|
||||
/// comparisons.
|
||||
struct FCmpFlagsTy {
|
||||
CmpInst::Predicate Pred;
|
||||
uint8_t CmpPredStorage;
|
||||
FastMathFlagsTy FMFs;
|
||||
};
|
||||
/// Holds reduction-specific flags: RecurKind, IsOrdered, IsInLoop, and FMFs.
|
||||
@ -748,30 +749,33 @@ private:
|
||||
OperationType OpType;
|
||||
|
||||
union {
|
||||
CmpInst::Predicate CmpPredicate;
|
||||
uint8_t CmpPredStorage;
|
||||
WrapFlagsTy WrapFlags;
|
||||
TruncFlagsTy TruncFlags;
|
||||
DisjointFlagsTy DisjointFlags;
|
||||
ExactFlagsTy ExactFlags;
|
||||
GEPNoWrapFlags GEPFlags;
|
||||
uint8_t GEPFlagsStorage;
|
||||
NonNegFlagsTy NonNegFlags;
|
||||
FastMathFlagsTy FMFs;
|
||||
FCmpFlagsTy FCmpFlags;
|
||||
ReductionFlagsTy ReductionFlags;
|
||||
unsigned AllFlags;
|
||||
};
|
||||
|
||||
public:
|
||||
VPIRFlags() : OpType(OperationType::Other), AllFlags(0) {}
|
||||
VPIRFlags() : OpType(OperationType::Other), CmpPredStorage(0) {}
|
||||
|
||||
VPIRFlags(Instruction &I) {
|
||||
if (auto *FCmp = dyn_cast<FCmpInst>(&I)) {
|
||||
OpType = OperationType::FCmp;
|
||||
FCmpFlags.Pred = FCmp->getPredicate();
|
||||
Bitfield::set<CmpInst::PredicateField>(FCmpFlags.CmpPredStorage,
|
||||
FCmp->getPredicate());
|
||||
assert(getPredicate() == FCmp->getPredicate() && "predicate truncated");
|
||||
FCmpFlags.FMFs = FCmp->getFastMathFlags();
|
||||
} else if (auto *Op = dyn_cast<CmpInst>(&I)) {
|
||||
OpType = OperationType::Cmp;
|
||||
CmpPredicate = Op->getPredicate();
|
||||
Bitfield::set<CmpInst::PredicateField>(CmpPredStorage,
|
||||
Op->getPredicate());
|
||||
assert(getPredicate() == Op->getPredicate() && "predicate truncated");
|
||||
} else if (auto *Op = dyn_cast<PossiblyDisjointInst>(&I)) {
|
||||
OpType = OperationType::DisjointOp;
|
||||
DisjointFlags.IsDisjoint = Op->isDisjoint();
|
||||
@ -786,7 +790,9 @@ public:
|
||||
ExactFlags.IsExact = Op->isExact();
|
||||
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
|
||||
OpType = OperationType::GEPOp;
|
||||
GEPFlags = GEP->getNoWrapFlags();
|
||||
GEPFlagsStorage = GEP->getNoWrapFlags().getRaw();
|
||||
assert(getGEPNoWrapFlags() == GEP->getNoWrapFlags() &&
|
||||
"wrap flags truncated");
|
||||
} else if (auto *PNNI = dyn_cast<PossiblyNonNegInst>(&I)) {
|
||||
OpType = OperationType::NonNegOp;
|
||||
NonNegFlags.NonNeg = PNNI->hasNonNeg();
|
||||
@ -795,16 +801,19 @@ public:
|
||||
FMFs = Op->getFastMathFlags();
|
||||
} else {
|
||||
OpType = OperationType::Other;
|
||||
AllFlags = 0;
|
||||
CmpPredStorage = 0;
|
||||
}
|
||||
}
|
||||
|
||||
VPIRFlags(CmpInst::Predicate Pred)
|
||||
: OpType(OperationType::Cmp), CmpPredicate(Pred) {}
|
||||
VPIRFlags(CmpInst::Predicate Pred) : OpType(OperationType::Cmp) {
|
||||
Bitfield::set<CmpInst::PredicateField>(CmpPredStorage, Pred);
|
||||
assert(getPredicate() == Pred && "predicate truncated");
|
||||
}
|
||||
|
||||
VPIRFlags(CmpInst::Predicate Pred, FastMathFlags FMFs)
|
||||
: OpType(OperationType::FCmp) {
|
||||
FCmpFlags.Pred = Pred;
|
||||
Bitfield::set<CmpInst::PredicateField>(FCmpFlags.CmpPredStorage, Pred);
|
||||
assert(getPredicate() == Pred && "predicate truncated");
|
||||
FCmpFlags.FMFs = FMFs;
|
||||
}
|
||||
|
||||
@ -826,16 +835,13 @@ public:
|
||||
: OpType(OperationType::PossiblyExactOp), ExactFlags(ExactFlags) {}
|
||||
|
||||
VPIRFlags(GEPNoWrapFlags GEPFlags)
|
||||
: OpType(OperationType::GEPOp), GEPFlags(GEPFlags) {}
|
||||
: OpType(OperationType::GEPOp), GEPFlagsStorage(GEPFlags.getRaw()) {}
|
||||
|
||||
VPIRFlags(RecurKind Kind, bool IsOrdered, bool IsInLoop, FastMathFlags FMFs)
|
||||
: OpType(OperationType::ReductionOp),
|
||||
ReductionFlags(Kind, IsOrdered, IsInLoop, FMFs) {}
|
||||
|
||||
void transferFlags(VPIRFlags &Other) {
|
||||
OpType = Other.OpType;
|
||||
AllFlags = Other.AllFlags;
|
||||
}
|
||||
void transferFlags(VPIRFlags &Other) { *this = Other; }
|
||||
|
||||
/// Only keep flags also present in \p Other. \p Other must have the same
|
||||
/// OpType as the current object.
|
||||
@ -861,7 +867,7 @@ public:
|
||||
ExactFlags.IsExact = false;
|
||||
break;
|
||||
case OperationType::GEPOp:
|
||||
GEPFlags = GEPNoWrapFlags::none();
|
||||
GEPFlagsStorage = 0;
|
||||
break;
|
||||
case OperationType::FPMathOp:
|
||||
case OperationType::FCmp:
|
||||
@ -896,7 +902,8 @@ public:
|
||||
I.setIsExact(ExactFlags.IsExact);
|
||||
break;
|
||||
case OperationType::GEPOp:
|
||||
cast<GetElementPtrInst>(&I)->setNoWrapFlags(GEPFlags);
|
||||
cast<GetElementPtrInst>(&I)->setNoWrapFlags(
|
||||
GEPNoWrapFlags::fromRaw(GEPFlagsStorage));
|
||||
break;
|
||||
case OperationType::FPMathOp:
|
||||
case OperationType::FCmp: {
|
||||
@ -924,19 +931,24 @@ public:
|
||||
CmpInst::Predicate getPredicate() const {
|
||||
assert((OpType == OperationType::Cmp || OpType == OperationType::FCmp) &&
|
||||
"recipe doesn't have a compare predicate");
|
||||
return OpType == OperationType::FCmp ? FCmpFlags.Pred : CmpPredicate;
|
||||
uint8_t Storage = OpType == OperationType::FCmp ? FCmpFlags.CmpPredStorage
|
||||
: CmpPredStorage;
|
||||
return Bitfield::get<CmpInst::PredicateField>(Storage);
|
||||
}
|
||||
|
||||
void setPredicate(CmpInst::Predicate Pred) {
|
||||
assert((OpType == OperationType::Cmp || OpType == OperationType::FCmp) &&
|
||||
"recipe doesn't have a compare predicate");
|
||||
if (OpType == OperationType::FCmp)
|
||||
FCmpFlags.Pred = Pred;
|
||||
Bitfield::set<CmpInst::PredicateField>(FCmpFlags.CmpPredStorage, Pred);
|
||||
else
|
||||
CmpPredicate = Pred;
|
||||
Bitfield::set<CmpInst::PredicateField>(CmpPredStorage, Pred);
|
||||
assert(getPredicate() == Pred && "predicate truncated");
|
||||
}
|
||||
|
||||
GEPNoWrapFlags getGEPNoWrapFlags() const { return GEPFlags; }
|
||||
GEPNoWrapFlags getGEPNoWrapFlags() const {
|
||||
return GEPNoWrapFlags::fromRaw(GEPFlagsStorage);
|
||||
}
|
||||
|
||||
/// Returns true if the recipe has a comparison predicate.
|
||||
bool hasPredicate() const {
|
||||
@ -1042,6 +1054,8 @@ public:
|
||||
#endif
|
||||
};
|
||||
|
||||
static_assert(sizeof(VPIRFlags) <= 3, "VPIRFlags should not grow");
|
||||
|
||||
/// A pure-virtual common base class for recipes defining a single VPValue and
|
||||
/// using IR flags.
|
||||
struct VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
|
||||
|
||||
@ -334,12 +334,12 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
|
||||
ExactFlags.IsExact &= Other.ExactFlags.IsExact;
|
||||
break;
|
||||
case OperationType::GEPOp:
|
||||
GEPFlags &= Other.GEPFlags;
|
||||
GEPFlagsStorage &= Other.GEPFlagsStorage;
|
||||
break;
|
||||
case OperationType::FPMathOp:
|
||||
case OperationType::FCmp:
|
||||
assert((OpType != OperationType::FCmp ||
|
||||
FCmpFlags.Pred == Other.FCmpFlags.Pred) &&
|
||||
FCmpFlags.CmpPredStorage == Other.FCmpFlags.CmpPredStorage) &&
|
||||
"Cannot drop CmpPredicate");
|
||||
getFMFsRef().NoNaNs &= Other.getFMFsRef().NoNaNs;
|
||||
getFMFsRef().NoInfs &= Other.getFMFsRef().NoInfs;
|
||||
@ -348,7 +348,8 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
|
||||
NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
|
||||
break;
|
||||
case OperationType::Cmp:
|
||||
assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate");
|
||||
assert(CmpPredStorage == Other.CmpPredStorage &&
|
||||
"Cannot drop CmpPredicate");
|
||||
break;
|
||||
case OperationType::ReductionOp:
|
||||
assert(ReductionFlags.Kind == Other.ReductionFlags.Kind &&
|
||||
@ -361,7 +362,6 @@ void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
|
||||
getFMFsRef().NoInfs &= Other.getFMFsRef().NoInfs;
|
||||
break;
|
||||
case OperationType::Other:
|
||||
assert(AllFlags == Other.AllFlags && "Cannot drop other flags");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -2219,14 +2219,16 @@ void VPIRFlags::printFlags(raw_ostream &O) const {
|
||||
case OperationType::FPMathOp:
|
||||
getFastMathFlags().print(O);
|
||||
break;
|
||||
case OperationType::GEPOp:
|
||||
if (GEPFlags.isInBounds())
|
||||
case OperationType::GEPOp: {
|
||||
GEPNoWrapFlags Flags = getGEPNoWrapFlags();
|
||||
if (Flags.isInBounds())
|
||||
O << " inbounds";
|
||||
else if (GEPFlags.hasNoUnsignedSignedWrap())
|
||||
else if (Flags.hasNoUnsignedSignedWrap())
|
||||
O << " nusw";
|
||||
if (GEPFlags.hasNoUnsignedWrap())
|
||||
if (Flags.hasNoUnsignedWrap())
|
||||
O << " nuw";
|
||||
break;
|
||||
}
|
||||
case OperationType::NonNegOp:
|
||||
if (NonNegFlags.NonNeg)
|
||||
O << " nneg";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user