[SimplifyCFG] Fold the contiguous wrapping cases into ICmp. (#161000)

Fixes #157113.

Take the following IR as an example; we know the destination of the `[1,
3]` cases is `%else`.

```llvm
define i32 @src(i8 range(i8 0, 6) %arg) {
  switch i8 %arg, label %else [
    i8 0, label %if
    i8 4, label %if
    i8 5, label %if
  ]

if:
  ret i32 0

else:
  ret i32 1
}
```

We can first try the non-wrapping range for both destinations, but I
don't see how that would be any better.

Proof: https://alive2.llvm.org/ce/z/acdWD4.
This commit is contained in:
dianqk 2025-10-06 12:56:43 +08:00 committed by GitHub
parent 36cfdebe92
commit bbdcba9b85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 331 additions and 47 deletions

View File

@ -5734,15 +5734,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
return Changed;
}
static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
struct ContiguousCasesResult {
ConstantInt *Min;
ConstantInt *Max;
BasicBlock *Dest;
BasicBlock *OtherDest;
SmallVectorImpl<ConstantInt *> *Cases;
SmallVectorImpl<ConstantInt *> *OtherCases;
};
static std::optional<ContiguousCasesResult>
findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases,
SmallVectorImpl<ConstantInt *> &OtherCases,
BasicBlock *Dest, BasicBlock *OtherDest) {
assert(Cases.size() >= 1);
array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate);
for (size_t I = 1, E = Cases.size(); I != E; ++I) {
if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1)
return false;
const APInt &Min = Cases.back()->getValue();
const APInt &Max = Cases.front()->getValue();
APInt Offset = Max - Min;
size_t ContiguousOffset = Cases.size() - 1;
if (Offset == ContiguousOffset) {
return ContiguousCasesResult{
/*Min=*/Cases.back(),
/*Max=*/Cases.front(),
/*Dest=*/Dest,
/*OtherDest=*/OtherDest,
/*Cases=*/&Cases,
/*OtherCases=*/&OtherCases,
};
}
return true;
ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false);
// If this is a wrapping contiguous range, that is, [Min, OtherMin] +
// [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
// contiguous range for the other destination. N.B. If CR is not a full range,
// Max+1 is not equal to Min. It's not continuous in arithmetic.
if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) {
assert(Cases.size() >= 2);
auto *It =
std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) {
return L->getValue() != R->getValue() + 1;
});
if (It == Cases.end())
return std::nullopt;
auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It));
if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) ==
Cases.size() - 2) {
return ContiguousCasesResult{
/*Min=*/cast<ConstantInt>(
ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)),
/*Max=*/
cast<ConstantInt>(
ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)),
/*Dest=*/OtherDest,
/*OtherDest=*/Dest,
/*Cases=*/&OtherCases,
/*OtherCases=*/&Cases,
};
}
}
return std::nullopt;
}
static void createUnreachableSwitchDefault(SwitchInst *Switch,
@ -5779,7 +5830,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
bool HasDefault = !SI->defaultDestUnreachable();
auto *BB = SI->getParent();
// Partition the cases into two sets with different destinations.
BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr;
BasicBlock *DestB = nullptr;
@ -5813,37 +5863,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
assert(!CasesA.empty() || HasDefault);
// Figure out if one of the sets of cases form a contiguous range.
SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr;
BasicBlock *ContiguousDest = nullptr;
BasicBlock *OtherDest = nullptr;
if (!CasesA.empty() && casesAreContiguous(CasesA)) {
ContiguousCases = &CasesA;
ContiguousDest = DestA;
OtherDest = DestB;
} else if (casesAreContiguous(CasesB)) {
ContiguousCases = &CasesB;
ContiguousDest = DestB;
OtherDest = DestA;
} else
std::optional<ContiguousCasesResult> ContiguousCases;
// Only one icmp is needed when there is only one case.
if (!HasDefault && CasesA.size() == 1)
ContiguousCases = ContiguousCasesResult{
/*Min=*/CasesA[0],
/*Max=*/CasesA[0],
/*Dest=*/DestA,
/*OtherDest=*/DestB,
/*Cases=*/&CasesA,
/*OtherCases=*/&CasesB,
};
else if (CasesB.size() == 1)
ContiguousCases = ContiguousCasesResult{
/*Min=*/CasesB[0],
/*Max=*/CasesB[0],
/*Dest=*/DestB,
/*OtherDest=*/DestA,
/*Cases=*/&CasesB,
/*OtherCases=*/&CasesA,
};
// Correctness: Cases to the default destination cannot be contiguous cases.
else if (!HasDefault)
ContiguousCases =
findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB);
if (!ContiguousCases)
ContiguousCases =
findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA);
if (!ContiguousCases)
return false;
auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases;
// Start building the compare and branch.
Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back());
Constant *NumCases =
ConstantInt::get(Offset->getType(), ContiguousCases->size());
Value *Sub = SI->getCondition();
if (!Offset->isNullValue())
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
Value *Cmp;
Constant *Offset = ConstantExpr::getNeg(Min);
Constant *NumCases = ConstantInt::get(Offset->getType(),
Max->getValue() - Min->getValue() + 1);
BranchInst *NewBI;
if (NumCases->isOneValue()) {
assert(Max->getValue() == Min->getValue());
Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min);
NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
}
// If NumCases overflowed, then all possible values jump to the successor.
if (NumCases->isNullValue() && !ContiguousCases->empty())
Cmp = ConstantInt::getTrue(SI->getContext());
else
Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
else if (NumCases->isNullValue() && !Cases->empty()) {
NewBI = Builder.CreateBr(Dest);
} else {
Value *Sub = SI->getCondition();
if (!Offset->isNullValue())
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
}
// Update weight for the newly-created conditional branch.
if (hasBranchWeightMD(*SI)) {
@ -5853,7 +5928,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
uint64_t TrueWeight = 0;
uint64_t FalseWeight = 0;
for (size_t I = 0, E = Weights.size(); I != E; ++I) {
if (SI->getSuccessor(I) == ContiguousDest)
if (SI->getSuccessor(I) == Dest)
TrueWeight += Weights[I];
else
FalseWeight += Weights[I];
@ -5868,15 +5943,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
}
// Prune obsolete incoming values off the successors' PHI nodes.
for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) {
unsigned PreviousEdges = ContiguousCases->size();
if (ContiguousDest == SI->getDefaultDest())
for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) {
unsigned PreviousEdges = Cases->size();
if (Dest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
}
for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) {
unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size();
unsigned PreviousEdges = OtherCases->size();
if (OtherDest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)

View File

@ -80,8 +80,8 @@ cleanup2:
; CHECK: cleanup2.corodispatch:
; CHECK: %1 = phi i8 [ 0, %handler2 ], [ 1, %catch.dispatch.2 ]
; CHECK: %2 = cleanuppad within %h1 []
; CHECK: %switch = icmp ult i8 %1, 1
; CHECK: br i1 %switch, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2
; CHECK: %3 = icmp eq i8 %1, 0
; CHECK: br i1 %3, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2
; CHECK: cleanup2.from.handler2:
; CHECK: %valueB.reload = load i32, ptr %valueB.spill.addr, align 4

View File

@ -7,8 +7,7 @@ declare void @foo(i32)
define void @test(i1 %a) {
; CHECK-LABEL: define void @test(
; CHECK-SAME: i1 [[A:%.*]]) {
; CHECK-NEXT: [[A_OFF:%.*]] = add i1 [[A]], true
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i1 [[A_OFF]], true
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i1 [[A]], true
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
@ -209,8 +208,7 @@ define void @test5(i8 %a) {
; CHECK-SAME: i8 [[A:%.*]]) {
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[A]], 2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], -1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], 1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
@ -243,8 +241,7 @@ define void @test6(i8 %a) {
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
@ -279,8 +276,7 @@ define void @test7(i8 %a) {
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void

View File

@ -188,4 +188,217 @@ exit:
ret void
}
define i32 @wrapping_known_range(i8 range(i8 0, 6) %arg) {
; CHECK-LABEL: @wrapping_known_range(
; CHECK-NEXT: [[ARG_OFF:%.*]] = add i8 [[ARG:%.*]], -1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[ARG_OFF]], 3
; CHECK-NEXT: br i1 [[SWITCH]], label [[ELSE:%.*]], label [[IF:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ]
; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
; CHECK: if:
; CHECK-NEXT: [[I0]] = call i32 @f(i32 0)
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
; CHECK-NEXT: [[I1]] = call i32 @f(i32 1)
; CHECK-NEXT: br label [[COMMON_RET]]
;
switch i8 %arg, label %else [
i8 0, label %if
i8 4, label %if
i8 5, label %if
]
if:
%i0 = call i32 @f(i32 0)
ret i32 %i0
else:
%i1 = call i32 @f(i32 1)
ret i32 %i1
}
define i32 @wrapping_known_range_2(i8 range(i8 0, 6) %arg) {
; CHECK-LABEL: @wrapping_known_range_2(
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[ARG:%.*]], 1
; CHECK-NEXT: br i1 [[SWITCH]], label [[ELSE:%.*]], label [[IF:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ]
; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
; CHECK: if:
; CHECK-NEXT: [[I0]] = call i32 @f(i32 0)
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
; CHECK-NEXT: [[I1]] = call i32 @f(i32 1)
; CHECK-NEXT: br label [[COMMON_RET]]
;
switch i8 %arg, label %else [
i8 0, label %if
i8 2, label %if
i8 3, label %if
i8 4, label %if
i8 5, label %if
]
if:
%i0 = call i32 @f(i32 0)
ret i32 %i0
else:
%i1 = call i32 @f(i32 1)
ret i32 %i1
}
define i32 @wrapping_range(i8 %arg) {
; CHECK-LABEL: @wrapping_range(
; CHECK-NEXT: [[ARG_OFF:%.*]] = add i8 [[ARG:%.*]], -1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[ARG_OFF]], -4
; CHECK-NEXT: br i1 [[SWITCH]], label [[ELSE:%.*]], label [[IF:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ]
; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
; CHECK: if:
; CHECK-NEXT: [[I0]] = call i32 @f(i32 0)
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
; CHECK-NEXT: [[I1]] = call i32 @f(i32 1)
; CHECK-NEXT: br label [[COMMON_RET]]
;
switch i8 %arg, label %else [
i8 0, label %if
i8 -3, label %if
i8 -2, label %if
i8 -1, label %if
]
if:
%i0 = call i32 @f(i32 0)
ret i32 %i0
else:
%i1 = call i32 @f(i32 1)
ret i32 %i1
}
define i8 @wrapping_range_phi(i8 %arg) {
; CHECK-LABEL: @wrapping_range_phi(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[ARG_OFF:%.*]] = add i8 [[ARG:%.*]], -1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[ARG_OFF]], -2
; CHECK-NEXT: [[SPEC_SELECT:%.*]] = select i1 [[SWITCH]], i8 0, i8 1
; CHECK-NEXT: ret i8 [[SPEC_SELECT]]
;
entry:
switch i8 %arg, label %else [
i8 0, label %if
i8 -1, label %if
]
if:
%i = phi i8 [ 0, %else ], [ 1, %entry ], [ 1, %entry ]
ret i8 %i
else:
br label %if
}
define i32 @no_continuous_wrapping_range(i8 %arg) {
; CHECK-LABEL: @no_continuous_wrapping_range(
; CHECK-NEXT: switch i8 [[ARG:%.*]], label [[ELSE:%.*]] [
; CHECK-NEXT: i8 0, label [[IF:%.*]]
; CHECK-NEXT: i8 -3, label [[IF]]
; CHECK-NEXT: i8 -1, label [[IF]]
; CHECK-NEXT: ]
; CHECK: common.ret:
; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ]
; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
; CHECK: if:
; CHECK-NEXT: [[I0]] = call i32 @f(i32 0)
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
; CHECK-NEXT: [[I1]] = call i32 @f(i32 1)
; CHECK-NEXT: br label [[COMMON_RET]]
;
switch i8 %arg, label %else [
i8 0, label %if
i8 -3, label %if
i8 -1, label %if
]
if:
%i0 = call i32 @f(i32 0)
ret i32 %i0
else:
%i1 = call i32 @f(i32 1)
ret i32 %i1
}
define i32 @one_case_1(i32 %x) {
; CHECK-LABEL: @one_case_1(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i32 [[X:%.*]], 10
; CHECK-NEXT: br i1 [[SWITCH]], label [[A:%.*]], label [[B:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[TMP0:%.*]], [[B]] ], [ [[TMP1:%.*]], [[A]] ]
; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
; CHECK: a:
; CHECK-NEXT: [[TMP0]] = call i32 @f(i32 0)
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: b:
; CHECK-NEXT: [[TMP1]] = call i32 @f(i32 1)
; CHECK-NEXT: br label [[COMMON_RET]]
;
entry:
switch i32 %x, label %unreachable [
i32 5, label %a
i32 6, label %a
i32 7, label %a
i32 10, label %b
]
unreachable:
unreachable
a:
%0 = call i32 @f(i32 0)
ret i32 %0
b:
%1 = call i32 @f(i32 1)
ret i32 %1
}
define i32 @one_case_2(i32 %x) {
; CHECK-LABEL: @one_case_2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i32 [[X:%.*]], 5
; CHECK-NEXT: br i1 [[SWITCH]], label [[A:%.*]], label [[B:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[TMP0:%.*]], [[A]] ], [ [[TMP1:%.*]], [[B]] ]
; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
; CHECK: a:
; CHECK-NEXT: [[TMP0]] = call i32 @f(i32 0)
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: b:
; CHECK-NEXT: [[TMP1]] = call i32 @f(i32 1)
; CHECK-NEXT: br label [[COMMON_RET]]
;
entry:
switch i32 %x, label %unreachable [
i32 5, label %a
i32 10, label %b
i32 11, label %b
i32 12, label %b
i32 13, label %b
]
unreachable:
unreachable
a:
%0 = call i32 @f(i32 0)
ret i32 %0
b:
%1 = call i32 @f(i32 1)
ret i32 %1
}
declare void @bar(ptr nonnull dereferenceable(4))