Alex Zinenko 684dfe8adb [mlir] factor out ConvertToLLVMPattern
This class and classes that extend it are general utilities for any dialect
that is being converted into the LLVM dialect. They are in no way specific to
Standard-to-LLVM conversion and should not make their users depend on it.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D105542
2021-07-08 10:27:05 +02:00

80 lines
3.1 KiB
C++

//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
using namespace mlir;
namespace {
/// A pattern that converts the region arguments in a single-region OpenMP
/// operation to the LLVM dialect. The body of the region is not modified and is
/// expected to either be processed by the conversion infrastructure or already
/// contain ops compatible with LLVM dialect types.
template <typename OpType>
struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(OpType curOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
curOp->getAttrs());
rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
newOp.region().end());
if (failed(rewriter.convertRegionTypes(&newOp.region(),
*this->getTypeConverter())))
return failure();
rewriter.eraseOp(curOp);
return success();
}
};
} // namespace
void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<RegionOpConversion<omp::ParallelOp>,
RegionOpConversion<omp::WsLoopOp>>(converter);
}
namespace {
struct ConvertOpenMPToLLVMPass
: public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
void runOnOperation() override;
};
} // namespace
void ConvertOpenMPToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to OpenMP operations with LLVM IR dialect
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
[&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
omp::BarrierOp, omp::TaskwaitOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
return std::make_unique<ConvertOpenMPToLLVMPass>();
}