
Previously cf.br cf.cond_br and cf.switch always lowered to their LLVM equivalents. These ops are all ops that take in some values of given types and jump to other blocks with argument lists of the same types. If the types are not the same, a verification failure will later occur. This led to confusions, as everything works when func->llvm and cf->llvm lowering both occur because func->llvm updates the blocks and argument lists while cf->llvm updates the branching ops. Without func->llvm though, there will potentially be a type mismatch. This change now only lowers the CF ops if they will later pass verification. This is possible because the parent op and its blocks will be updated before the contained branching ops, so they can test their new operand types against the types of the blocks they jump to. Another plan was to have func->llvm only update the entry block signature and to allow cf->llvm to update all other blocks, but this had 2 problems: 1. This would create a FuncOp lowering in cf->llvm lowering which is awkward 2. This new pattern would only be applied if the containing FuncOp is marked invalid. This is infeasible with the shared LLVM type conversion/target infrastructure. See previous discussions at https://discourse.llvm.org/t/lowering-cf-to-llvm/63863 and https://github.com/llvm/llvm-project/issues/55301 Differential Revision: https://reviews.llvm.org/D130971
224 lines
9.0 KiB
C++
224 lines
9.0 KiB
C++
//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to convert MLIR standard and builtin dialects
|
|
// into the LLVM IR dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include <functional>
|
|
|
|
using namespace mlir;
|
|
|
|
#define PASS_NAME "convert-cf-to-llvm"
|
|
|
|
namespace {
|
|
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
|
|
/// assertion is violated and has no effect otherwise. The failure message is
|
|
/// ignored by the default lowering but should be propagated by any custom
|
|
/// lowering.
|
|
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
|
|
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
|
|
// Insert the `abort` declaration if necessary.
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
|
|
if (!abortFunc) {
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
|
|
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
|
|
"abort", abortFuncTy);
|
|
}
|
|
|
|
// Split block at `assert` operation.
|
|
Block *opBlock = rewriter.getInsertionBlock();
|
|
auto opPosition = rewriter.getInsertionPoint();
|
|
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
|
|
|
|
// Generate IR to call `abort`.
|
|
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
|
|
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
|
|
rewriter.create<LLVM::UnreachableOp>(loc);
|
|
|
|
// Generate assertion test.
|
|
rewriter.setInsertionPointToEnd(opBlock);
|
|
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
|
op, adaptor.getArg(), continuationBlock, failureBlock);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
|
|
/// to first have updated types which should be handled by a pattern operating
|
|
/// on the parent op.
|
|
static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
|
|
ValueRange operands,
|
|
ValueRange blockArgs, Location loc,
|
|
llvm::StringRef messagePrefix) {
|
|
for (const auto &idxAndTypes :
|
|
llvm::enumerate(llvm::zip(blockArgs, operands))) {
|
|
int64_t i = idxAndTypes.index();
|
|
Value argValue =
|
|
rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
|
|
Type operandType = std::get<1>(idxAndTypes.value()).getType();
|
|
// In the case of an invalid jump, the block argument will have been
|
|
// remapped to an UnrealizedConversionCast. In the case of a valid jump,
|
|
// there might still be a no-op conversion cast with both types being equal.
|
|
// Consider both of these details to see if the jump would be invalid.
|
|
if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
|
|
argValue.getDefiningOp())) {
|
|
if (op.getOperandTypes().front() != operandType) {
|
|
return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
|
|
diag << messagePrefix;
|
|
diag << "mismatched types from operand # " << i << " ";
|
|
diag << operandType;
|
|
diag << " not compatible with destination block argument type ";
|
|
diag << argValue.getType();
|
|
diag << " which should be converted with the parent op.";
|
|
});
|
|
}
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Ensure that all block types were updated and then create an LLVM::BrOp
|
|
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
|
|
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
|
|
op.getSuccessor()->getArguments(),
|
|
op.getLoc(),
|
|
/*messagePrefix=*/"")))
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
|
|
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Ensure that all block types were updated and then create an LLVM::CondBrOp
|
|
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
|
|
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(cf::CondBranchOp op,
|
|
typename cf::CondBranchOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
|
|
op.getFalseDest()->getArguments(),
|
|
op.getLoc(), "in false case branch ")))
|
|
return failure();
|
|
if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
|
|
op.getTrueDest()->getArguments(),
|
|
op.getLoc(), "in true case branch ")))
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
|
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Ensure that all block types were updated and then create an LLVM::SwitchOp
|
|
struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
|
|
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
|
|
op.getDefaultDestination()->getArguments(),
|
|
op.getLoc(), "in switch default case ")))
|
|
return failure();
|
|
|
|
for (const auto &i : llvm::enumerate(
|
|
llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
|
|
if (failed(verifyMatchingValues(
|
|
rewriter, std::get<0>(i.value()),
|
|
std::get<1>(i.value())->getArguments(), op.getLoc(),
|
|
"in switch case " + std::to_string(i.index()) + " "))) {
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
|
|
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
|
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
|
// clang-format off
|
|
patterns.add<
|
|
AssertOpLowering,
|
|
BranchOpLowering,
|
|
CondBranchOpLowering,
|
|
SwitchOpLowering>(converter);
|
|
// clang-format on
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass Definition
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// A pass converting MLIR operations into the LLVM IR dialect.
|
|
struct ConvertControlFlowToLLVM
|
|
: public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
|
|
ConvertControlFlowToLLVM() = default;
|
|
|
|
/// Run the dialect converter on the module.
|
|
void runOnOperation() override {
|
|
LLVMConversionTarget target(getContext());
|
|
RewritePatternSet patterns(&getContext());
|
|
|
|
LowerToLLVMOptions options(&getContext());
|
|
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
|
|
options.overrideIndexBitwidth(indexBitwidth);
|
|
|
|
LLVMTypeConverter converter(&getContext(), options);
|
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
|
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
|
|
return std::make_unique<ConvertControlFlowToLLVM>();
|
|
}
|