[mlir][mpi] Lowering Mpi To LLVM (#127053)
* The first set of patterns to convert the MPI dialect to LLVM. * Further conversion pattern will be added in future PRs. * Supports MPICH compatible MPI implementations and openMPI, selectable through DLTI attribute on module --------- Co-authored-by: Anton Lydike <me@antonlydike.de> Co-authored-by: Christian Ulmann <christianulmann@gmail.com>
This commit is contained in:
parent
506deb0cce
commit
ab166d4d10
29
mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
Normal file
29
mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
Normal file
@ -0,0 +1,29 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_MPITOLLVM_H
|
||||
#define MLIR_CONVERSION_MPITOLLVM_H
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class LLVMTypeConverter;
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace mpi {
|
||||
|
||||
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry);
|
||||
|
||||
} // namespace mpi
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_MPITOLLVM_H
|
@ -102,13 +102,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
|
||||
let arguments = (
|
||||
ins AnyMemRef : $ref,
|
||||
I32 : $tag,
|
||||
I32 : $rank
|
||||
I32 : $dest
|
||||
);
|
||||
|
||||
let results = (outs Optional<MPI_Retval>:$retval);
|
||||
|
||||
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
|
||||
"type($ref) `,` type($tag) `,` type($rank)"
|
||||
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
|
||||
"type($ref) `,` type($tag) `,` type($dest)"
|
||||
"(`->` type($retval)^)?";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
@ -154,11 +154,11 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_RecvOp : MPI_Op<"recv", []> {
|
||||
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
|
||||
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
|
||||
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
|
||||
let description = [{
|
||||
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
|
||||
from rank `dest`. The `tag` value and communicator enables the library to
|
||||
from rank `source`. The `tag` value and communicator enables the library to
|
||||
determine the matching of multiple sends and receives between the same
|
||||
ranks.
|
||||
|
||||
@ -172,13 +172,13 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
|
||||
|
||||
let arguments = (
|
||||
ins AnyMemRef : $ref,
|
||||
I32 : $tag, I32 : $rank
|
||||
I32 : $tag, I32 : $source
|
||||
);
|
||||
|
||||
let results = (outs Optional<MPI_Retval>:$retval);
|
||||
|
||||
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
|
||||
"type($ref) `,` type($tag) `,` type($rank)"
|
||||
let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
|
||||
"type($ref) `,` type($tag) `,` type($source)"
|
||||
"(`->` type($retval)^)?";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_Retval : MPI_Type<"Retval", "retval"> {
|
||||
let summary = "MPI function call return value";
|
||||
let summary = "MPI function call return value (!mpi.retval)";
|
||||
let description = [{
|
||||
This type represents a return value from an MPI function call.
|
||||
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
|
||||
|
@ -14,6 +14,7 @@
|
||||
#ifndef MLIR_INITALLEXTENSIONS_H_
|
||||
#define MLIR_INITALLEXTENSIONS_H_
|
||||
|
||||
#include "Conversion/MPIToLLVM/MPIToLLVM.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
|
||||
registerConvertFuncToLLVMInterface(registry);
|
||||
index::registerConvertIndexToLLVMInterface(registry);
|
||||
registerConvertMathToLLVMInterface(registry);
|
||||
mpi::registerConvertMPIToLLVMInterface(registry);
|
||||
registerConvertMemRefToLLVMInterface(registry);
|
||||
registerConvertNVVMToLLVMInterface(registry);
|
||||
registerConvertOpenMPToLLVMInterface(registry);
|
||||
|
@ -42,6 +42,7 @@ add_subdirectory(MemRefToEmitC)
|
||||
add_subdirectory(MemRefToLLVM)
|
||||
add_subdirectory(MemRefToSPIRV)
|
||||
add_subdirectory(MeshToMPI)
|
||||
add_subdirectory(MPIToLLVM)
|
||||
add_subdirectory(NVGPUToNVVM)
|
||||
add_subdirectory(NVVMToLLVM)
|
||||
add_subdirectory(OpenACCToSCF)
|
||||
|
18
mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
Normal file
18
mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
Normal file
@ -0,0 +1,18 @@
|
||||
add_mlir_conversion_library(MLIRMPIToLLVM
|
||||
MPIToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDLTIDialect
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMDialect
|
||||
MLIRMPIDialect
|
||||
)
|
502
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Normal file
502
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Normal file
@ -0,0 +1,502 @@
|
||||
//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//
|
||||
// Copyright (C) by Argonne National Laboratory
|
||||
// See COPYRIGHT in top-level directory
|
||||
// of MPICH source repository.
|
||||
//
|
||||
|
||||
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MPI/IR/MPI.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Op, typename... Args>
|
||||
static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
|
||||
ConversionPatternRewriter &rewriter, StringRef name,
|
||||
Args &&...args) {
|
||||
Op ret;
|
||||
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
|
||||
const Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
StringRef name,
|
||||
LLVM::LLVMFunctionType type) {
|
||||
return getOrDefineGlobal<LLVM::LLVMFuncOp>(
|
||||
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
|
||||
}
|
||||
|
||||
/// When lowering the mpi dialect to functions calls certain details
|
||||
/// differ between various MPI implementations. This class will provide
|
||||
/// these in a generic way, depending on the MPI implementation that got
|
||||
/// selected by the DLTI attribute on the module.
|
||||
class MPIImplTraits {
|
||||
ModuleOp &moduleOp;
|
||||
|
||||
public:
|
||||
/// Instantiate a new MPIImplTraits object according to the DLTI attribute
|
||||
/// on the given module. Default to MPICH if no attribute is present or
|
||||
/// the value is unknown.
|
||||
static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
|
||||
|
||||
explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
|
||||
|
||||
ModuleOp &getModuleOp() { return moduleOp; }
|
||||
|
||||
/// Gets or creates MPI_COMM_WORLD as a Value.
|
||||
virtual Value getCommWorld(const Location loc,
|
||||
ConversionPatternRewriter &rewriter) = 0;
|
||||
|
||||
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
|
||||
virtual intptr_t getStatusIgnore() = 0;
|
||||
|
||||
/// Gets or creates an MPI datatype as a value which corresponds to the given
|
||||
/// type.
|
||||
virtual Value getDataType(const Location loc,
|
||||
ConversionPatternRewriter &rewriter, Type type) = 0;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Implementation details for MPICH ABI compatible MPI implementations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class MPICHImplTraits : public MPIImplTraits {
|
||||
static constexpr int MPI_FLOAT = 0x4c00040a;
|
||||
static constexpr int MPI_DOUBLE = 0x4c00080b;
|
||||
static constexpr int MPI_INT8_T = 0x4c000137;
|
||||
static constexpr int MPI_INT16_T = 0x4c000238;
|
||||
static constexpr int MPI_INT32_T = 0x4c000439;
|
||||
static constexpr int MPI_INT64_T = 0x4c00083a;
|
||||
static constexpr int MPI_UINT8_T = 0x4c00013b;
|
||||
static constexpr int MPI_UINT16_T = 0x4c00023c;
|
||||
static constexpr int MPI_UINT32_T = 0x4c00043d;
|
||||
static constexpr int MPI_UINT64_T = 0x4c00083e;
|
||||
|
||||
public:
|
||||
using MPIImplTraits::MPIImplTraits;
|
||||
|
||||
Value getCommWorld(const Location loc,
|
||||
ConversionPatternRewriter &rewriter) override {
|
||||
static constexpr int MPI_COMM_WORLD = 0x44000000;
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
|
||||
MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
intptr_t getStatusIgnore() override { return 1; }
|
||||
|
||||
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
|
||||
Type type) override {
|
||||
int32_t mtype = 0;
|
||||
if (type.isF32())
|
||||
mtype = MPI_FLOAT;
|
||||
else if (type.isF64())
|
||||
mtype = MPI_DOUBLE;
|
||||
else if (type.isInteger(64) && !type.isUnsignedInteger())
|
||||
mtype = MPI_INT64_T;
|
||||
else if (type.isInteger(64))
|
||||
mtype = MPI_UINT64_T;
|
||||
else if (type.isInteger(32) && !type.isUnsignedInteger())
|
||||
mtype = MPI_INT32_T;
|
||||
else if (type.isInteger(32))
|
||||
mtype = MPI_UINT32_T;
|
||||
else if (type.isInteger(16) && !type.isUnsignedInteger())
|
||||
mtype = MPI_INT16_T;
|
||||
else if (type.isInteger(16))
|
||||
mtype = MPI_UINT16_T;
|
||||
else if (type.isInteger(8) && !type.isUnsignedInteger())
|
||||
mtype = MPI_INT8_T;
|
||||
else if (type.isInteger(8))
|
||||
mtype = MPI_UINT8_T;
|
||||
else
|
||||
assert(false && "unsupported type");
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Implementation details for OpenMPI
|
||||
//===----------------------------------------------------------------------===//
|
||||
class OMPIImplTraits : public MPIImplTraits {
|
||||
LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
StringRef name,
|
||||
LLVM::LLVMStructType type) {
|
||||
|
||||
return getOrDefineGlobal<LLVM::GlobalOp>(
|
||||
getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
|
||||
LLVM::Linkage::External, name,
|
||||
/*value=*/Attribute(), /*alignment=*/0, 0);
|
||||
}
|
||||
|
||||
public:
|
||||
using MPIImplTraits::MPIImplTraits;
|
||||
|
||||
Value getCommWorld(const Location loc,
|
||||
ConversionPatternRewriter &rewriter) override {
|
||||
auto context = rewriter.getContext();
|
||||
// get external opaque struct pointer type
|
||||
auto commStructT =
|
||||
LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
|
||||
StringRef name = "ompi_mpi_comm_world";
|
||||
|
||||
// make sure global op definition exists
|
||||
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
|
||||
|
||||
// get address of symbol
|
||||
return rewriter.create<LLVM::AddressOfOp>(
|
||||
loc, LLVM::LLVMPointerType::get(context),
|
||||
SymbolRefAttr::get(context, name));
|
||||
}
|
||||
|
||||
intptr_t getStatusIgnore() override { return 0; }
|
||||
|
||||
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
|
||||
Type type) override {
|
||||
StringRef mtype;
|
||||
if (type.isF32())
|
||||
mtype = "ompi_mpi_float";
|
||||
else if (type.isF64())
|
||||
mtype = "ompi_mpi_double";
|
||||
else if (type.isInteger(64) && !type.isUnsignedInteger())
|
||||
mtype = "ompi_mpi_int64_t";
|
||||
else if (type.isInteger(64))
|
||||
mtype = "ompi_mpi_uint64_t";
|
||||
else if (type.isInteger(32) && !type.isUnsignedInteger())
|
||||
mtype = "ompi_mpi_int32_t";
|
||||
else if (type.isInteger(32))
|
||||
mtype = "ompi_mpi_uint32_t";
|
||||
else if (type.isInteger(16) && !type.isUnsignedInteger())
|
||||
mtype = "ompi_mpi_int16_t";
|
||||
else if (type.isInteger(16))
|
||||
mtype = "ompi_mpi_uint16_t";
|
||||
else if (type.isInteger(8) && !type.isUnsignedInteger())
|
||||
mtype = "ompi_mpi_int8_t";
|
||||
else if (type.isInteger(8))
|
||||
mtype = "ompi_mpi_uint8_t";
|
||||
else
|
||||
assert(false && "unsupported type");
|
||||
|
||||
auto context = rewriter.getContext();
|
||||
// get external opaque struct pointer type
|
||||
auto commStructT =
|
||||
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
|
||||
// make sure global op definition exists
|
||||
getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
|
||||
// get address of symbol
|
||||
return rewriter.create<LLVM::AddressOfOp>(
|
||||
loc, LLVM::LLVMPointerType::get(context),
|
||||
SymbolRefAttr::get(context, mtype));
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
|
||||
auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
|
||||
if (failed(attr))
|
||||
return std::make_unique<MPICHImplTraits>(moduleOp);
|
||||
auto strAttr = dyn_cast<StringAttr>(attr.value());
|
||||
if (strAttr && strAttr.getValue() == "OpenMPI")
|
||||
return std::make_unique<OMPIImplTraits>(moduleOp);
|
||||
if (!strAttr || strAttr.getValue() != "MPICH")
|
||||
moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
|
||||
<< strAttr.getValue() << "), defaulting to MPICH";
|
||||
return std::make_unique<MPICHImplTraits>(moduleOp);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InitOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// ptrType `!llvm.ptr`
|
||||
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
|
||||
|
||||
// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
|
||||
auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
|
||||
Value llvmnull = nullPtrOp.getRes();
|
||||
|
||||
// grab a reference to the global module op:
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
|
||||
// LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
|
||||
auto initFuncType =
|
||||
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
|
||||
// get or create function declaration:
|
||||
LLVM::LLVMFuncOp initDecl =
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
|
||||
|
||||
// replace init with function call
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
|
||||
ValueRange{llvmnull, llvmnull});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FinalizeOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// get loc
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// grab a reference to the global module op:
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
|
||||
// LLVM Function type representing `i32 MPI_Finalize()`
|
||||
auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
|
||||
// get or create function declaration:
|
||||
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
|
||||
|
||||
// replace init with function call
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CommRankOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// get some helper vars
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
Type i32 = rewriter.getI32Type();
|
||||
|
||||
// ptrType `!llvm.ptr`
|
||||
Type ptrType = LLVM::LLVMPointerType::get(context);
|
||||
|
||||
// grab a reference to the global module op:
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
|
||||
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
||||
// get MPI_COMM_WORLD
|
||||
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
|
||||
|
||||
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
|
||||
auto rankFuncType =
|
||||
LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType});
|
||||
// get or create function declaration:
|
||||
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
|
||||
|
||||
// replace init with function call
|
||||
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
|
||||
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
|
||||
auto callOp = rewriter.create<LLVM::CallOp>(
|
||||
loc, initDecl, ValueRange{commWorld, rankptr.getRes()});
|
||||
|
||||
// load the rank into a register
|
||||
auto loadedRank =
|
||||
rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
|
||||
|
||||
// if retval is checked, replace uses of retval with the results from the
|
||||
// call op
|
||||
SmallVector<Value> replacements;
|
||||
if (op.getRetval())
|
||||
replacements.push_back(callOp.getResult());
|
||||
|
||||
// replace all uses, then erase op
|
||||
replacements.push_back(loadedRank.getRes());
|
||||
rewriter.replaceOp(op, replacements);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SendOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// get some helper vars
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
Type i32 = rewriter.getI32Type();
|
||||
Type i64 = rewriter.getI64Type();
|
||||
Value memRef = adaptor.getRef();
|
||||
Type elemType = op.getRef().getType().getElementType();
|
||||
|
||||
// ptrType `!llvm.ptr`
|
||||
Type ptrType = LLVM::LLVMPointerType::get(context);
|
||||
|
||||
// grab a reference to the global module op:
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
|
||||
// get MPI_COMM_WORLD, dataType and pointer
|
||||
Value dataPtr =
|
||||
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
|
||||
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
|
||||
dataPtr =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
|
||||
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
|
||||
ArrayRef<int64_t>{3, 0});
|
||||
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
|
||||
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
||||
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
|
||||
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
|
||||
|
||||
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
|
||||
// tag, comm)`
|
||||
auto funcType = LLVM::LLVMFunctionType::get(
|
||||
i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()});
|
||||
// get or create function declaration:
|
||||
LLVM::LLVMFuncOp funcDecl =
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
|
||||
|
||||
// replace op with function call
|
||||
auto funcCall = rewriter.create<LLVM::CallOp>(
|
||||
loc, funcDecl,
|
||||
ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
|
||||
commWorld});
|
||||
if (op.getRetval())
|
||||
rewriter.replaceOp(op, funcCall.getResult());
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RecvOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// get some helper vars
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
Type i32 = rewriter.getI32Type();
|
||||
Type i64 = rewriter.getI64Type();
|
||||
Value memRef = adaptor.getRef();
|
||||
Type elemType = op.getRef().getType().getElementType();
|
||||
|
||||
// ptrType `!llvm.ptr`
|
||||
Type ptrType = LLVM::LLVMPointerType::get(context);
|
||||
|
||||
// grab a reference to the global module op:
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
|
||||
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
|
||||
Value dataPtr =
|
||||
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
|
||||
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
|
||||
dataPtr =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
|
||||
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
|
||||
ArrayRef<int64_t>{3, 0});
|
||||
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
|
||||
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
||||
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
|
||||
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
|
||||
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i64, mpiTraits->getStatusIgnore());
|
||||
statusIgnore =
|
||||
rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
|
||||
|
||||
// LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
|
||||
// tag, comm)`
|
||||
auto funcType =
|
||||
LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
|
||||
i32, commWorld.getType(), ptrType});
|
||||
// get or create function declaration:
|
||||
LLVM::LLVMFuncOp funcDecl =
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
|
||||
|
||||
// replace op with function call
|
||||
auto funcCall = rewriter.create<LLVM::CallOp>(
|
||||
loc, funcDecl,
|
||||
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
|
||||
adaptor.getTag(), commWorld, statusIgnore});
|
||||
if (op.getRetval())
|
||||
rewriter.replaceOp(op, funcCall.getResult());
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertToLLVMPatternInterface implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Implement the interface to convert Func to LLVM.
|
||||
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
||||
/// Hook for derived dialect interface to provide conversion patterns
|
||||
/// and mark dialect legal for the conversion target.
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern Population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
|
||||
SendOpLowering, RecvOpLowering>(converter);
|
||||
}
|
||||
|
||||
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
|
||||
dialect->addInterfaces<FuncToLLVMDialectInterface>();
|
||||
});
|
||||
}
|
165
mlir/test/Conversion/MPIToLLVM/ops.mlir
Normal file
165
mlir/test/Conversion/MPIToLLVM/ops.mlir
Normal file
@ -0,0 +1,165 @@
|
||||
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
|
||||
|
||||
// COM: Test MPICH ABI
|
||||
// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
|
||||
// CHECK: llvm.func @MPI_Finalize() -> i32
|
||||
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
|
||||
module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
|
||||
|
||||
// CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
|
||||
func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
|
||||
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
|
||||
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
%0 = mpi.init : !mpi.retval
|
||||
|
||||
// CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
|
||||
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
|
||||
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
|
||||
|
||||
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
|
||||
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
|
||||
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v19:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
|
||||
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
|
||||
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v27:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
|
||||
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
|
||||
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
|
||||
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v35:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
|
||||
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
|
||||
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v45:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
|
||||
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
|
||||
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
|
||||
%3 = mpi.finalize : !mpi.retval
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// COM: Test OpenMPI ABI
|
||||
// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
|
||||
// CHECK: llvm.func @MPI_Finalize() -> i32
|
||||
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
|
||||
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
|
||||
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
|
||||
module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
|
||||
|
||||
// CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
|
||||
func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
|
||||
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
|
||||
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
%0 = mpi.init : !mpi.retval
|
||||
|
||||
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
|
||||
|
||||
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
|
||||
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
|
||||
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
|
||||
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
|
||||
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
|
||||
// CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v35:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
|
||||
// CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v45:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
|
||||
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
|
||||
%3 = mpi.finalize : !mpi.retval
|
||||
|
||||
return
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user