219 lines
8.7 KiB
C++
219 lines
8.7 KiB
C++
//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
|
|
|
|
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
|
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
/// A pattern that converts the result and operand types, attributes, and region
|
|
/// arguments of an OpenMP operation to the LLVM dialect.
|
|
///
|
|
/// Attributes are copied verbatim by default, and only translated if they are
|
|
/// type attributes.
|
|
///
|
|
/// Region bodies, if any, are not modified and expected to either be processed
|
|
/// by the conversion infrastructure or already contain ops compatible with LLVM
|
|
/// dialect types.
|
|
template <typename T>
|
|
struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
|
|
using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
|
|
|
|
OpenMPOpConversion(LLVMTypeConverter &typeConverter,
|
|
PatternBenefit benefit = 1)
|
|
: ConvertOpToLLVMPattern<T>(typeConverter, benefit) {
|
|
// Operations using CanonicalLoopInfoType are lowered only by
|
|
// mlir::translateModuleToLLVMIR() using the OpenMPIRBuilder. Until then,
|
|
// the type and operations using it must be preserved.
|
|
typeConverter.addConversion(
|
|
[&](::mlir::omp::CanonicalLoopInfoType type) { return type; });
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(T op, typename T::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Translate result types.
|
|
const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
|
|
SmallVector<Type> resTypes;
|
|
if (failed(converter->convertTypes(op->getResultTypes(), resTypes)))
|
|
return failure();
|
|
|
|
// Translate type attributes.
|
|
// They are kept unmodified except if they are type attributes.
|
|
SmallVector<NamedAttribute> convertedAttrs;
|
|
for (NamedAttribute attr : op->getAttrs()) {
|
|
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
|
|
Type convertedType = converter->convertType(typeAttr.getValue());
|
|
convertedAttrs.emplace_back(attr.getName(),
|
|
TypeAttr::get(convertedType));
|
|
} else {
|
|
convertedAttrs.push_back(attr);
|
|
}
|
|
}
|
|
|
|
// Translate operands.
|
|
SmallVector<Value> convertedOperands;
|
|
convertedOperands.reserve(op->getNumOperands());
|
|
for (auto [originalOperand, convertedOperand] :
|
|
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
|
|
if (!originalOperand)
|
|
return failure();
|
|
|
|
// TODO: Revisit whether we need to trigger an error specifically for this
|
|
// set of operations. Consider removing this check or updating the list.
|
|
if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
|
|
omp::FlushOp, omp::MapBoundsOp,
|
|
omp::ThreadprivateOp>::value) {
|
|
if (isa<MemRefType>(originalOperand.getType())) {
|
|
// TODO: Support memref type in variable operands
|
|
return rewriter.notifyMatchFailure(op, "memref is not supported yet");
|
|
}
|
|
}
|
|
convertedOperands.push_back(convertedOperand);
|
|
}
|
|
|
|
// Create new operation.
|
|
auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands,
|
|
convertedAttrs);
|
|
|
|
// Translate regions.
|
|
for (auto [originalRegion, convertedRegion] :
|
|
llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
|
|
rewriter.inlineRegionBefore(originalRegion, convertedRegion,
|
|
convertedRegion.end());
|
|
if (failed(rewriter.convertRegionTypes(&convertedRegion,
|
|
*this->getTypeConverter())))
|
|
return failure();
|
|
}
|
|
|
|
// Delete old operation and replace result uses with those of the new one.
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::configureOpenMPToLLVMConversionLegality(
|
|
ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
|
|
target.addDynamicallyLegalOp<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
|
|
>([&](Operation *op) {
|
|
return typeConverter.isLegal(op->getOperandTypes()) &&
|
|
typeConverter.isLegal(op->getResultTypes()) &&
|
|
llvm::all_of(op->getRegions(),
|
|
[&](Region ®ion) {
|
|
return typeConverter.isLegal(®ion);
|
|
}) &&
|
|
llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) {
|
|
auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
|
|
return !typeAttr || typeConverter.isLegal(typeAttr.getValue());
|
|
});
|
|
});
|
|
}
|
|
|
|
/// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
|
|
/// passed as template argument.
|
|
template <typename... Ts>
|
|
static inline RewritePatternSet &
|
|
addOpenMPOpConversions(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns) {
|
|
return patterns.add<OpenMPOpConversion<Ts>...>(converter);
|
|
}
|
|
|
|
void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns) {
|
|
// This type is allowed when converting OpenMP to LLVM Dialect, it carries
|
|
// bounds information for map clauses and the operation and type are
|
|
// discarded on lowering to LLVM-IR from the OpenMP dialect.
|
|
converter.addConversion(
|
|
[&](omp::MapBoundsType type) -> Type { return type; });
|
|
|
|
// Add conversions for all OpenMP operations.
|
|
addOpenMPOpConversions<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
|
|
>(converter, patterns);
|
|
}
|
|
|
|
namespace {
|
|
struct ConvertOpenMPToLLVMPass
|
|
: public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
|
|
using Base::Base;
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void ConvertOpenMPToLLVMPass::runOnOperation() {
|
|
auto module = getOperation();
|
|
|
|
// Convert to OpenMP operations with LLVM IR dialect
|
|
RewritePatternSet patterns(&getContext());
|
|
LLVMTypeConverter converter(&getContext());
|
|
arith::populateArithToLLVMConversionPatterns(converter, patterns);
|
|
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
|
cf::populateAssertToLLVMConversionPattern(converter, patterns);
|
|
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
|
|
populateFuncToLLVMConversionPatterns(converter, patterns);
|
|
populateOpenMPToLLVMConversionPatterns(converter, patterns);
|
|
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
|
|
omp::TaskyieldOp, omp::TerminatorOp>();
|
|
configureOpenMPToLLVMConversionLegality(target, converter);
|
|
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConvertToLLVMPatternInterface implementation
|
|
//===----------------------------------------------------------------------===//
|
|
namespace {
|
|
/// Implement the interface to convert OpenMP to LLVM.
|
|
struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
|
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
|
void loadDependentDialects(MLIRContext *context) const final {
|
|
context->loadDialect<LLVM::LLVMDialect>();
|
|
}
|
|
|
|
/// Hook for derived dialect interface to provide conversion patterns
|
|
/// and mark dialect legal for the conversion target.
|
|
void populateConvertToLLVMConversionPatterns(
|
|
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) const final {
|
|
configureOpenMPToLLVMConversionLegality(target, typeConverter);
|
|
populateOpenMPToLLVMConversionPatterns(typeConverter, patterns);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
|
|
dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
|
|
});
|
|
}
|