
This patch adds the ZeroOpConversion pattern to LLVM IR dialect. Conversion of aggregate types is not implemented yet and will trigger a failure to legalize the operation. This is tested in the convert-to-llvm-invalid.fir test file. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D113014 Co-authored-by: Jean Perier <jperier@nvidia.com> Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
233 lines
8.7 KiB
C++
233 lines
8.7 KiB
C++
//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Optimizer/CodeGen/CodeGen.h"
|
|
#include "PassDetail.h"
|
|
#include "flang/Optimizer/Dialect/FIROps.h"
|
|
#include "flang/Optimizer/Dialect/FIRType.h"
|
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
|
|
#define DEBUG_TYPE "flang-codegen"
|
|
|
|
// fir::LLVMTypeConverter for converting to LLVM IR dialect types.
|
|
#include "TypeConverter.h"
|
|
|
|
namespace {
|
|
/// FIR conversion pattern template
|
|
template <typename FromOp>
|
|
class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
|
|
public:
|
|
explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
|
|
: mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
|
|
|
|
protected:
|
|
mlir::Type convertType(mlir::Type ty) const {
|
|
return lowerTy().convertType(ty);
|
|
}
|
|
|
|
fir::LLVMTypeConverter &lowerTy() const {
|
|
return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
|
|
}
|
|
};
|
|
|
|
// Lower `fir.address_of` operation to `llvm.address_of` operation.
|
|
struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
|
|
using FIROpConversion::FIROpConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
auto ty = convertType(addr.getType());
|
|
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
|
|
addr, ty, addr.symbol().getRootReference().getValue());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `fir.has_value` operation to `llvm.return` operation.
|
|
struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
|
|
using FIROpConversion::FIROpConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `fir.global` operation to `llvm.global` operation.
|
|
/// `fir.insert_on_range` operations are replaced with constant dense attribute
|
|
/// if they are applied on the full range.
|
|
struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
|
|
using FIROpConversion::FIROpConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
auto tyAttr = convertType(global.getType());
|
|
if (global.getType().isa<fir::BoxType>())
|
|
tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
|
|
auto loc = global.getLoc();
|
|
mlir::Attribute initAttr{};
|
|
if (global.initVal())
|
|
initAttr = global.initVal().getValue();
|
|
auto linkage = convertLinkage(global.linkName());
|
|
auto isConst = global.constant().hasValue();
|
|
auto g = rewriter.create<mlir::LLVM::GlobalOp>(
|
|
loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
|
|
auto &gr = g.getInitializerRegion();
|
|
rewriter.inlineRegionBefore(global.region(), gr, gr.end());
|
|
if (!gr.empty()) {
|
|
// Replace insert_on_range with a constant dense attribute if the
|
|
// initialization is on the full range.
|
|
auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
|
|
for (auto insertOp : insertOnRangeOps) {
|
|
if (isFullRange(insertOp.coor(), insertOp.getType())) {
|
|
auto seqTyAttr = convertType(insertOp.getType());
|
|
auto *op = insertOp.val().getDefiningOp();
|
|
auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
|
|
if (!constant) {
|
|
auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
|
|
if (!convertOp)
|
|
continue;
|
|
constant = cast<mlir::arith::ConstantOp>(
|
|
convertOp.value().getDefiningOp());
|
|
}
|
|
mlir::Type vecType = mlir::VectorType::get(
|
|
insertOp.getType().getShape(), constant.getType());
|
|
auto denseAttr = mlir::DenseElementsAttr::get(
|
|
vecType.cast<ShapedType>(), constant.value());
|
|
rewriter.setInsertionPointAfter(insertOp);
|
|
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
|
|
insertOp, seqTyAttr, denseAttr);
|
|
}
|
|
}
|
|
}
|
|
rewriter.eraseOp(global);
|
|
return success();
|
|
}
|
|
|
|
bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
|
|
auto extents = seqTy.getShape();
|
|
if (indexes.size() / 2 != extents.size())
|
|
return false;
|
|
for (unsigned i = 0; i < indexes.size(); i += 2) {
|
|
if (indexes[i].cast<IntegerAttr>().getInt() != 0)
|
|
return false;
|
|
if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// TODO: String comparaison should be avoided. Replace linkName with an
|
|
// enumeration.
|
|
mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
|
|
if (optLinkage.hasValue()) {
|
|
auto name = optLinkage.getValue();
|
|
if (name == "internal")
|
|
return mlir::LLVM::Linkage::Internal;
|
|
if (name == "linkonce")
|
|
return mlir::LLVM::Linkage::Linkonce;
|
|
if (name == "common")
|
|
return mlir::LLVM::Linkage::Common;
|
|
if (name == "weak")
|
|
return mlir::LLVM::Linkage::Weak;
|
|
}
|
|
return mlir::LLVM::Linkage::External;
|
|
}
|
|
};
|
|
|
|
// convert to LLVM IR dialect `undef`
|
|
struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
|
|
using FIROpConversion::FIROpConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(fir::UndefOp undef, OpAdaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
|
|
undef, convertType(undef.getType()));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
|
|
using FIROpConversion::FIROpConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(fir::ZeroOp zero, OpAdaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
|
auto ty = convertType(zero.getType());
|
|
if (ty.isa<mlir::LLVM::LLVMPointerType>()) {
|
|
rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty);
|
|
} else if (ty.isa<mlir::IntegerType>()) {
|
|
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
|
|
zero, ty, mlir::IntegerAttr::get(zero.getType(), 0));
|
|
} else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) {
|
|
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
|
|
zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0));
|
|
} else {
|
|
// TODO: create ConstantAggregateZero for FIR aggregate/array types.
|
|
return zero.emitOpError(
|
|
"conversion of fir.zero with aggregate type not implemented yet");
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
/// Convert FIR dialect to LLVM dialect
|
|
///
|
|
/// This pass lowers all FIR dialect operations to LLVM IR dialect. An
|
|
/// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
|
|
///
|
|
/// This pass is not complete yet. We are upstreaming it in small patches.
|
|
class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
|
|
public:
|
|
mlir::ModuleOp getModule() { return getOperation(); }
|
|
|
|
void runOnOperation() override final {
|
|
auto *context = getModule().getContext();
|
|
fir::LLVMTypeConverter typeConverter{getModule()};
|
|
mlir::OwningRewritePatternList pattern(context);
|
|
pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
|
|
UndefOpConversion, ZeroOpConversion>(typeConverter);
|
|
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
|
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
|
pattern);
|
|
mlir::ConversionTarget target{*context};
|
|
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
|
|
|
// required NOPs for applying a full conversion
|
|
target.addLegalOp<mlir::ModuleOp>();
|
|
|
|
// apply the patterns
|
|
if (mlir::failed(mlir::applyFullConversion(getModule(), target,
|
|
std::move(pattern)))) {
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
|
|
return std::make_unique<FIRToLLVMLowering>();
|
|
}
|