
* inplace allreduce needs special MPI token MPI_IN_PLACE as send buffer * 0d tensors have no sizes/strides in LLVM memref struct
812 lines
30 KiB
C++
812 lines
30 KiB
C++
//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
// Copyright (C) by Argonne National Laboratory
|
|
// See COPYRIGHT in top-level directory
|
|
// of MPICH source repository.
|
|
//
|
|
|
|
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/DLTI/DLTI.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/MPI/IR/MPI.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include <memory>
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
template <typename Op, typename... Args>
|
|
static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
|
|
ConversionPatternRewriter &rewriter, StringRef name,
|
|
Args &&...args) {
|
|
Op ret;
|
|
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
|
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
|
ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
|
|
const Location loc,
|
|
ConversionPatternRewriter &rewriter,
|
|
StringRef name,
|
|
LLVM::LLVMFunctionType type) {
|
|
return getOrDefineGlobal<LLVM::LLVMFuncOp>(
|
|
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
|
|
}
|
|
|
|
std::pair<Value, Value> getRawPtrAndSize(const Location loc,
|
|
ConversionPatternRewriter &rewriter,
|
|
Value memRef, Type elType) {
|
|
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
Value dataPtr =
|
|
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
|
|
Value offset = rewriter.create<LLVM::ExtractValueOp>(
|
|
loc, rewriter.getI64Type(), memRef, 2);
|
|
Value resPtr =
|
|
rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
|
|
Value size;
|
|
if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
|
|
size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
|
|
ArrayRef<int64_t>{3, 0});
|
|
size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
|
|
} else {
|
|
size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
|
|
}
|
|
return {resPtr, size};
|
|
}
|
|
|
|
/// When lowering the mpi dialect to functions calls certain details
|
|
/// differ between various MPI implementations. This class will provide
|
|
/// these in a generic way, depending on the MPI implementation that got
|
|
/// selected by the DLTI attribute on the module.
|
|
class MPIImplTraits {
|
|
ModuleOp &moduleOp;
|
|
|
|
public:
|
|
/// Instantiate a new MPIImplTraits object according to the DLTI attribute
|
|
/// on the given module. Default to MPICH if no attribute is present or
|
|
/// the value is unknown.
|
|
static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
|
|
|
|
explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
|
|
|
|
virtual ~MPIImplTraits() = default;
|
|
|
|
ModuleOp &getModuleOp() { return moduleOp; }
|
|
|
|
/// Gets or creates MPI_COMM_WORLD as a Value.
|
|
/// Different MPI implementations have different communicator types.
|
|
/// Using i64 as a portable, intermediate type.
|
|
/// Appropriate cast needs to take place before calling MPI functions.
|
|
virtual Value getCommWorld(const Location loc,
|
|
ConversionPatternRewriter &rewriter) = 0;
|
|
|
|
/// Type converter provides i64 type for communicator type.
|
|
/// Converts to native type, which might be ptr or int or whatever.
|
|
virtual Value castComm(const Location loc,
|
|
ConversionPatternRewriter &rewriter, Value comm) = 0;
|
|
|
|
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
|
|
virtual intptr_t getStatusIgnore() = 0;
|
|
|
|
/// Get the MPI_IN_PLACE value (void *).
|
|
virtual void *getInPlace() = 0;
|
|
|
|
/// Gets or creates an MPI datatype as a value which corresponds to the given
|
|
/// type.
|
|
virtual Value getDataType(const Location loc,
|
|
ConversionPatternRewriter &rewriter, Type type) = 0;
|
|
|
|
/// Gets or creates an MPI_Op value which corresponds to the given
|
|
/// enum value.
|
|
virtual Value getMPIOp(const Location loc,
|
|
ConversionPatternRewriter &rewriter,
|
|
mpi::MPI_OpClassEnum opAttr) = 0;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Implementation details for MPICH ABI compatible MPI implementations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class MPICHImplTraits : public MPIImplTraits {
|
|
static constexpr int MPI_FLOAT = 0x4c00040a;
|
|
static constexpr int MPI_DOUBLE = 0x4c00080b;
|
|
static constexpr int MPI_INT8_T = 0x4c000137;
|
|
static constexpr int MPI_INT16_T = 0x4c000238;
|
|
static constexpr int MPI_INT32_T = 0x4c000439;
|
|
static constexpr int MPI_INT64_T = 0x4c00083a;
|
|
static constexpr int MPI_UINT8_T = 0x4c00013b;
|
|
static constexpr int MPI_UINT16_T = 0x4c00023c;
|
|
static constexpr int MPI_UINT32_T = 0x4c00043d;
|
|
static constexpr int MPI_UINT64_T = 0x4c00083e;
|
|
static constexpr int MPI_MAX = 0x58000001;
|
|
static constexpr int MPI_MIN = 0x58000002;
|
|
static constexpr int MPI_SUM = 0x58000003;
|
|
static constexpr int MPI_PROD = 0x58000004;
|
|
static constexpr int MPI_LAND = 0x58000005;
|
|
static constexpr int MPI_BAND = 0x58000006;
|
|
static constexpr int MPI_LOR = 0x58000007;
|
|
static constexpr int MPI_BOR = 0x58000008;
|
|
static constexpr int MPI_LXOR = 0x58000009;
|
|
static constexpr int MPI_BXOR = 0x5800000a;
|
|
static constexpr int MPI_MINLOC = 0x5800000b;
|
|
static constexpr int MPI_MAXLOC = 0x5800000c;
|
|
static constexpr int MPI_REPLACE = 0x5800000d;
|
|
static constexpr int MPI_NO_OP = 0x5800000e;
|
|
|
|
public:
|
|
using MPIImplTraits::MPIImplTraits;
|
|
|
|
~MPICHImplTraits() override = default;
|
|
|
|
Value getCommWorld(const Location loc,
|
|
ConversionPatternRewriter &rewriter) override {
|
|
static constexpr int MPI_COMM_WORLD = 0x44000000;
|
|
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
|
|
MPI_COMM_WORLD);
|
|
}
|
|
|
|
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
|
|
Value comm) override {
|
|
return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
|
|
}
|
|
|
|
intptr_t getStatusIgnore() override { return 1; }
|
|
|
|
void *getInPlace() override { return reinterpret_cast<void *>(-1); }
|
|
|
|
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
|
|
Type type) override {
|
|
int32_t mtype = 0;
|
|
if (type.isF32())
|
|
mtype = MPI_FLOAT;
|
|
else if (type.isF64())
|
|
mtype = MPI_DOUBLE;
|
|
else if (type.isInteger(64) && !type.isUnsignedInteger())
|
|
mtype = MPI_INT64_T;
|
|
else if (type.isInteger(64))
|
|
mtype = MPI_UINT64_T;
|
|
else if (type.isInteger(32) && !type.isUnsignedInteger())
|
|
mtype = MPI_INT32_T;
|
|
else if (type.isInteger(32))
|
|
mtype = MPI_UINT32_T;
|
|
else if (type.isInteger(16) && !type.isUnsignedInteger())
|
|
mtype = MPI_INT16_T;
|
|
else if (type.isInteger(16))
|
|
mtype = MPI_UINT16_T;
|
|
else if (type.isInteger(8) && !type.isUnsignedInteger())
|
|
mtype = MPI_INT8_T;
|
|
else if (type.isInteger(8))
|
|
mtype = MPI_UINT8_T;
|
|
else
|
|
assert(false && "unsupported type");
|
|
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
|
|
}
|
|
|
|
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
|
|
mpi::MPI_OpClassEnum opAttr) override {
|
|
int32_t op = MPI_NO_OP;
|
|
switch (opAttr) {
|
|
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
|
|
op = MPI_NO_OP;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MAX:
|
|
op = MPI_MAX;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MIN:
|
|
op = MPI_MIN;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_SUM:
|
|
op = MPI_SUM;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_PROD:
|
|
op = MPI_PROD;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_LAND:
|
|
op = MPI_LAND;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_BAND:
|
|
op = MPI_BAND;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_LOR:
|
|
op = MPI_LOR;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_BOR:
|
|
op = MPI_BOR;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_LXOR:
|
|
op = MPI_LXOR;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_BXOR:
|
|
op = MPI_BXOR;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MINLOC:
|
|
op = MPI_MINLOC;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
|
|
op = MPI_MAXLOC;
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_REPLACE:
|
|
op = MPI_REPLACE;
|
|
break;
|
|
}
|
|
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Implementation details for OpenMPI
|
|
//===----------------------------------------------------------------------===//
|
|
class OMPIImplTraits : public MPIImplTraits {
|
|
LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
|
|
ConversionPatternRewriter &rewriter,
|
|
StringRef name,
|
|
LLVM::LLVMStructType type) {
|
|
|
|
return getOrDefineGlobal<LLVM::GlobalOp>(
|
|
getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
|
|
LLVM::Linkage::External, name,
|
|
/*value=*/Attribute(), /*alignment=*/0, 0);
|
|
}
|
|
|
|
public:
|
|
using MPIImplTraits::MPIImplTraits;
|
|
|
|
~OMPIImplTraits() override = default;
|
|
|
|
Value getCommWorld(const Location loc,
|
|
ConversionPatternRewriter &rewriter) override {
|
|
auto context = rewriter.getContext();
|
|
// get external opaque struct pointer type
|
|
auto commStructT =
|
|
LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
|
|
StringRef name = "ompi_mpi_comm_world";
|
|
|
|
// make sure global op definition exists
|
|
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
|
|
|
|
// get address of symbol
|
|
auto comm = rewriter.create<LLVM::AddressOfOp>(
|
|
loc, LLVM::LLVMPointerType::get(context),
|
|
SymbolRefAttr::get(context, name));
|
|
return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
|
|
}
|
|
|
|
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
|
|
Value comm) override {
|
|
return rewriter.create<LLVM::IntToPtrOp>(
|
|
loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
|
|
}
|
|
|
|
intptr_t getStatusIgnore() override { return 0; }
|
|
|
|
void *getInPlace() override { return reinterpret_cast<void *>(1); }
|
|
|
|
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
|
|
Type type) override {
|
|
StringRef mtype;
|
|
if (type.isF32())
|
|
mtype = "ompi_mpi_float";
|
|
else if (type.isF64())
|
|
mtype = "ompi_mpi_double";
|
|
else if (type.isInteger(64) && !type.isUnsignedInteger())
|
|
mtype = "ompi_mpi_int64_t";
|
|
else if (type.isInteger(64))
|
|
mtype = "ompi_mpi_uint64_t";
|
|
else if (type.isInteger(32) && !type.isUnsignedInteger())
|
|
mtype = "ompi_mpi_int32_t";
|
|
else if (type.isInteger(32))
|
|
mtype = "ompi_mpi_uint32_t";
|
|
else if (type.isInteger(16) && !type.isUnsignedInteger())
|
|
mtype = "ompi_mpi_int16_t";
|
|
else if (type.isInteger(16))
|
|
mtype = "ompi_mpi_uint16_t";
|
|
else if (type.isInteger(8) && !type.isUnsignedInteger())
|
|
mtype = "ompi_mpi_int8_t";
|
|
else if (type.isInteger(8))
|
|
mtype = "ompi_mpi_uint8_t";
|
|
else
|
|
assert(false && "unsupported type");
|
|
|
|
auto context = rewriter.getContext();
|
|
// get external opaque struct pointer type
|
|
auto typeStructT =
|
|
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
|
|
// make sure global op definition exists
|
|
getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
|
|
// get address of symbol
|
|
return rewriter.create<LLVM::AddressOfOp>(
|
|
loc, LLVM::LLVMPointerType::get(context),
|
|
SymbolRefAttr::get(context, mtype));
|
|
}
|
|
|
|
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
|
|
mpi::MPI_OpClassEnum opAttr) override {
|
|
StringRef op;
|
|
switch (opAttr) {
|
|
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
|
|
op = "ompi_mpi_no_op";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MAX:
|
|
op = "ompi_mpi_max";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MIN:
|
|
op = "ompi_mpi_min";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_SUM:
|
|
op = "ompi_mpi_sum";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_PROD:
|
|
op = "ompi_mpi_prod";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_LAND:
|
|
op = "ompi_mpi_land";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_BAND:
|
|
op = "ompi_mpi_band";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_LOR:
|
|
op = "ompi_mpi_lor";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_BOR:
|
|
op = "ompi_mpi_bor";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_LXOR:
|
|
op = "ompi_mpi_lxor";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_BXOR:
|
|
op = "ompi_mpi_bxor";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MINLOC:
|
|
op = "ompi_mpi_minloc";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
|
|
op = "ompi_mpi_maxloc";
|
|
break;
|
|
case mpi::MPI_OpClassEnum::MPI_REPLACE:
|
|
op = "ompi_mpi_replace";
|
|
break;
|
|
}
|
|
auto context = rewriter.getContext();
|
|
// get external opaque struct pointer type
|
|
auto opStructT =
|
|
LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
|
|
// make sure global op definition exists
|
|
getOrDefineExternalStruct(loc, rewriter, op, opStructT);
|
|
// get address of symbol
|
|
return rewriter.create<LLVM::AddressOfOp>(
|
|
loc, LLVM::LLVMPointerType::get(context),
|
|
SymbolRefAttr::get(context, op));
|
|
}
|
|
};
|
|
|
|
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
|
|
auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
|
|
if (failed(attr))
|
|
return std::make_unique<MPICHImplTraits>(moduleOp);
|
|
auto strAttr = dyn_cast<StringAttr>(attr.value());
|
|
if (strAttr && strAttr.getValue() == "OpenMPI")
|
|
return std::make_unique<OMPIImplTraits>(moduleOp);
|
|
if (!strAttr || strAttr.getValue() != "MPICH")
|
|
moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
|
|
<< strAttr.getValue() << "), defaulting to MPICH";
|
|
return std::make_unique<MPICHImplTraits>(moduleOp);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// InitOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
// ptrType `!llvm.ptr`
|
|
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
|
|
// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
|
|
auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
|
|
Value llvmnull = nullPtrOp.getRes();
|
|
|
|
// grab a reference to the global module op:
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
|
|
// LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
|
|
auto initFuncType =
|
|
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
|
|
// get or create function declaration:
|
|
LLVM::LLVMFuncOp initDecl =
|
|
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
|
|
|
|
// replace init with function call
|
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
|
|
ValueRange{llvmnull, llvmnull});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FinalizeOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// get loc
|
|
Location loc = op.getLoc();
|
|
|
|
// grab a reference to the global module op:
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
|
|
// LLVM Function type representing `i32 MPI_Finalize()`
|
|
auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
|
|
// get or create function declaration:
|
|
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
|
|
moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
|
|
|
|
// replace init with function call
|
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CommWorldOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// grab a reference to the global module op:
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
|
// get MPI_COMM_WORLD
|
|
rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CommSplitOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// grab a reference to the global module op:
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
|
Type i32 = rewriter.getI32Type();
|
|
Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
|
|
Location loc = op.getLoc();
|
|
|
|
// get communicator
|
|
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
|
|
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
|
|
auto outPtr =
|
|
rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
|
|
|
|
// int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
|
|
auto funcType =
|
|
LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
|
|
// get or create function declaration:
|
|
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
|
|
"MPI_Comm_split", funcType);
|
|
|
|
auto callOp = rewriter.create<LLVM::CallOp>(
|
|
loc, funcDecl,
|
|
ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
|
|
outPtr.getRes()});
|
|
|
|
// load the communicator into a register
|
|
Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
|
|
res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
|
|
|
|
// if retval is checked, replace uses of retval with the results from the
|
|
// call op
|
|
SmallVector<Value> replacements;
|
|
if (op.getRetval())
|
|
replacements.push_back(callOp.getResult());
|
|
|
|
// replace op
|
|
replacements.push_back(res);
|
|
rewriter.replaceOp(op, replacements);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CommRankOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// get some helper vars
|
|
Location loc = op.getLoc();
|
|
MLIRContext *context = rewriter.getContext();
|
|
Type i32 = rewriter.getI32Type();
|
|
|
|
// ptrType `!llvm.ptr`
|
|
Type ptrType = LLVM::LLVMPointerType::get(context);
|
|
|
|
// grab a reference to the global module op:
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
|
|
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
|
// get communicator
|
|
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
|
|
|
|
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
|
|
auto rankFuncType =
|
|
LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
|
|
// get or create function declaration:
|
|
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
|
|
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
|
|
|
|
// replace with function call
|
|
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
|
|
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
|
|
auto callOp = rewriter.create<LLVM::CallOp>(
|
|
loc, initDecl, ValueRange{comm, rankptr.getRes()});
|
|
|
|
// load the rank into a register
|
|
auto loadedRank =
|
|
rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
|
|
|
|
// if retval is checked, replace uses of retval with the results from the
|
|
// call op
|
|
SmallVector<Value> replacements;
|
|
if (op.getRetval())
|
|
replacements.push_back(callOp.getResult());
|
|
|
|
// replace all uses, then erase op
|
|
replacements.push_back(loadedRank.getRes());
|
|
rewriter.replaceOp(op, replacements);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SendOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// get some helper vars
|
|
Location loc = op.getLoc();
|
|
MLIRContext *context = rewriter.getContext();
|
|
Type i32 = rewriter.getI32Type();
|
|
Type elemType = op.getRef().getType().getElementType();
|
|
|
|
// ptrType `!llvm.ptr`
|
|
Type ptrType = LLVM::LLVMPointerType::get(context);
|
|
|
|
// grab a reference to the global module op:
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
|
|
// get MPI_COMM_WORLD, dataType and pointer
|
|
auto [dataPtr, size] =
|
|
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
|
|
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
|
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
|
|
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
|
|
|
|
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
|
|
// tag, comm)`
|
|
auto funcType = LLVM::LLVMFunctionType::get(
|
|
i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
|
|
// get or create function declaration:
|
|
LLVM::LLVMFuncOp funcDecl =
|
|
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
|
|
|
|
// replace op with function call
|
|
auto funcCall = rewriter.create<LLVM::CallOp>(
|
|
loc, funcDecl,
|
|
ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
|
|
comm});
|
|
if (op.getRetval())
|
|
rewriter.replaceOp(op, funcCall.getResult());
|
|
else
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RecvOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// get some helper vars
|
|
Location loc = op.getLoc();
|
|
MLIRContext *context = rewriter.getContext();
|
|
Type i32 = rewriter.getI32Type();
|
|
Type i64 = rewriter.getI64Type();
|
|
Type elemType = op.getRef().getType().getElementType();
|
|
|
|
// ptrType `!llvm.ptr`
|
|
Type ptrType = LLVM::LLVMPointerType::get(context);
|
|
|
|
// grab a reference to the global module op:
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
|
|
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
|
|
auto [dataPtr, size] =
|
|
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
|
|
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
|
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
|
|
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
|
|
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
|
|
loc, i64, mpiTraits->getStatusIgnore());
|
|
statusIgnore =
|
|
rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
|
|
|
|
// LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
|
|
// tag, comm)`
|
|
auto funcType =
|
|
LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
|
|
i32, comm.getType(), ptrType});
|
|
// get or create function declaration:
|
|
LLVM::LLVMFuncOp funcDecl =
|
|
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
|
|
|
|
// replace op with function call
|
|
auto funcCall = rewriter.create<LLVM::CallOp>(
|
|
loc, funcDecl,
|
|
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
|
|
adaptor.getTag(), comm, statusIgnore});
|
|
if (op.getRetval())
|
|
rewriter.replaceOp(op, funcCall.getResult());
|
|
else
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AllReduceOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
MLIRContext *context = rewriter.getContext();
|
|
Type i32 = rewriter.getI32Type();
|
|
Type i64 = rewriter.getI64Type();
|
|
Type elemType = op.getSendbuf().getType().getElementType();
|
|
|
|
// ptrType `!llvm.ptr`
|
|
Type ptrType = LLVM::LLVMPointerType::get(context);
|
|
auto moduleOp = op->getParentOfType<ModuleOp>();
|
|
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
|
auto [sendPtr, sendSize] =
|
|
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
|
|
auto [recvPtr, recvSize] =
|
|
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
|
|
|
|
// If input and output are the same, request in-place operation.
|
|
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
|
|
sendPtr = rewriter.create<LLVM::ConstantOp>(
|
|
loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
|
|
sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
|
|
}
|
|
|
|
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
|
|
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
|
|
Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
|
|
|
|
// 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
|
|
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
|
|
auto funcType = LLVM::LLVMFunctionType::get(
|
|
i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
|
|
commWorld.getType()});
|
|
// get or create function declaration:
|
|
LLVM::LLVMFuncOp funcDecl =
|
|
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
|
|
|
|
// replace op with function call
|
|
auto funcCall = rewriter.create<LLVM::CallOp>(
|
|
loc, funcDecl,
|
|
ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
|
|
|
|
if (op.getRetval())
|
|
rewriter.replaceOp(op, funcCall.getResult());
|
|
else
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConvertToLLVMPatternInterface implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Implement the interface to convert Func to LLVM.
|
|
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
|
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
|
/// Hook for derived dialect interface to provide conversion patterns
|
|
/// and mark dialect legal for the conversion target.
|
|
void populateConvertToLLVMConversionPatterns(
|
|
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) const final {
|
|
mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern Population
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns) {
|
|
// Using i64 as a portable, intermediate type for !mpi.comm.
|
|
// It would be nicer to somehow get the right type directly, but TLDI is not
|
|
// available here.
|
|
converter.addConversion([](mpi::CommType type) {
|
|
return IntegerType::get(type.getContext(), 64);
|
|
});
|
|
patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
|
|
FinalizeOpLowering, InitOpLowering, SendOpLowering,
|
|
RecvOpLowering, AllReduceOpLowering>(converter);
|
|
}
|
|
|
|
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
|
|
dialect->addInterfaces<FuncToLLVMDialectInterface>();
|
|
});
|
|
}
|