[CIR] Upstream GotoSolver pass (#154596)

This PR upstreams the GotoSolver pass.  
It works by walking the function and matching each label to a goto. If a
label is not matched to a goto, it is removed and not lowered.
This commit is contained in:
Andres-Salamanca 2025-08-21 11:02:29 -05:00 committed by GitHub
parent 3923adfa3f
commit fc62990657
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 266 additions and 2 deletions

View File

@ -26,6 +26,7 @@ std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();
std::unique_ptr<Pass> createLoweringPreparePass();
std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
std::unique_ptr<Pass> createGotoSolverPass();
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);

View File

@ -72,6 +72,16 @@ def CIRFlattenCFG : Pass<"cir-flatten-cfg"> {
let dependentDialects = ["cir::CIRDialect"];
}
def GotoSolver : Pass<"cir-goto-solver"> {
let summary = "Replaces goto operations with branches";
let description = [{
This pass transforms CIR and replaces goto-s with branch
operations to the proper blocks.
}];
let constructor = "mlir::createGotoSolverPass()";
let dependentDialects = ["cir::CIRDialect"];
}
def LoweringPrepare : Pass<"cir-lowering-prepare"> {
let summary = "Lower to more fine-grained CIR operations before lowering to "
"other dialects";

View File

@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTransforms
FlattenCFG.cpp
HoistAllocas.cpp
LoweringPrepare.cpp
GotoSolver.cpp
DEPENDS
MLIRCIRPassIncGen

View File

@ -0,0 +1,57 @@
//====- GotoSolver.cpp -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "llvm/Support/TimeProfiler.h"
#include <memory>
using namespace mlir;
using namespace cir;
namespace {
struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> {
GotoSolverPass() = default;
void runOnOperation() override;
};
static void process(cir::FuncOp func) {
mlir::OpBuilder rewriter(func.getContext());
llvm::StringMap<Block *> labels;
llvm::SmallVector<cir::GotoOp, 4> gotos;
func.getBody().walk([&](mlir::Operation *op) {
if (auto lab = dyn_cast<cir::LabelOp>(op)) {
// Will construct a string copy inplace. Safely erase the label
labels.try_emplace(lab.getLabel(), lab->getBlock());
lab.erase();
} else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
gotos.push_back(goTo);
}
});
for (auto goTo : gotos) {
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(goTo);
Block *dest = labels[goTo.getLabel()];
cir::BrOp::create(rewriter, goTo.getLoc(), dest);
goTo.erase();
}
}
void GotoSolverPass::runOnOperation() {
llvm::TimeTraceScope scope("Goto Solver");
getOperation()->walk(&process);
}
} // namespace
std::unique_ptr<Pass> mlir::createGotoSolverPass() {
return std::make_unique<GotoSolverPass>();
}

View File

@ -45,6 +45,7 @@ namespace mlir {
void populateCIRPreLoweringPasses(OpPassManager &pm) {
pm.addPass(createHoistAllocasPass());
pm.addPass(createCIRFlattenCFGPass());
pm.addPass(createGotoSolverPass());
}
} // namespace mlir

View File

@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
@ -27,6 +29,24 @@ err:
// CIR: cir.store [[MINUS]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
// CIR: cir.br ^bb1
// LLVM: define dso_local i32 @_Z21shouldNotGenBranchReti
// LLVM: [[COND:%.*]] = load i32, ptr {{.*}}, align 4
// LLVM: [[CMP:%.*]] = icmp sgt i32 [[COND]], 5
// LLVM: br i1 [[CMP]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
// LLVM: [[IFTHEN]]:
// LLVM: br label %[[ERR:.*]]
// LLVM: [[IFEND]]:
// LLVM: br label %[[BB9:.*]]
// LLVM: [[BB9]]:
// LLVM: store i32 0, ptr %[[RETVAL:.*]], align 4
// LLVM: br label %[[BBRET:.*]]
// LLVM: [[BBRET]]:
// LLVM: [[RET:%.*]] = load i32, ptr %[[RETVAL]], align 4
// LLVM: ret i32 [[RET]]
// LLVM: [[ERR]]:
// LLVM: store i32 -1, ptr %[[RETVAL]], align 4
// LLVM: br label %10
// OGCG: define dso_local noundef i32 @_Z21shouldNotGenBranchReti
// OGCG: if.then:
// OGCG: br label %err
@ -51,6 +71,17 @@ err:
// CIR: ^bb1:
// CIR: cir.label "err"
// LLVM: define dso_local i32 @_Z15shouldGenBranchi
// LLVM: br i1 [[CMP:%.*]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
// LLVM: [[IFTHEN]]:
// LLVM: br label %[[ERR:.*]]
// LLVM: [[IFEND]]:
// LLVM: br label %[[BB9:.*]]
// LLVM: [[BB9]]:
// LLVM: br label %[[ERR]]
// LLVM: [[ERR]]:
// LLVM: ret i32 [[RET:%.*]]
// OGCG: define dso_local noundef i32 @_Z15shouldGenBranchi
// OGCG: if.then:
// OGCG: br label %err
@ -78,6 +109,15 @@ end2:
// CIR: ^bb[[#BLK3]]:
// CIR: cir.label "end2"
// LLVM: define dso_local void @_Z19severalLabelsInARowi
// LLVM: br label %[[END1:.*]]
// LLVM: [[UNRE:.*]]: ; No predecessors!
// LLVM: br label %[[END2:.*]]
// LLVM: [[END1]]:
// LLVM: br label %[[END2]]
// LLVM: [[END2]]:
// LLVM: ret
// OGCG: define dso_local void @_Z19severalLabelsInARowi
// OGCG: br label %end1
// OGCG: end1:
@ -99,6 +139,13 @@ end:
// CIR: ^bb[[#BLK2:]]:
// CIR: cir.label "end"
// LLVM: define dso_local void @_Z18severalGotosInARowi
// LLVM: br label %[[END:.*]]
// LLVM: [[UNRE:.*]]: ; No predecessors!
// LLVM: br label %[[END]]
// LLVM: [[END]]:
// LLVM: ret void
// OGCG: define dso_local void @_Z18severalGotosInARowi(i32 noundef %a) #0 {
// OGCG: br label %end
// OGCG: end:
@ -126,6 +173,14 @@ extern "C" void multiple_non_case(int v) {
// CIR: cir.call @action2()
// CIR: cir.break
// LLVM: define dso_local void @multiple_non_case
// LLVM: [[SWDEFAULT:.*]]:
// LLVM: call void @action1()
// LLVM: br label %[[L2:.*]]
// LLVM: [[L2]]:
// LLVM: call void @action2()
// LLVM: br label %[[BREAK:.*]]
// OGCG: define dso_local void @multiple_non_case
// OGCG: sw.default:
// OGCG: call void @action1()
@ -158,6 +213,26 @@ extern "C" void case_follow_label(int v) {
// CIR: cir.call @action2()
// CIR: cir.goto "label"
// LLVM: define dso_local void @case_follow_label
// LLVM: switch i32 {{.*}}, label %[[SWDEFAULT:.*]] [
// LLVM: i32 1, label %[[LABEL:.*]]
// LLVM: i32 2, label %[[CASE2:.*]]
// LLVM: ]
// LLVM: [[LABEL]]:
// LLVM: br label %[[CASE2]]
// LLVM: [[CASE2]]:
// LLVM: call void @action1()
// LLVM: br label %[[BREAK:.*]]
// LLVM: [[BREAK]]:
// LLVM: br label %[[END:.*]]
// LLVM: [[SWDEFAULT]]:
// LLVM: call void @action2()
// LLVM: br label %[[LABEL]]
// LLVM: [[END]]:
// LLVM: br label %[[RET:.*]]
// LLVM: [[RET]]:
// LLVM: ret void
// OGCG: define dso_local void @case_follow_label
// OGCG: sw.bb:
// OGCG: br label %label
@ -197,6 +272,26 @@ extern "C" void default_follow_label(int v) {
// CIR: cir.call @action2()
// CIR: cir.goto "label"
// LLVM: define dso_local void @default_follow_label
// LLVM: [[CASE1:.*]]:
// LLVM: br label %[[BB8:.*]]
// LLVM: [[BB8]]:
// LLVM: br label %[[CASE2:.*]]
// LLVM: [[CASE2]]:
// LLVM: call void @action1()
// LLVM: br label %[[BREAK:.*]]
// LLVM: [[LABEL:.*]]:
// LLVM: br label %[[SWDEFAULT:.*]]
// LLVM: [[SWDEFAULT]]:
// LLVM: call void @action2()
// LLVM: br label %[[BB9:.*]]
// LLVM: [[BB9]]:
// LLVM: br label %[[LABEL]]
// LLVM: [[BREAK]]:
// LLVM: br label %[[RET:.*]]
// LLVM: [[RET]]:
// LLVM: ret void
// OGCG: define dso_local void @default_follow_label
// OGCG: sw.bb:
// OGCG: call void @action1()

View File

@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
@ -12,8 +14,8 @@ labelA:
// CIR: cir.label "labelA"
// CIR: cir.return
// Note: We are not lowering to LLVM IR via CIR at this stage because that
// process depends on the GotoSolver.
// LLVM:define dso_local void @label
// LLVM: ret void
// OGCG: define dso_local void @label
// OGCG: br label %labelA
@ -33,6 +35,11 @@ labelC:
// CIR: cir.label "labelC"
// CIR: cir.return
// LLVM: define dso_local void @multiple_labels()
// LLVM: br label %1
// LLVM: 1:
// LLVM: ret void
// OGCG: define dso_local void @multiple_labels
// OGCG: br label %labelB
// OGCG: labelB:
@ -56,6 +63,22 @@ labelD:
// CIR: }
// CIR: cir.return
// LLVM: define dso_local void @label_in_if
// LLVM: br label %3
// LLVM: 3:
// LLVM: [[LOAD:%.*]] = load i32, ptr [[COND:%.*]], align 4
// LLVM: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0
// LLVM: br i1 [[CMP]], label %6, label %9
// LLVM: 6:
// LLVM: [[LOAD2:%.*]] = load i32, ptr [[COND]], align 4
// LLVM: [[ADD1:%.*]] = add nsw i32 [[LOAD2]], 1
// LLVM: store i32 [[ADD1]], ptr [[COND]], align 4
// LLVM: br label %9
// LLVM: 9:
// LLVM: br label %10
// LLVM: 10:
// LLVM: ret void
// OGCG: define dso_local void @label_in_if
// OGCG: if.then:
// OGCG: br label %labelD
@ -80,6 +103,13 @@ void after_return() {
// CIR: cir.label "label"
// CIR: cir.br ^bb1
// LLVM: define dso_local void @after_return
// LLVM: br label %1
// LLVM: 1:
// LLVM: ret void
// LLVM: 2:
// LLVM: br label %1
// OGCG: define dso_local void @after_return
// OGCG: br label %label
// OGCG: label:
@ -97,6 +127,11 @@ void after_unreachable() {
// CIR: cir.label "label"
// CIR: cir.return
// LLVM: define dso_local void @after_unreachable
// LLVM: unreachable
// LLVM: 1:
// LLVM: ret void
// OGCG: define dso_local void @after_unreachable
// OGCG: unreachable
// OGCG: label:
@ -111,6 +146,9 @@ end:
// CIR: cir.return
// CIR: }
// LLVM: define dso_local void @labelWithoutMatch
// LLVM: ret void
// OGCG: define dso_local void @labelWithoutMatch
// OGCG: br label %end
// OGCG: end:
@ -132,6 +170,15 @@ void foo() {
// CIR: cir.label "label"
// CIR: %0 = cir.alloca !rec_S, !cir.ptr<!rec_S>, ["agg.tmp0"]
// LLVM:define dso_local void @foo() {
// LLVM: [[ALLOC:%.*]] = alloca %struct.S, i64 1, align 1
// LLVM: br label %2
// LLVM:2:
// LLVM: [[CALL:%.*]] = call %struct.S @get()
// LLVM: store %struct.S [[CALL]], ptr [[ALLOC]], align 1
// LLVM: [[LOAD:%.*]] = load %struct.S, ptr [[ALLOC]], align 1
// LLVM: call void @bar(%struct.S [[LOAD]])
// OGCG: define dso_local void @foo()
// OGCG: %agg.tmp = alloca %struct.S, align 1
// OGCG: %undef.agg.tmp = alloca %struct.S, align 1

View File

@ -0,0 +1,52 @@
// RUN: cir-opt %s --pass-pipeline='builtin.module(cir-to-llvm,canonicalize{region-simplify=disabled})' -o - | FileCheck %s -check-prefix=MLIR
!s32i = !cir.int<s, 32>
module {
cir.func @gotoFromIf(%arg0: !s32i) -> !s32i {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
cir.scope {
%6 = cir.load %0 : !cir.ptr<!s32i>, !s32i
%7 = cir.const #cir.int<5> : !s32i
%8 = cir.cmp(gt, %6, %7) : !s32i, !cir.bool
cir.if %8 {
cir.goto "err"
}
}
%2 = cir.const #cir.int<0> : !s32i
cir.store %2, %1 : !s32i, !cir.ptr<!s32i>
cir.br ^bb1
^bb1:
%3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
cir.return %3 : !s32i
^bb2:
cir.label "err"
%4 = cir.const #cir.int<1> : !s32i
%5 = cir.unary(minus, %4) : !s32i, !s32i
cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
cir.br ^bb1
}
// MLIR: llvm.func @gotoFromIf
// MLIR: %[[#One:]] = llvm.mlir.constant(1 : i32) : i32
// MLIR: %[[#Zero:]] = llvm.mlir.constant(0 : i32) : i32
// MLIR: llvm.cond_br {{.*}}, ^bb[[#COND_YES:]], ^bb[[#COND_NO:]]
// MLIR: ^bb[[#COND_YES]]:
// MLIR: llvm.br ^bb[[#GOTO_BLK:]]
// MLIR: ^bb[[#COND_NO]]:
// MLIR: llvm.br ^bb[[#BLK:]]
// MLIR: ^bb[[#BLK]]:
// MLIR: llvm.store %[[#Zero]], %[[#Ret_val_addr:]] {{.*}}: i32, !llvm.ptr
// MLIR: llvm.br ^bb[[#RETURN:]]
// MLIR: ^bb[[#RETURN]]:
// MLIR: %[[#Ret_val:]] = llvm.load %[[#Ret_val_addr]] {alignment = 4 : i64} : !llvm.ptr -> i32
// MLIR: llvm.return %[[#Ret_val]] : i32
// MLIR: ^bb[[#GOTO_BLK]]:
// MLIR: %[[#Neg_one:]] = llvm.sub %[[#Zero]], %[[#One]] : i32
// MLIR: llvm.store %[[#Neg_one]], %[[#Ret_val_addr]] {{.*}}: i32, !llvm.ptr
// MLIR: llvm.br ^bb[[#RETURN]]
// MLIR: }
}