//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "../PassDetail.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include using namespace mlir; #define PASS_NAME "convert-cf-to-llvm" namespace { /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); auto abortFunc = module.lookupSymbol("abort"); if (!abortFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); abortFunc = rewriter.create(rewriter.getUnknownLoc(), "abort", abortFuncTy); } // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); // Generate IR to call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); rewriter.create(loc, abortFunc, llvm::None); rewriter.create(loc); // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( op, adaptor.getArg(), continuationBlock, failureBlock); return success(); } }; // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Base = OneToOneLLVMTerminatorLowering; LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); return success(); } }; // FIXME: this should be tablegen'ed as well. struct BranchOpLowering : public OneToOneLLVMTerminatorLowering { using Base::Base; }; struct CondBranchOpLowering : public OneToOneLLVMTerminatorLowering { using Base::Base; }; struct SwitchOpLowering : public OneToOneLLVMTerminatorLowering { using Base::Base; }; } // namespace void mlir::cf::populateControlFlowToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< AssertOpLowering, BranchOpLowering, CondBranchOpLowering, SwitchOpLowering>(converter); // clang-format on } //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct ConvertControlFlowToLLVM : public ConvertControlFlowToLLVMBase { ConvertControlFlowToLLVM() = default; /// Run the dialect converter on the module. void runOnOperation() override { LLVMConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(&getContext()); if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter converter(&getContext(), options); mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr mlir::cf::createConvertControlFlowToLLVMPass() { return std::make_unique(); }