This PR proposes to add `Stop` and `ErrorStop` PRIF call procedures to the MIF dialect. If the `-fcoarray` flag is passed, then all calls to `STOP` and `ERROR STOP` will use those of PRIF in flang-rt. Thes procedure has been registered during the initialization (mif::InitOp). --------- Co-authored-by: Dan Bonachea <dobonachea@lbl.gov>
841 lines
34 KiB
C++
841 lines
34 KiB
C++
//===-- MIFOpConversion.cpp -----------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Optimizer/Transforms/MIFOpConversion.h"
|
|
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
|
|
#include "flang/Optimizer/Builder/Todo.h"
|
|
#include "flang/Optimizer/CodeGen/TypeConverter.h"
|
|
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
|
#include "flang/Optimizer/Dialect/FIROps.h"
|
|
#include "flang/Optimizer/Dialect/MIF/MIFOps.h"
|
|
#include "flang/Optimizer/HLFIR/HLFIROps.h"
|
|
#include "flang/Optimizer/Support/DataLayout.h"
|
|
#include "flang/Optimizer/Support/InternalNames.h"
|
|
#include "flang/Runtime/stop.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace fir {
|
|
#define GEN_PASS_DEF_MIFOPCONVERSION
|
|
#include "flang/Optimizer/Transforms/Passes.h.inc"
|
|
} // namespace fir
|
|
|
|
using namespace mlir;
|
|
using namespace Fortran::runtime;
|
|
|
|
namespace {
|
|
|
|
// Default prefix for subroutines of PRIF compiled with LLVM
|
|
static std::string getPRIFProcName(std::string fmt) {
|
|
std::ostringstream oss;
|
|
oss << "prif_" << fmt;
|
|
return fir::NameUniquer::doProcedure({"prif"}, {}, oss.str());
|
|
}
|
|
|
|
static mlir::Type getPRIFStatType(fir::FirOpBuilder &builder) {
|
|
return builder.getRefType(builder.getI32Type());
|
|
}
|
|
|
|
static mlir::Type getPRIFErrmsgType(fir::FirOpBuilder &builder) {
|
|
return fir::BoxType::get(fir::CharacterType::get(
|
|
builder.getContext(), 1, fir::CharacterType::unknownLen()));
|
|
}
|
|
|
|
// Most PRIF functions take `errmsg` and `errmsg_alloc` as two optional
|
|
// arguments of intent (out). One is allocatable, the other is not.
|
|
// It is the responsibility of the compiler to ensure that the appropriate
|
|
// optional argument is passed, and at most one must be provided in a given
|
|
// call.
|
|
// Depending on the type of `errmsg`, this function will return the pair
|
|
// corresponding to (`errmsg`, `errmsg_alloc`).
|
|
static std::pair<mlir::Value, mlir::Value>
|
|
genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
|
|
mlir::Value errmsg) {
|
|
mlir::Value absent =
|
|
fir::AbsentOp::create(builder, loc, getPRIFErrmsgType(builder));
|
|
if (!errmsg)
|
|
return {absent, absent};
|
|
|
|
bool isAllocatableErrmsg = fir::isAllocatableType(errmsg.getType());
|
|
mlir::Value errMsg = isAllocatableErrmsg ? absent : errmsg;
|
|
mlir::Value errMsgAlloc = isAllocatableErrmsg ? errmsg : absent;
|
|
return {errMsg, errMsgAlloc};
|
|
}
|
|
|
|
static mlir::Value genStatPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
|
|
mlir::Value stat) {
|
|
if (!stat)
|
|
return fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
|
|
return stat;
|
|
}
|
|
|
|
static fir::CallOp genPRIFStopErrorStop(fir::FirOpBuilder &builder,
|
|
mlir::Location loc,
|
|
mlir::Value stopCode,
|
|
bool isError = false) {
|
|
mlir::Type stopCharTy = fir::BoxCharType::get(builder.getContext(), 1);
|
|
mlir::Type i1Ty = builder.getI1Type();
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/
|
|
{builder.getRefType(i1Ty), builder.getRefType(i32Ty), stopCharTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
isError
|
|
? builder.createFunction(loc, getPRIFProcName("error_stop"), ftype)
|
|
: builder.createFunction(loc, getPRIFProcName("stop"), ftype);
|
|
|
|
// QUIET is managed in flang-rt, so its value is set to TRUE here.
|
|
mlir::Value q = builder.createBool(loc, true);
|
|
mlir::Value quiet = builder.createTemporary(loc, i1Ty);
|
|
fir::StoreOp::create(builder, loc, q, quiet);
|
|
|
|
mlir::Value stopCodeInt, stopCodeChar;
|
|
if (!stopCode) {
|
|
stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
|
|
stopCodeInt =
|
|
fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
|
|
} else if (fir::isa_integer(stopCode.getType())) {
|
|
stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
|
|
stopCodeInt = builder.createTemporary(loc, i32Ty);
|
|
if (stopCode.getType() != i32Ty)
|
|
stopCode = fir::ConvertOp::create(builder, loc, i32Ty, stopCode);
|
|
fir::StoreOp::create(builder, loc, stopCode, stopCodeInt);
|
|
} else {
|
|
stopCodeChar = stopCode;
|
|
if (!mlir::isa<fir::BoxCharType>(stopCodeChar.getType())) {
|
|
auto len =
|
|
fir::UndefOp::create(builder, loc, builder.getCharacterLengthType());
|
|
stopCodeChar =
|
|
fir::EmboxCharOp::create(builder, loc, stopCharTy, stopCodeChar, len);
|
|
}
|
|
stopCodeInt =
|
|
fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
|
|
}
|
|
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, quiet, stopCodeInt, stopCodeChar);
|
|
return fir::CallOp::create(builder, loc, funcOp, args);
|
|
}
|
|
|
|
enum class TerminationKind { Normal = 0, Error = 1, FailImage = 2 };
|
|
// Generates a wrapper function for the different kind of termination in PRIF.
|
|
// This function will be used to register wrappers on PRIF runtime termination
|
|
// functions into the Fortran runtime.
|
|
mlir::Value genTerminationOperationWrapper(fir::FirOpBuilder &builder,
|
|
mlir::Location loc,
|
|
mlir::ModuleOp module,
|
|
TerminationKind termKind) {
|
|
std::string funcName;
|
|
mlir::FunctionType funcType =
|
|
mlir::FunctionType::get(builder.getContext(), {}, {});
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
if (termKind == TerminationKind::Normal) {
|
|
funcName = getPRIFProcName("stop");
|
|
funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
|
|
} else if (termKind == TerminationKind::Error) {
|
|
funcName = getPRIFProcName("error_stop");
|
|
funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
|
|
} else {
|
|
funcName = getPRIFProcName("fail_image");
|
|
}
|
|
funcName += "_termination_wrapper";
|
|
mlir::func::FuncOp funcWrapperOp =
|
|
module.lookupSymbol<mlir::func::FuncOp>(funcName);
|
|
|
|
if (!funcWrapperOp) {
|
|
funcWrapperOp = builder.createFunction(loc, funcName, funcType);
|
|
|
|
// generating the body of the function.
|
|
mlir::OpBuilder::InsertPoint saveInsertPoint = builder.saveInsertionPoint();
|
|
builder.setInsertionPointToStart(funcWrapperOp.addEntryBlock());
|
|
|
|
if (termKind == TerminationKind::Normal) {
|
|
genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
|
|
/*isError*/ false);
|
|
} else if (termKind == TerminationKind::Error) {
|
|
genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
|
|
/*isError*/ true);
|
|
} else {
|
|
mlir::func::FuncOp fOp = builder.createFunction(
|
|
loc, getPRIFProcName("fail_image"),
|
|
mlir::FunctionType::get(builder.getContext(), {}, {}));
|
|
fir::CallOp::create(builder, loc, fOp);
|
|
}
|
|
|
|
mlir::func::ReturnOp::create(builder, loc);
|
|
builder.restoreInsertionPoint(saveInsertPoint);
|
|
}
|
|
|
|
mlir::SymbolRefAttr symbolRef = mlir::SymbolRefAttr::get(
|
|
builder.getContext(), funcWrapperOp.getSymNameAttr());
|
|
return fir::AddrOfOp::create(builder, loc, funcType, symbolRef);
|
|
}
|
|
|
|
/// Convert mif.init operation to runtime call of 'prif_init'
|
|
struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::InitOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
mlir::Value result = builder.createTemporary(loc, i32Ty);
|
|
|
|
// Registering PRIF runtime termination to the Fortran runtime
|
|
// STOP
|
|
mlir::Value funcStopOp = genTerminationOperationWrapper(
|
|
builder, loc, mod, TerminationKind::Normal);
|
|
mlir::func::FuncOp normalEndFunc =
|
|
fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesNormalEndCallback)>(
|
|
loc, builder);
|
|
llvm::SmallVector<mlir::Value> args1 = fir::runtime::createArguments(
|
|
builder, loc, normalEndFunc.getFunctionType(), funcStopOp);
|
|
fir::CallOp::create(builder, loc, normalEndFunc, args1);
|
|
|
|
// ERROR STOP
|
|
mlir::Value funcErrorStopOp = genTerminationOperationWrapper(
|
|
builder, loc, mod, TerminationKind::Error);
|
|
mlir::func::FuncOp errorFunc =
|
|
fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesErrorCallback)>(
|
|
loc, builder);
|
|
llvm::SmallVector<mlir::Value> args2 = fir::runtime::createArguments(
|
|
builder, loc, errorFunc.getFunctionType(), funcErrorStopOp);
|
|
fir::CallOp::create(builder, loc, errorFunc, args2);
|
|
|
|
// FAIL IMAGE
|
|
mlir::Value failImageOp = genTerminationOperationWrapper(
|
|
builder, loc, mod, TerminationKind::FailImage);
|
|
mlir::func::FuncOp failImageFunc =
|
|
fir::runtime::getRuntimeFunc<mkRTKey(RegisterFailImageCallback)>(
|
|
loc, builder);
|
|
llvm::SmallVector<mlir::Value> args3 = fir::runtime::createArguments(
|
|
builder, loc, errorFunc.getFunctionType(), failImageOp);
|
|
fir::CallOp::create(builder, loc, failImageFunc, args3);
|
|
|
|
// Intialize the multi-image parallel environment
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("init"), ftype);
|
|
llvm::SmallVector<mlir::Value> args =
|
|
fir::runtime::createArguments(builder, loc, ftype, result);
|
|
fir::CallOp::create(builder, loc, funcOp, args);
|
|
rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.this_image operation to PRIF runtime call
|
|
struct MIFThisImageOpConversion
|
|
: public mlir::OpRewritePattern<mif::ThisImageOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::ThisImageOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
if (op.getCoarray())
|
|
TODO(loc, "mif.this_image op with coarray argument.");
|
|
else {
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
mlir::Type boxTy = fir::BoxType::get(rewriter.getNoneType());
|
|
mlir::Value result = builder.createTemporary(loc, i32Ty);
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {boxTy, builder.getRefType(i32Ty)}, /*results*/ {});
|
|
mlir::Value teamArg = op.getTeam();
|
|
if (!op.getTeam())
|
|
teamArg = fir::AbsentOp::create(builder, loc, boxTy);
|
|
|
|
mlir::func::FuncOp funcOp = builder.createFunction(
|
|
loc, getPRIFProcName("this_image_no_coarray"), ftype);
|
|
llvm::SmallVector<mlir::Value> args =
|
|
fir::runtime::createArguments(builder, loc, ftype, teamArg, result);
|
|
fir::CallOp::create(builder, loc, funcOp, args);
|
|
rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result);
|
|
return mlir::success();
|
|
}
|
|
}
|
|
};
|
|
|
|
/// Convert mif.num_images operation to runtime call of
|
|
/// prif_num_images_with_{team|team_number}
|
|
struct MIFNumImagesOpConversion
|
|
: public mlir::OpRewritePattern<mif::NumImagesOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::NumImagesOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
mlir::Type i64Ty = builder.getI64Type();
|
|
mlir::Type boxTy = fir::BoxType::get(rewriter.getNoneType());
|
|
mlir::Value result = builder.createTemporary(loc, i32Ty);
|
|
|
|
mlir::func::FuncOp funcOp;
|
|
llvm::SmallVector<mlir::Value> args;
|
|
if (!op.getTeam() && !op.getTeamNumber()) {
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {});
|
|
funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("num_images"), ftype);
|
|
args = fir::runtime::createArguments(builder, loc, ftype, result);
|
|
} else {
|
|
if (op.getTeam()) {
|
|
mlir::FunctionType ftype =
|
|
mlir::FunctionType::get(builder.getContext(),
|
|
/*inputs*/
|
|
{boxTy, builder.getRefType(i32Ty)},
|
|
/*results*/ {});
|
|
funcOp = builder.createFunction(
|
|
loc, getPRIFProcName("num_images_with_team"), ftype);
|
|
args = fir::runtime::createArguments(builder, loc, ftype, op.getTeam(),
|
|
result);
|
|
} else {
|
|
mlir::Value teamNumber = builder.createTemporary(loc, i64Ty);
|
|
mlir::Value cst = op.getTeamNumber();
|
|
if (op.getTeamNumber().getType() != i64Ty)
|
|
cst = fir::ConvertOp::create(builder, loc, i64Ty, op.getTeamNumber());
|
|
fir::StoreOp::create(builder, loc, cst, teamNumber);
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {builder.getRefType(i64Ty), builder.getRefType(i32Ty)},
|
|
/*results*/ {});
|
|
funcOp = builder.createFunction(
|
|
loc, getPRIFProcName("num_images_with_team_number"), ftype);
|
|
args = fir::runtime::createArguments(builder, loc, ftype, teamNumber,
|
|
result);
|
|
}
|
|
}
|
|
fir::CallOp::create(builder, loc, funcOp, args);
|
|
rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.sync_all operation to runtime call of 'prif_sync_all'
|
|
struct MIFSyncAllOpConversion : public mlir::OpRewritePattern<mif::SyncAllOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::SyncAllOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("sync_all"), ftype);
|
|
|
|
auto [errmsgArg, errmsgAllocArg] =
|
|
genErrmsgPRIF(builder, loc, op.getErrmsg());
|
|
mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
|
|
rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.sync_images operation to runtime call of 'prif_sync_images'
|
|
struct MIFSyncImagesOpConversion
|
|
: public mlir::OpRewritePattern<mif::SyncImagesOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::SyncImagesOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::Type imgSetTy = fir::BoxType::get(fir::SequenceType::get(
|
|
{fir::SequenceType::getUnknownExtent()}, builder.getI32Type()));
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/
|
|
{imgSetTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("sync_images"), ftype);
|
|
|
|
// If imageSet is scalar, PRIF require to pass an array of size 1.
|
|
mlir::Value imageSet = op.getImageSet();
|
|
if (!imageSet)
|
|
imageSet = fir::AbsentOp::create(builder, loc, imgSetTy);
|
|
else if (auto boxTy = mlir::dyn_cast<fir::BoxType>(imageSet.getType())) {
|
|
if (!mlir::isa<fir::SequenceType>(boxTy.getEleTy())) {
|
|
mlir::Value one =
|
|
builder.createIntegerConstant(loc, builder.getI32Type(), 1);
|
|
mlir::Value shape = fir::ShapeOp::create(builder, loc, one);
|
|
imageSet =
|
|
fir::ReboxOp::create(builder, loc,
|
|
fir::BoxType::get(fir::SequenceType::get(
|
|
{1}, builder.getI32Type())),
|
|
imageSet, shape, mlir::Value{});
|
|
}
|
|
}
|
|
auto [errmsgArg, errmsgAllocArg] =
|
|
genErrmsgPRIF(builder, loc, op.getErrmsg());
|
|
mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, imageSet, stat, errmsgArg, errmsgAllocArg);
|
|
rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.sync_memory operation to runtime call of 'prif_sync_memory'
|
|
struct MIFSyncMemoryOpConversion
|
|
: public mlir::OpRewritePattern<mif::SyncMemoryOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::SyncMemoryOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("sync_memory"), ftype);
|
|
|
|
auto [errmsgArg, errmsgAllocArg] =
|
|
genErrmsgPRIF(builder, loc, op.getErrmsg());
|
|
mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
|
|
rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.sync_team operation to runtime call of 'prif_sync_team'
|
|
struct MIFSyncTeamOpConversion
|
|
: public mlir::OpRewritePattern<mif::SyncTeamOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::SyncTeamOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("sync_team"), ftype);
|
|
|
|
auto [errmsgArg, errmsgAllocArg] =
|
|
genErrmsgPRIF(builder, loc, op.getErrmsg());
|
|
mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg);
|
|
rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Generate call to collective subroutines except co_reduce
|
|
/// A must be lowered as a box
|
|
static fir::CallOp genCollectiveSubroutine(fir::FirOpBuilder &builder,
|
|
mlir::Location loc, mlir::Value A,
|
|
mlir::Value image, mlir::Value stat,
|
|
mlir::Value errmsg,
|
|
std::string coName) {
|
|
mlir::Value rootImage;
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
if (!image)
|
|
rootImage = fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
|
|
else {
|
|
rootImage = builder.createTemporary(loc, i32Ty);
|
|
if (image.getType() != i32Ty)
|
|
image = fir::ConvertOp::create(builder, loc, i32Ty, image);
|
|
fir::StoreOp::create(builder, loc, image, rootImage);
|
|
}
|
|
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
|
|
mlir::FunctionType ftype =
|
|
mlir::FunctionType::get(builder.getContext(),
|
|
/*inputs*/
|
|
{boxTy, builder.getRefType(builder.getI32Type()),
|
|
getPRIFStatType(builder), errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp = builder.createFunction(loc, coName, ftype);
|
|
|
|
auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg);
|
|
if (!stat)
|
|
stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, A, rootImage, stat, errmsgArg, errmsgAllocArg);
|
|
return fir::CallOp::create(builder, loc, funcOp, args);
|
|
}
|
|
|
|
/// Convert mif.co_broadcast operation to runtime call of 'prif_co_broadcast'
|
|
struct MIFCoBroadcastOpConversion
|
|
: public mlir::OpRewritePattern<mif::CoBroadcastOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::CoBroadcastOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
fir::CallOp callOp = genCollectiveSubroutine(
|
|
builder, loc, op.getA(), op.getSourceImage(), op.getStat(),
|
|
op.getErrmsg(), getPRIFProcName("co_broadcast"));
|
|
rewriter.replaceOp(op, callOp);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.co_max operation to runtime call of 'prif_co_max'
|
|
struct MIFCoMaxOpConversion : public mlir::OpRewritePattern<mif::CoMaxOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::CoMaxOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
fir::CallOp callOp;
|
|
mlir::Type argTy =
|
|
fir::unwrapSequenceType(fir::unwrapPassByRefType(op.getA().getType()));
|
|
if (mlir::isa<fir::CharacterType>(argTy))
|
|
callOp = genCollectiveSubroutine(
|
|
builder, loc, op.getA(), op.getResultImage(), op.getStat(),
|
|
op.getErrmsg(), getPRIFProcName("co_max_character"));
|
|
else
|
|
callOp = genCollectiveSubroutine(
|
|
builder, loc, op.getA(), op.getResultImage(), op.getStat(),
|
|
op.getErrmsg(), getPRIFProcName("co_max"));
|
|
rewriter.replaceOp(op, callOp);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.co_min operation to runtime call of 'prif_co_min'
|
|
struct MIFCoMinOpConversion : public mlir::OpRewritePattern<mif::CoMinOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::CoMinOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
fir::CallOp callOp;
|
|
mlir::Type argTy =
|
|
fir::unwrapSequenceType(fir::unwrapPassByRefType(op.getA().getType()));
|
|
if (mlir::isa<fir::CharacterType>(argTy))
|
|
callOp = genCollectiveSubroutine(
|
|
builder, loc, op.getA(), op.getResultImage(), op.getStat(),
|
|
op.getErrmsg(), getPRIFProcName("co_min_character"));
|
|
else
|
|
callOp = genCollectiveSubroutine(
|
|
builder, loc, op.getA(), op.getResultImage(), op.getStat(),
|
|
op.getErrmsg(), getPRIFProcName("co_min"));
|
|
rewriter.replaceOp(op, callOp);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.co_sum operation to runtime call of 'prif_co_sum'
|
|
struct MIFCoSumOpConversion : public mlir::OpRewritePattern<mif::CoSumOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::CoSumOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
fir::CallOp callOp = genCollectiveSubroutine(
|
|
builder, loc, op.getA(), op.getResultImage(), op.getStat(),
|
|
op.getErrmsg(), getPRIFProcName("co_sum"));
|
|
rewriter.replaceOp(op, callOp);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.form_team operation to runtime call of 'prif_form_team'
|
|
struct MIFFormTeamOpConversion
|
|
: public mlir::OpRewritePattern<mif::FormTeamOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::FormTeamOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/
|
|
{builder.getRefType(builder.getI64Type()), boxTy,
|
|
builder.getRefType(builder.getI32Type()), getPRIFStatType(builder),
|
|
errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("form_team"), ftype);
|
|
|
|
mlir::Type i64Ty = builder.getI64Type();
|
|
mlir::Value teamNumber = builder.createTemporary(loc, i64Ty);
|
|
mlir::Value t =
|
|
(op.getTeamNumber().getType() == i64Ty)
|
|
? op.getTeamNumber()
|
|
: fir::ConvertOp::create(builder, loc, i64Ty, op.getTeamNumber());
|
|
fir::StoreOp::create(builder, loc, t, teamNumber);
|
|
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
mlir::Value newIndex;
|
|
if (op.getNewIndex()) {
|
|
newIndex = builder.createTemporary(loc, i32Ty);
|
|
mlir::Value ni =
|
|
(op.getNewIndex().getType() == i32Ty)
|
|
? op.getNewIndex()
|
|
: fir::ConvertOp::create(builder, loc, i32Ty, op.getNewIndex());
|
|
fir::StoreOp::create(builder, loc, ni, newIndex);
|
|
} else
|
|
newIndex = fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
|
|
|
|
mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
|
|
auto [errmsgArg, errmsgAllocArg] =
|
|
genErrmsgPRIF(builder, loc, op.getErrmsg());
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, teamNumber, op.getTeamVar(), newIndex, stat,
|
|
errmsgArg, errmsgAllocArg);
|
|
fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args);
|
|
rewriter.replaceOp(op, callOp);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.change_team operation to runtime call of 'prif_change_team'
|
|
struct MIFChangeTeamOpConversion
|
|
: public mlir::OpRewritePattern<mif::ChangeTeamOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::ChangeTeamOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
builder.setInsertionPoint(op);
|
|
|
|
mlir::Location loc = op.getLoc();
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("change_team"), ftype);
|
|
|
|
mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
|
|
auto [errmsgArg, errmsgAllocArg] =
|
|
genErrmsgPRIF(builder, loc, op.getErrmsg());
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg);
|
|
fir::CallOp::create(builder, loc, funcOp, args);
|
|
|
|
mlir::Operation *changeOp = op.getOperation();
|
|
auto &bodyRegion = op.getRegion();
|
|
mlir::Block &bodyBlock = bodyRegion.front();
|
|
|
|
rewriter.inlineBlockBefore(&bodyBlock, changeOp);
|
|
rewriter.eraseOp(op);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.end_team operation to runtime call of 'prif_end_team'
|
|
struct MIFEndTeamOpConversion : public mlir::OpRewritePattern<mif::EndTeamOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::EndTeamOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
mlir::Type errmsgTy = getPRIFErrmsgType(builder);
|
|
mlir::FunctionType ftype = mlir::FunctionType::get(
|
|
builder.getContext(),
|
|
/*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("end_team"), ftype);
|
|
|
|
mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
|
|
auto [errmsgArg, errmsgAllocArg] =
|
|
genErrmsgPRIF(builder, loc, op.getErrmsg());
|
|
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
|
|
builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
|
|
fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args);
|
|
rewriter.replaceOp(op, callOp);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.get_team operation to runtime call of 'prif_get_team'
|
|
struct MIFGetTeamOpConversion : public mlir::OpRewritePattern<mif::GetTeamOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::GetTeamOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
|
|
mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
|
|
mlir::Type lvlTy = builder.getRefType(builder.getI32Type());
|
|
mlir::FunctionType ftype =
|
|
mlir::FunctionType::get(builder.getContext(),
|
|
/*inputs*/ {lvlTy, boxTy},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("get_team"), ftype);
|
|
|
|
mlir::Value level = op.getLevel();
|
|
if (!level)
|
|
level = fir::AbsentOp::create(builder, loc, lvlTy);
|
|
else {
|
|
mlir::Value cst = op.getLevel();
|
|
mlir::Type i32Ty = builder.getI32Type();
|
|
level = builder.createTemporary(loc, i32Ty);
|
|
if (cst.getType() != i32Ty)
|
|
cst = builder.createConvert(loc, i32Ty, cst);
|
|
fir::StoreOp::create(builder, loc, cst, level);
|
|
}
|
|
mlir::Type resultType = op.getResult().getType();
|
|
mlir::Type baseTy = fir::unwrapRefType(resultType);
|
|
mlir::Value team = builder.createTemporary(loc, baseTy);
|
|
fir::EmboxOp box = fir::EmboxOp::create(builder, loc, resultType, team);
|
|
|
|
llvm::SmallVector<mlir::Value> args =
|
|
fir::runtime::createArguments(builder, loc, ftype, level, box);
|
|
fir::CallOp::create(builder, loc, funcOp, args);
|
|
|
|
rewriter.replaceOp(op, box);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
/// Convert mif.team_number operation to runtime call of 'prif_team_number'
|
|
struct MIFTeamNumberOpConversion
|
|
: public mlir::OpRewritePattern<mif::TeamNumberOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mif::TeamNumberOp op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto mod = op->template getParentOfType<mlir::ModuleOp>();
|
|
fir::FirOpBuilder builder(rewriter, mod);
|
|
mlir::Location loc = op.getLoc();
|
|
mlir::Type i64Ty = builder.getI64Type();
|
|
mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
|
|
mlir::FunctionType ftype =
|
|
mlir::FunctionType::get(builder.getContext(),
|
|
/*inputs*/ {boxTy, builder.getRefType(i64Ty)},
|
|
/*results*/ {});
|
|
mlir::func::FuncOp funcOp =
|
|
builder.createFunction(loc, getPRIFProcName("team_number"), ftype);
|
|
|
|
mlir::Value team = op.getTeam();
|
|
if (!team)
|
|
team = fir::AbsentOp::create(builder, loc, boxTy);
|
|
|
|
mlir::Value result = builder.createTemporary(loc, i64Ty);
|
|
llvm::SmallVector<mlir::Value> args =
|
|
fir::runtime::createArguments(builder, loc, ftype, team, result);
|
|
fir::CallOp::create(builder, loc, funcOp, args);
|
|
fir::LoadOp load = fir::LoadOp::create(builder, loc, result);
|
|
rewriter.replaceOp(op, load);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
class MIFOpConversion : public fir::impl::MIFOpConversionBase<MIFOpConversion> {
|
|
public:
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
mlir::RewritePatternSet patterns(ctx);
|
|
mlir::ConversionTarget target(*ctx);
|
|
|
|
mif::populateMIFOpConversionPatterns(patterns);
|
|
|
|
target.addLegalDialect<fir::FIROpsDialect>();
|
|
target.addLegalOp<mlir::ModuleOp>();
|
|
|
|
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
|
|
std::move(patterns)))) {
|
|
mlir::emitError(mlir::UnknownLoc::get(ctx),
|
|
"error in MIF op conversion\n");
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mif::populateMIFOpConversionPatterns(mlir::RewritePatternSet &patterns) {
|
|
patterns.insert<MIFInitOpConversion, MIFThisImageOpConversion,
|
|
MIFNumImagesOpConversion, MIFSyncAllOpConversion,
|
|
MIFSyncImagesOpConversion, MIFSyncMemoryOpConversion,
|
|
MIFSyncTeamOpConversion, MIFCoBroadcastOpConversion,
|
|
MIFCoMaxOpConversion, MIFCoMinOpConversion,
|
|
MIFCoSumOpConversion, MIFFormTeamOpConversion,
|
|
MIFChangeTeamOpConversion, MIFEndTeamOpConversion,
|
|
MIFGetTeamOpConversion, MIFTeamNumberOpConversion>(
|
|
patterns.getContext());
|
|
}
|