[CIR] Upstream GotoSolver pass (#154596)
This PR upstreams the GotoSolver pass. It works by walking the function and matching each label to a goto. If a label is not matched to a goto, it is removed and not lowered.
This commit is contained in:
parent
3923adfa3f
commit
fc62990657
@ -26,6 +26,7 @@ std::unique_ptr<Pass> createCIRSimplifyPass();
|
||||
std::unique_ptr<Pass> createHoistAllocasPass();
|
||||
std::unique_ptr<Pass> createLoweringPreparePass();
|
||||
std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
|
||||
std::unique_ptr<Pass> createGotoSolverPass();
|
||||
|
||||
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
|
||||
|
||||
|
@ -72,6 +72,16 @@ def CIRFlattenCFG : Pass<"cir-flatten-cfg"> {
|
||||
let dependentDialects = ["cir::CIRDialect"];
|
||||
}
|
||||
|
||||
def GotoSolver : Pass<"cir-goto-solver"> {
|
||||
let summary = "Replaces goto operations with branches";
|
||||
let description = [{
|
||||
This pass transforms CIR and replaces goto-s with branch
|
||||
operations to the proper blocks.
|
||||
}];
|
||||
let constructor = "mlir::createGotoSolverPass()";
|
||||
let dependentDialects = ["cir::CIRDialect"];
|
||||
}
|
||||
|
||||
def LoweringPrepare : Pass<"cir-lowering-prepare"> {
|
||||
let summary = "Lower to more fine-grained CIR operations before lowering to "
|
||||
"other dialects";
|
||||
|
@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTransforms
|
||||
FlattenCFG.cpp
|
||||
HoistAllocas.cpp
|
||||
LoweringPrepare.cpp
|
||||
GotoSolver.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRCIRPassIncGen
|
||||
|
57
clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
Normal file
57
clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
Normal file
@ -0,0 +1,57 @@
|
||||
//====- GotoSolver.cpp -----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "PassDetail.h"
|
||||
#include "clang/CIR/Dialect/IR/CIRDialect.h"
|
||||
#include "clang/CIR/Dialect/Passes.h"
|
||||
#include "llvm/Support/TimeProfiler.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace cir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> {
|
||||
GotoSolverPass() = default;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
static void process(cir::FuncOp func) {
|
||||
mlir::OpBuilder rewriter(func.getContext());
|
||||
llvm::StringMap<Block *> labels;
|
||||
llvm::SmallVector<cir::GotoOp, 4> gotos;
|
||||
|
||||
func.getBody().walk([&](mlir::Operation *op) {
|
||||
if (auto lab = dyn_cast<cir::LabelOp>(op)) {
|
||||
// Will construct a string copy inplace. Safely erase the label
|
||||
labels.try_emplace(lab.getLabel(), lab->getBlock());
|
||||
lab.erase();
|
||||
} else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
|
||||
gotos.push_back(goTo);
|
||||
}
|
||||
});
|
||||
|
||||
for (auto goTo : gotos) {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(goTo);
|
||||
Block *dest = labels[goTo.getLabel()];
|
||||
cir::BrOp::create(rewriter, goTo.getLoc(), dest);
|
||||
goTo.erase();
|
||||
}
|
||||
}
|
||||
|
||||
void GotoSolverPass::runOnOperation() {
|
||||
llvm::TimeTraceScope scope("Goto Solver");
|
||||
getOperation()->walk(&process);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createGotoSolverPass() {
|
||||
return std::make_unique<GotoSolverPass>();
|
||||
}
|
@ -45,6 +45,7 @@ namespace mlir {
|
||||
void populateCIRPreLoweringPasses(OpPassManager &pm) {
|
||||
pm.addPass(createHoistAllocasPass());
|
||||
pm.addPass(createCIRFlattenCFGPass());
|
||||
pm.addPass(createGotoSolverPass());
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -1,5 +1,7 @@
|
||||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
|
||||
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
|
||||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
|
||||
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
|
||||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
|
||||
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
|
||||
|
||||
@ -27,6 +29,24 @@ err:
|
||||
// CIR: cir.store [[MINUS]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
|
||||
// CIR: cir.br ^bb1
|
||||
|
||||
// LLVM: define dso_local i32 @_Z21shouldNotGenBranchReti
|
||||
// LLVM: [[COND:%.*]] = load i32, ptr {{.*}}, align 4
|
||||
// LLVM: [[CMP:%.*]] = icmp sgt i32 [[COND]], 5
|
||||
// LLVM: br i1 [[CMP]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
|
||||
// LLVM: [[IFTHEN]]:
|
||||
// LLVM: br label %[[ERR:.*]]
|
||||
// LLVM: [[IFEND]]:
|
||||
// LLVM: br label %[[BB9:.*]]
|
||||
// LLVM: [[BB9]]:
|
||||
// LLVM: store i32 0, ptr %[[RETVAL:.*]], align 4
|
||||
// LLVM: br label %[[BBRET:.*]]
|
||||
// LLVM: [[BBRET]]:
|
||||
// LLVM: [[RET:%.*]] = load i32, ptr %[[RETVAL]], align 4
|
||||
// LLVM: ret i32 [[RET]]
|
||||
// LLVM: [[ERR]]:
|
||||
// LLVM: store i32 -1, ptr %[[RETVAL]], align 4
|
||||
// LLVM: br label %10
|
||||
|
||||
// OGCG: define dso_local noundef i32 @_Z21shouldNotGenBranchReti
|
||||
// OGCG: if.then:
|
||||
// OGCG: br label %err
|
||||
@ -51,6 +71,17 @@ err:
|
||||
// CIR: ^bb1:
|
||||
// CIR: cir.label "err"
|
||||
|
||||
// LLVM: define dso_local i32 @_Z15shouldGenBranchi
|
||||
// LLVM: br i1 [[CMP:%.*]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
|
||||
// LLVM: [[IFTHEN]]:
|
||||
// LLVM: br label %[[ERR:.*]]
|
||||
// LLVM: [[IFEND]]:
|
||||
// LLVM: br label %[[BB9:.*]]
|
||||
// LLVM: [[BB9]]:
|
||||
// LLVM: br label %[[ERR]]
|
||||
// LLVM: [[ERR]]:
|
||||
// LLVM: ret i32 [[RET:%.*]]
|
||||
|
||||
// OGCG: define dso_local noundef i32 @_Z15shouldGenBranchi
|
||||
// OGCG: if.then:
|
||||
// OGCG: br label %err
|
||||
@ -78,6 +109,15 @@ end2:
|
||||
// CIR: ^bb[[#BLK3]]:
|
||||
// CIR: cir.label "end2"
|
||||
|
||||
// LLVM: define dso_local void @_Z19severalLabelsInARowi
|
||||
// LLVM: br label %[[END1:.*]]
|
||||
// LLVM: [[UNRE:.*]]: ; No predecessors!
|
||||
// LLVM: br label %[[END2:.*]]
|
||||
// LLVM: [[END1]]:
|
||||
// LLVM: br label %[[END2]]
|
||||
// LLVM: [[END2]]:
|
||||
// LLVM: ret
|
||||
|
||||
// OGCG: define dso_local void @_Z19severalLabelsInARowi
|
||||
// OGCG: br label %end1
|
||||
// OGCG: end1:
|
||||
@ -99,6 +139,13 @@ end:
|
||||
// CIR: ^bb[[#BLK2:]]:
|
||||
// CIR: cir.label "end"
|
||||
|
||||
// LLVM: define dso_local void @_Z18severalGotosInARowi
|
||||
// LLVM: br label %[[END:.*]]
|
||||
// LLVM: [[UNRE:.*]]: ; No predecessors!
|
||||
// LLVM: br label %[[END]]
|
||||
// LLVM: [[END]]:
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @_Z18severalGotosInARowi(i32 noundef %a) #0 {
|
||||
// OGCG: br label %end
|
||||
// OGCG: end:
|
||||
@ -126,6 +173,14 @@ extern "C" void multiple_non_case(int v) {
|
||||
// CIR: cir.call @action2()
|
||||
// CIR: cir.break
|
||||
|
||||
// LLVM: define dso_local void @multiple_non_case
|
||||
// LLVM: [[SWDEFAULT:.*]]:
|
||||
// LLVM: call void @action1()
|
||||
// LLVM: br label %[[L2:.*]]
|
||||
// LLVM: [[L2]]:
|
||||
// LLVM: call void @action2()
|
||||
// LLVM: br label %[[BREAK:.*]]
|
||||
|
||||
// OGCG: define dso_local void @multiple_non_case
|
||||
// OGCG: sw.default:
|
||||
// OGCG: call void @action1()
|
||||
@ -158,6 +213,26 @@ extern "C" void case_follow_label(int v) {
|
||||
// CIR: cir.call @action2()
|
||||
// CIR: cir.goto "label"
|
||||
|
||||
// LLVM: define dso_local void @case_follow_label
|
||||
// LLVM: switch i32 {{.*}}, label %[[SWDEFAULT:.*]] [
|
||||
// LLVM: i32 1, label %[[LABEL:.*]]
|
||||
// LLVM: i32 2, label %[[CASE2:.*]]
|
||||
// LLVM: ]
|
||||
// LLVM: [[LABEL]]:
|
||||
// LLVM: br label %[[CASE2]]
|
||||
// LLVM: [[CASE2]]:
|
||||
// LLVM: call void @action1()
|
||||
// LLVM: br label %[[BREAK:.*]]
|
||||
// LLVM: [[BREAK]]:
|
||||
// LLVM: br label %[[END:.*]]
|
||||
// LLVM: [[SWDEFAULT]]:
|
||||
// LLVM: call void @action2()
|
||||
// LLVM: br label %[[LABEL]]
|
||||
// LLVM: [[END]]:
|
||||
// LLVM: br label %[[RET:.*]]
|
||||
// LLVM: [[RET]]:
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @case_follow_label
|
||||
// OGCG: sw.bb:
|
||||
// OGCG: br label %label
|
||||
@ -197,6 +272,26 @@ extern "C" void default_follow_label(int v) {
|
||||
// CIR: cir.call @action2()
|
||||
// CIR: cir.goto "label"
|
||||
|
||||
// LLVM: define dso_local void @default_follow_label
|
||||
// LLVM: [[CASE1:.*]]:
|
||||
// LLVM: br label %[[BB8:.*]]
|
||||
// LLVM: [[BB8]]:
|
||||
// LLVM: br label %[[CASE2:.*]]
|
||||
// LLVM: [[CASE2]]:
|
||||
// LLVM: call void @action1()
|
||||
// LLVM: br label %[[BREAK:.*]]
|
||||
// LLVM: [[LABEL:.*]]:
|
||||
// LLVM: br label %[[SWDEFAULT:.*]]
|
||||
// LLVM: [[SWDEFAULT]]:
|
||||
// LLVM: call void @action2()
|
||||
// LLVM: br label %[[BB9:.*]]
|
||||
// LLVM: [[BB9]]:
|
||||
// LLVM: br label %[[LABEL]]
|
||||
// LLVM: [[BREAK]]:
|
||||
// LLVM: br label %[[RET:.*]]
|
||||
// LLVM: [[RET]]:
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @default_follow_label
|
||||
// OGCG: sw.bb:
|
||||
// OGCG: call void @action1()
|
||||
|
@ -1,5 +1,7 @@
|
||||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
|
||||
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
|
||||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
|
||||
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
|
||||
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
|
||||
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
|
||||
|
||||
@ -12,8 +14,8 @@ labelA:
|
||||
// CIR: cir.label "labelA"
|
||||
// CIR: cir.return
|
||||
|
||||
// Note: We are not lowering to LLVM IR via CIR at this stage because that
|
||||
// process depends on the GotoSolver.
|
||||
// LLVM:define dso_local void @label
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @label
|
||||
// OGCG: br label %labelA
|
||||
@ -33,6 +35,11 @@ labelC:
|
||||
// CIR: cir.label "labelC"
|
||||
// CIR: cir.return
|
||||
|
||||
// LLVM: define dso_local void @multiple_labels()
|
||||
// LLVM: br label %1
|
||||
// LLVM: 1:
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @multiple_labels
|
||||
// OGCG: br label %labelB
|
||||
// OGCG: labelB:
|
||||
@ -56,6 +63,22 @@ labelD:
|
||||
// CIR: }
|
||||
// CIR: cir.return
|
||||
|
||||
// LLVM: define dso_local void @label_in_if
|
||||
// LLVM: br label %3
|
||||
// LLVM: 3:
|
||||
// LLVM: [[LOAD:%.*]] = load i32, ptr [[COND:%.*]], align 4
|
||||
// LLVM: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0
|
||||
// LLVM: br i1 [[CMP]], label %6, label %9
|
||||
// LLVM: 6:
|
||||
// LLVM: [[LOAD2:%.*]] = load i32, ptr [[COND]], align 4
|
||||
// LLVM: [[ADD1:%.*]] = add nsw i32 [[LOAD2]], 1
|
||||
// LLVM: store i32 [[ADD1]], ptr [[COND]], align 4
|
||||
// LLVM: br label %9
|
||||
// LLVM: 9:
|
||||
// LLVM: br label %10
|
||||
// LLVM: 10:
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @label_in_if
|
||||
// OGCG: if.then:
|
||||
// OGCG: br label %labelD
|
||||
@ -80,6 +103,13 @@ void after_return() {
|
||||
// CIR: cir.label "label"
|
||||
// CIR: cir.br ^bb1
|
||||
|
||||
// LLVM: define dso_local void @after_return
|
||||
// LLVM: br label %1
|
||||
// LLVM: 1:
|
||||
// LLVM: ret void
|
||||
// LLVM: 2:
|
||||
// LLVM: br label %1
|
||||
|
||||
// OGCG: define dso_local void @after_return
|
||||
// OGCG: br label %label
|
||||
// OGCG: label:
|
||||
@ -97,6 +127,11 @@ void after_unreachable() {
|
||||
// CIR: cir.label "label"
|
||||
// CIR: cir.return
|
||||
|
||||
// LLVM: define dso_local void @after_unreachable
|
||||
// LLVM: unreachable
|
||||
// LLVM: 1:
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @after_unreachable
|
||||
// OGCG: unreachable
|
||||
// OGCG: label:
|
||||
@ -111,6 +146,9 @@ end:
|
||||
// CIR: cir.return
|
||||
// CIR: }
|
||||
|
||||
// LLVM: define dso_local void @labelWithoutMatch
|
||||
// LLVM: ret void
|
||||
|
||||
// OGCG: define dso_local void @labelWithoutMatch
|
||||
// OGCG: br label %end
|
||||
// OGCG: end:
|
||||
@ -132,6 +170,15 @@ void foo() {
|
||||
// CIR: cir.label "label"
|
||||
// CIR: %0 = cir.alloca !rec_S, !cir.ptr<!rec_S>, ["agg.tmp0"]
|
||||
|
||||
// LLVM:define dso_local void @foo() {
|
||||
// LLVM: [[ALLOC:%.*]] = alloca %struct.S, i64 1, align 1
|
||||
// LLVM: br label %2
|
||||
// LLVM:2:
|
||||
// LLVM: [[CALL:%.*]] = call %struct.S @get()
|
||||
// LLVM: store %struct.S [[CALL]], ptr [[ALLOC]], align 1
|
||||
// LLVM: [[LOAD:%.*]] = load %struct.S, ptr [[ALLOC]], align 1
|
||||
// LLVM: call void @bar(%struct.S [[LOAD]])
|
||||
|
||||
// OGCG: define dso_local void @foo()
|
||||
// OGCG: %agg.tmp = alloca %struct.S, align 1
|
||||
// OGCG: %undef.agg.tmp = alloca %struct.S, align 1
|
||||
|
52
clang/test/CIR/Lowering/goto.cir
Normal file
52
clang/test/CIR/Lowering/goto.cir
Normal file
@ -0,0 +1,52 @@
|
||||
// RUN: cir-opt %s --pass-pipeline='builtin.module(cir-to-llvm,canonicalize{region-simplify=disabled})' -o - | FileCheck %s -check-prefix=MLIR
|
||||
|
||||
!s32i = !cir.int<s, 32>
|
||||
|
||||
module {
|
||||
|
||||
cir.func @gotoFromIf(%arg0: !s32i) -> !s32i {
|
||||
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
|
||||
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
|
||||
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
|
||||
cir.scope {
|
||||
%6 = cir.load %0 : !cir.ptr<!s32i>, !s32i
|
||||
%7 = cir.const #cir.int<5> : !s32i
|
||||
%8 = cir.cmp(gt, %6, %7) : !s32i, !cir.bool
|
||||
cir.if %8 {
|
||||
cir.goto "err"
|
||||
}
|
||||
}
|
||||
%2 = cir.const #cir.int<0> : !s32i
|
||||
cir.store %2, %1 : !s32i, !cir.ptr<!s32i>
|
||||
cir.br ^bb1
|
||||
^bb1:
|
||||
%3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
|
||||
cir.return %3 : !s32i
|
||||
^bb2:
|
||||
cir.label "err"
|
||||
%4 = cir.const #cir.int<1> : !s32i
|
||||
%5 = cir.unary(minus, %4) : !s32i, !s32i
|
||||
cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
|
||||
cir.br ^bb1
|
||||
}
|
||||
|
||||
// MLIR: llvm.func @gotoFromIf
|
||||
// MLIR: %[[#One:]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// MLIR: %[[#Zero:]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// MLIR: llvm.cond_br {{.*}}, ^bb[[#COND_YES:]], ^bb[[#COND_NO:]]
|
||||
// MLIR: ^bb[[#COND_YES]]:
|
||||
// MLIR: llvm.br ^bb[[#GOTO_BLK:]]
|
||||
// MLIR: ^bb[[#COND_NO]]:
|
||||
// MLIR: llvm.br ^bb[[#BLK:]]
|
||||
// MLIR: ^bb[[#BLK]]:
|
||||
// MLIR: llvm.store %[[#Zero]], %[[#Ret_val_addr:]] {{.*}}: i32, !llvm.ptr
|
||||
// MLIR: llvm.br ^bb[[#RETURN:]]
|
||||
// MLIR: ^bb[[#RETURN]]:
|
||||
// MLIR: %[[#Ret_val:]] = llvm.load %[[#Ret_val_addr]] {alignment = 4 : i64} : !llvm.ptr -> i32
|
||||
// MLIR: llvm.return %[[#Ret_val]] : i32
|
||||
// MLIR: ^bb[[#GOTO_BLK]]:
|
||||
// MLIR: %[[#Neg_one:]] = llvm.sub %[[#Zero]], %[[#One]] : i32
|
||||
// MLIR: llvm.store %[[#Neg_one]], %[[#Ret_val_addr]] {{.*}}: i32, !llvm.ptr
|
||||
// MLIR: llvm.br ^bb[[#RETURN]]
|
||||
// MLIR: }
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user