Valentin Clement ea55503d7c
[fir] Add fir.extract_value and fir.insert_value conversion
This patch add the conversion pattern for fir.extract_value
and fir.insert_value. fir.extract_value is lowered to llvm.extractvalue
anf fir.insert_value is lowered to llvm.insertvalue.
This patch also adds the type conversion for the BoxType and RecordType
needed to have some comprehensive tests.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: awarzynski

Differential Revision: https://reviews.llvm.org/D112961

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
2021-11-05 15:53:42 +01:00

517 lines
19 KiB
C++

//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "PassDetail.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/ArrayRef.h"
#define DEBUG_TYPE "flang-codegen"
// fir::LLVMTypeConverter for converting to LLVM IR dialect types.
#include "TypeConverter.h"
namespace {
/// FIR conversion pattern template
template <typename FromOp>
class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
public:
explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
: mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
protected:
mlir::Type convertType(mlir::Type ty) const {
return lowerTy().convertType(ty);
}
fir::LLVMTypeConverter &lowerTy() const {
return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
}
};
/// FIR conversion pattern template
template <typename FromOp>
class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
public:
using FIROpConversion<FromOp>::FIROpConversion;
using OpAdaptor = typename FromOp::Adaptor;
mlir::LogicalResult
matchAndRewrite(FromOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const final {
mlir::Type ty = this->convertType(op.getType());
return doRewrite(op, ty, adaptor, rewriter);
}
virtual mlir::LogicalResult
doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const = 0;
};
// Lower `fir.address_of` operation to `llvm.address_of` operation.
struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto ty = convertType(addr.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
addr, ty, addr.symbol().getRootReference().getValue());
return success();
}
};
/// Lower `fir.has_value` operation to `llvm.return` operation.
struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
/// Lower `fir.global` operation to `llvm.global` operation.
/// `fir.insert_on_range` operations are replaced with constant dense attribute
/// if they are applied on the full range.
struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto tyAttr = convertType(global.getType());
if (global.getType().isa<fir::BoxType>())
tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
auto loc = global.getLoc();
mlir::Attribute initAttr{};
if (global.initVal())
initAttr = global.initVal().getValue();
auto linkage = convertLinkage(global.linkName());
auto isConst = global.constant().hasValue();
auto g = rewriter.create<mlir::LLVM::GlobalOp>(
loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
auto &gr = g.getInitializerRegion();
rewriter.inlineRegionBefore(global.region(), gr, gr.end());
if (!gr.empty()) {
// Replace insert_on_range with a constant dense attribute if the
// initialization is on the full range.
auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
for (auto insertOp : insertOnRangeOps) {
if (isFullRange(insertOp.coor(), insertOp.getType())) {
auto seqTyAttr = convertType(insertOp.getType());
auto *op = insertOp.val().getDefiningOp();
auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
if (!constant) {
auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
if (!convertOp)
continue;
constant = cast<mlir::arith::ConstantOp>(
convertOp.value().getDefiningOp());
}
mlir::Type vecType = mlir::VectorType::get(
insertOp.getType().getShape(), constant.getType());
auto denseAttr = mlir::DenseElementsAttr::get(
vecType.cast<ShapedType>(), constant.value());
rewriter.setInsertionPointAfter(insertOp);
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
insertOp, seqTyAttr, denseAttr);
}
}
}
rewriter.eraseOp(global);
return success();
}
bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
auto extents = seqTy.getShape();
if (indexes.size() / 2 != extents.size())
return false;
for (unsigned i = 0; i < indexes.size(); i += 2) {
if (indexes[i].cast<IntegerAttr>().getInt() != 0)
return false;
if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
return false;
}
return true;
}
// TODO: String comparaison should be avoided. Replace linkName with an
// enumeration.
mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
if (optLinkage.hasValue()) {
auto name = optLinkage.getValue();
if (name == "internal")
return mlir::LLVM::Linkage::Internal;
if (name == "linkonce")
return mlir::LLVM::Linkage::Linkonce;
if (name == "common")
return mlir::LLVM::Linkage::Common;
if (name == "weak")
return mlir::LLVM::Linkage::Weak;
}
return mlir::LLVM::Linkage::External;
}
};
template <typename OP>
void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
mlir::Value selector = adaptor.selector();
auto loc = select.getLoc();
assert(conds > 0 && "select must have cases");
llvm::SmallVector<mlir::Block *> destinations;
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
mlir::Block *defaultDestination;
mlir::ValueRange defaultOperands;
llvm::SmallVector<int32_t> caseValues;
for (unsigned t = 0; t != conds; ++t) {
mlir::Block *dest = select.getSuccessor(t);
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
const mlir::Attribute &attr = cases[t];
if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
destinations.push_back(dest);
destinationsOperands.push_back(destOps.hasValue() ? *destOps
: ValueRange());
caseValues.push_back(intAttr.getInt());
continue;
}
assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
assert((t + 1 == conds) && "unit must be last");
defaultDestination = dest;
defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
}
// LLVM::SwitchOp takes a i32 type for the selector.
if (select.getSelector().getType() != rewriter.getI32Type())
selector =
rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
select, selector,
/*defaultDestination=*/defaultDestination,
/*defaultOperands=*/defaultOperands,
/*caseValues=*/caseValues,
/*caseDestinations=*/destinations,
/*caseOperands=*/destinationsOperands,
/*branchWeights=*/ArrayRef<int32_t>());
}
/// conversion of fir::SelectOp to an if-then-else ladder
struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
return success();
}
};
/// conversion of fir::SelectRankOp to an if-then-else ladder
struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
return success();
}
};
// convert to LLVM IR dialect `undef`
struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::UndefOp undef, OpAdaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
undef, convertType(undef.getType()));
return success();
}
};
// convert to LLVM IR dialect `unreachable`
struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach);
return success();
}
};
struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::ZeroOp zero, OpAdaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto ty = convertType(zero.getType());
if (ty.isa<mlir::LLVM::LLVMPointerType>()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty);
} else if (ty.isa<mlir::IntegerType>()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
zero, ty, mlir::IntegerAttr::get(zero.getType(), 0));
} else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0));
} else {
// TODO: create ConstantAggregateZero for FIR aggregate/array types.
return rewriter.notifyMatchFailure(
zero,
"conversion of fir.zero with aggregate type not implemented yet");
}
return success();
}
};
// Code shared between insert_value and extract_value Ops.
struct ValueOpCommon {
// Translate the arguments pertaining to any multidimensional array to
// row-major order for LLVM-IR.
static void toRowMajor(SmallVectorImpl<mlir::Attribute> &attrs,
mlir::Type ty) {
assert(ty && "type is null");
const auto end = attrs.size();
for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) {
if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
const auto dim = getDimension(seq);
if (dim > 1) {
auto ub = std::min(i + dim, end);
std::reverse(attrs.begin() + i, attrs.begin() + ub);
i += dim - 1;
}
ty = getArrayElementType(seq);
} else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) {
ty = st.getBody()[attrs[i].cast<mlir::IntegerAttr>().getInt()];
} else {
llvm_unreachable("index into invalid type");
}
}
}
static llvm::SmallVector<mlir::Attribute>
collectIndices(mlir::ConversionPatternRewriter &rewriter,
mlir::ArrayAttr arrAttr) {
llvm::SmallVector<mlir::Attribute> attrs;
for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) {
if (i->isa<mlir::IntegerAttr>()) {
attrs.push_back(*i);
} else {
auto fieldName = i->cast<mlir::StringAttr>().getValue();
++i;
auto ty = i->cast<mlir::TypeAttr>().getValue();
auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName);
attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index));
}
}
return attrs;
}
private:
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
unsigned result = 1;
for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>();
eleTy;
eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>())
++result;
return result;
}
static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) {
auto eleTy = ty.getElementType();
while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
eleTy = arrTy.getElementType();
return eleTy;
}
};
/// Extract a subobject value from an ssa-value of aggregate type
struct ExtractValueOpConversion
: public FIROpAndTypeConversion<fir::ExtractValueOp>,
public ValueOpCommon {
using FIROpAndTypeConversion::FIROpAndTypeConversion;
mlir::LogicalResult
doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto attrs = collectIndices(rewriter, extractVal.coor());
toRowMajor(attrs, adaptor.getOperands()[0].getType());
auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs);
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
extractVal, ty, adaptor.getOperands()[0], position);
return success();
}
};
/// InsertValue is the generalized instruction for the composition of new
/// aggregate type values.
struct InsertValueOpConversion
: public FIROpAndTypeConversion<fir::InsertValueOp>,
public ValueOpCommon {
using FIROpAndTypeConversion::FIROpAndTypeConversion;
mlir::LogicalResult
doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto attrs = collectIndices(rewriter, insertVal.coor());
toRowMajor(attrs, adaptor.getOperands()[0].getType());
auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs);
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1],
position);
return success();
}
};
/// InsertOnRange inserts a value into a sequence over a range of offsets.
struct InsertOnRangeOpConversion
: public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
using FIROpAndTypeConversion::FIROpAndTypeConversion;
// Increments an array of subscripts in a row major fasion.
void incrementSubscripts(const SmallVector<uint64_t> &dims,
SmallVector<uint64_t> &subscripts) const {
for (size_t i = dims.size(); i > 0; --i) {
if (++subscripts[i - 1] < dims[i - 1]) {
return;
}
subscripts[i - 1] = 0;
}
}
mlir::LogicalResult
doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<uint64_t> dims;
auto type = adaptor.getOperands()[0].getType();
// Iteratively extract the array dimensions from the type.
while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
dims.push_back(t.getNumElements());
type = t.getElementType();
}
SmallVector<uint64_t> lBounds;
SmallVector<uint64_t> uBounds;
// Extract integer value from the attribute
SmallVector<int64_t> coordinates = llvm::to_vector<4>(
llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
return a.cast<IntegerAttr>().getInt();
}));
// Unzip the upper and lower bound and convert to a row major format.
for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
uBounds.push_back(*i++);
lBounds.push_back(*i);
}
auto &subscripts = lBounds;
auto loc = range.getLoc();
mlir::Value lastOp = adaptor.getOperands()[0];
mlir::Value insertVal = adaptor.getOperands()[1];
auto i64Ty = rewriter.getI64Type();
while (subscripts != uBounds) {
// Convert uint64_t's to Attribute's.
SmallVector<mlir::Attribute> subscriptAttrs;
for (const auto &subscript : subscripts)
subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript));
lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
loc, ty, lastOp, insertVal,
ArrayAttr::get(range.getContext(), subscriptAttrs));
incrementSubscripts(dims, subscripts);
}
// Convert uint64_t's to Attribute's.
SmallVector<mlir::Attribute> subscriptAttrs;
for (const auto &subscript : subscripts)
subscriptAttrs.push_back(
IntegerAttr::get(rewriter.getI64Type(), subscript));
mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs);
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
range, ty, lastOp, insertVal,
ArrayAttr::get(range.getContext(), arrayRef));
return success();
}
};
} // namespace
namespace {
/// Convert FIR dialect to LLVM dialect
///
/// This pass lowers all FIR dialect operations to LLVM IR dialect. An
/// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
///
/// This pass is not complete yet. We are upstreaming it in small patches.
class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
public:
mlir::ModuleOp getModule() { return getOperation(); }
void runOnOperation() override final {
auto *context = getModule().getContext();
fir::LLVMTypeConverter typeConverter{getModule()};
mlir::OwningRewritePatternList pattern(context);
pattern.insert<
AddrOfOpConversion, ExtractValueOpConversion, HasValueOpConversion,
GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
SelectOpConversion, SelectRankOpConversion, UndefOpConversion,
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
mlir::ConversionTarget target{*context};
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
// required NOPs for applying a full conversion
target.addLegalOp<mlir::ModuleOp>();
// apply the patterns
if (mlir::failed(mlir::applyFullConversion(getModule(), target,
std::move(pattern)))) {
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
return std::make_unique<FIRToLLVMLowering>();
}