[LoopFusion] Fix sink instructions (#147501)

If we have instructions in second loop's preheader which can be sunk, we
should also be adjusting PHI nodes to receive values from the fused loop's latch block.

Fixes #128600
This commit is contained in:
Madhur Amilkanthwar 2025-07-28 12:08:43 +05:30 committed by GitHub
parent 495774d6d5
commit 90de4a4ac9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 90 additions and 0 deletions

View File

@ -1176,6 +1176,28 @@ private:
return true;
}
/// This function fixes PHI nodes after fusion in \p SafeToSink.
/// \p SafeToSink instructions are the instructions that are to be moved past
/// the fused loop. Thus, the PHI nodes in \p SafeToSink should be updated to
/// receive values from the fused loop if they are currently taking values
/// from the first loop (i.e. FC0)'s latch.
void fixPHINodes(ArrayRef<Instruction *> SafeToSink,
const FusionCandidate &FC0,
const FusionCandidate &FC1) const {
for (Instruction *Inst : SafeToSink) {
// No update needed for non-PHI nodes.
PHINode *Phi = dyn_cast<PHINode>(Inst);
if (!Phi)
continue;
for (unsigned I = 0; I < Phi->getNumIncomingValues(); I++) {
if (Phi->getIncomingBlock(I) != FC0.Latch)
continue;
assert(FC1.Latch && "FC1 latch is not set");
Phi->setIncomingBlock(I, FC1.Latch);
}
}
}
/// Collect instructions in the \p FC1 Preheader that can be hoisted
/// to the \p FC0 Preheader or sunk into the \p FC1 Body
bool collectMovablePreheaderInsts(
@ -1481,6 +1503,9 @@ private:
assert(I->getParent() == FC1.Preheader);
I->moveBefore(*FC1.ExitBlock, FC1.ExitBlock->getFirstInsertionPt());
}
// PHI nodes in SinkInsts need to be updated to receive values from the
// fused loop.
fixPHINodes(SinkInsts, FC0, FC1);
}
/// Determine if two fusion candidates have identical guards

View File

@ -0,0 +1,65 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=loop-fusion -S < %s 2>&1 | FileCheck %s
define dso_local i32 @check_sunk_phi_nodes() {
; CHECK-LABEL: define dso_local i32 @check_sunk_phi_nodes() {
; CHECK-NEXT: [[ENTRY:.*]]:
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
; CHECK: [[FOR_BODY]]:
; CHECK-NEXT: [[SUM1_02:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD:%.*]], %[[FOR_INC6:.*]] ]
; CHECK-NEXT: [[I_01:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[INC:%.*]], %[[FOR_INC6]] ]
; CHECK-NEXT: [[I1_04:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[INC7:%.*]], %[[FOR_INC6]] ]
; CHECK-NEXT: [[SUM2_03:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD5:%.*]], %[[FOR_INC6]] ]
; CHECK-NEXT: [[ADD]] = add nsw i32 [[SUM1_02]], [[I_01]]
; CHECK-NEXT: br label %[[FOR_INC:.*]]
; CHECK: [[FOR_INC]]:
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[I1_04]], [[I1_04]]
; CHECK-NEXT: [[ADD5]] = add nsw i32 [[SUM2_03]], [[MUL]]
; CHECK-NEXT: br label %[[FOR_INC6]]
; CHECK: [[FOR_INC6]]:
; CHECK-NEXT: [[INC]] = add nsw i32 [[I_01]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[INC]], 10
; CHECK-NEXT: [[INC7]] = add nsw i32 [[I1_04]], 1
; CHECK-NEXT: [[CMP3:%.*]] = icmp slt i32 [[INC7]], 10
; CHECK-NEXT: br i1 [[CMP3]], label %[[FOR_BODY]], label %[[FOR_END8:.*]]
; CHECK: [[FOR_END8]]:
; CHECK-NEXT: [[SUM2_0_LCSSA:%.*]] = phi i32 [ [[ADD5]], %[[FOR_INC6]] ]
; CHECK-NEXT: [[SUM1_0_LCSSA:%.*]] = phi i32 [ [[ADD]], %[[FOR_INC6]] ]
; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[SUM1_0_LCSSA]], [[SUM2_0_LCSSA]]
; CHECK-NEXT: ret i32 [[TMP0]]
;
entry:
br label %for.body
for.body: ; preds = %entry, %for.inc
%sum1.02 = phi i32 [ 0, %entry ], [ %add, %for.inc ]
%i.01 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]
%add = add nsw i32 %sum1.02, %i.01
br label %for.inc
for.inc: ; preds = %for.body
%inc = add nsw i32 %i.01, 1
%cmp = icmp slt i32 %inc, 10
br i1 %cmp, label %for.body, label %for.end
for.end: ; preds = %for.inc
%sum1.0.lcssa = phi i32 [ %add, %for.inc ]
br label %for.body4
for.body4: ; preds = %for.end, %for.inc6
%i1.04 = phi i32 [ 0, %for.end ], [ %inc7, %for.inc6 ]
%sum2.03 = phi i32 [ 0, %for.end ], [ %add5, %for.inc6 ]
%mul = mul nsw i32 %i1.04, %i1.04
%add5 = add nsw i32 %sum2.03, %mul
br label %for.inc6
for.inc6: ; preds = %for.body4
%inc7 = add nsw i32 %i1.04, 1
%cmp3 = icmp slt i32 %inc7, 10
br i1 %cmp3, label %for.body4, label %for.end8
for.end8: ; preds = %for.inc6
%sum2.0.lcssa = phi i32 [ %add5, %for.inc6 ]
%0 = add i32 %sum1.0.lcssa, %sum2.0.lcssa
ret i32 %0
}