[mlir][shard, mpi] Adding Shard/MPI reduce_scatter and simplification (#184189)
- introduces a simplify pass, which finds such patterns and replaces it with the equivalent `reduce-scatter` - promotes the test-pass `test-shard-optimizations` to a proper pass and adds - folding allgather+allslice into reduce_scatter - sanitizes the `shard.reduce_scatter` op - adds a new `mpi.reduce_scatter_block` op - lowers `shard.reduce_scatter` to MPI - lowers `mpi-reduce_scatter_block` to llvm --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
5f8f1e2afe
commit
a232b5b96f
@ -333,6 +333,41 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
|
||||
"(`->` type($retval)^)?";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReduceScatterBlockOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
|
||||
let summary = "Equivalent to `MPI_Reduce_scatter_block(sendbuf, recvbuf, "
|
||||
"recvcount, dtype, op, comm)`";
|
||||
let description = [{
|
||||
MPI_Reduce_scatter_block first performs an element-wise reduction on the
|
||||
sendbuf across all processes in the communicator, then scatters the result
|
||||
by distributing equal-sized blocks to each process into recvbuf.
|
||||
|
||||
The `op` attribute specifies the reduction operation to be performed.
|
||||
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
|
||||
supported.
|
||||
|
||||
This operation can optionally return an `!mpi.retval` value that can be used
|
||||
to check for errors.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyNon0RankedMemRef : $sendbuf,
|
||||
AnyNon0RankedMemRef : $recvbuf,
|
||||
MPI_ReductionOpEnum : $op,
|
||||
MPI_Comm : $comm
|
||||
);
|
||||
|
||||
let results = (outs Optional<MPI_Retval>:$retval);
|
||||
|
||||
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
|
||||
"attr-dict `:` type($sendbuf) `,` type($recvbuf) "
|
||||
"(`->` type($retval)^)?";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BarrierOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -901,13 +901,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
|
||||
`grid_axes` using the specified reduction method. The reduction is performed
|
||||
element-wise across the tensor pieces from all devices in the group.
|
||||
After reduction, the reduction result is scattered (split and distributed)
|
||||
across the device group along `scatter_axis`.
|
||||
across the device group along `scatter_dim`.
|
||||
Example:
|
||||
```
|
||||
shard.grid @grid0(shape = 2x2)
|
||||
...
|
||||
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
|
||||
reduction = <max> scatter_axis = 0
|
||||
reduction = <max> scatter_dim = 0
|
||||
: tensor<2x2xf32> -> tensor<1x2xf64>
|
||||
```
|
||||
Input:
|
||||
@ -940,17 +940,17 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
|
||||
DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
|
||||
IndexAttr:$scatter_axis
|
||||
IndexAttr:$scatter_dim
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
|
||||
(`reduction` `=` $reduction^)?
|
||||
`scatter_axis` `=` $scatter_axis
|
||||
`scatter_dim` `=` $scatter_dim
|
||||
attr-dict `:` type($input) `->` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
@ -964,7 +964,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
|
||||
let summary = "Scatter over a device grid.";
|
||||
let description = [{
|
||||
For each device group defined by `grid_axes`, the input tensor on the `root`
|
||||
device is split along axis `scatter_axis` and distributed across the group.
|
||||
device is split along axis `scatter_dim` and distributed across the group.
|
||||
The content of the input on all other (non-root) devices is ignored.
|
||||
The `root` device is defined by its in-group multi-index.
|
||||
|
||||
@ -972,7 +972,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
|
||||
```
|
||||
shard.grid @grid0(shape = 2x2)
|
||||
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
|
||||
scatter_axis = 0
|
||||
scatter_dim = 0
|
||||
root = [1]
|
||||
: (tensor<2x2xi8>) -> tensor<1x2xi8>
|
||||
```
|
||||
@ -1011,7 +1011,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
IndexAttr:$scatter_axis,
|
||||
IndexAttr:$scatter_dim,
|
||||
DenseI64ArrayAttr:$root,
|
||||
Variadic<Index>:$root_dynamic
|
||||
));
|
||||
@ -1020,7 +1020,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
|
||||
`scatter_axis` `=` $scatter_axis
|
||||
`scatter_dim` `=` $scatter_dim
|
||||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
|
||||
//===- Partition.h - Shard Partition ----------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
||||
@ -44,6 +44,23 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
|
||||
];
|
||||
}
|
||||
|
||||
def ShardSimplify : Pass<"shard-simplify"> {
|
||||
let summary = "Shard simplify patterns.";
|
||||
let description = [{
|
||||
Applies simplification patterns on the Shard dialect operations.
|
||||
This includes:
|
||||
- All-reduce endomorphism simplification, e.g. transforming
|
||||
`all_reduce_sum(x) + all_reduce_sum(y)` into `all_reduce_sum(x + y)`.
|
||||
- Folding `AllSliceOp(AllReduceOp)` into `ReduceScatterOp` when both ops
|
||||
share the same grid and grid_axes.
|
||||
- Folding static grid shapes into constants.
|
||||
}];
|
||||
let dependentDialects = [
|
||||
"arith::ArithDialect",
|
||||
"shard::ShardDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
|
||||
let summary = "Partition a function into SPMD form.";
|
||||
let description = [{
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
|
||||
//===- Simplify.h - Shard Simplify ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@ -6,8 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
|
||||
#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
|
||||
#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
|
||||
#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
|
||||
|
||||
#include "mlir/Dialect/Shard/IR/ShardOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@ -37,8 +37,8 @@ namespace shard {
|
||||
// Will not work with some op `f(x, y, z)` where only `x` and `y` form
|
||||
// the algebraic structure.
|
||||
template <typename AlgebraicOp>
|
||||
void populateAllReduceEndomorphismSimplificationPatterns(
|
||||
RewritePatternSet &patterns, ReductionKind reduction) {
|
||||
void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns,
|
||||
ReductionKind reduction) {
|
||||
auto getEndomorphismOpOperand = [](Operation *op) {
|
||||
auto allReduceOp = llvm::cast<AllReduceOp>(op);
|
||||
return &allReduceOp.getInputMutable();
|
||||
@ -105,12 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
|
||||
|
||||
// It is invalid to change ops that declare symbols during the application of
|
||||
// these patterns, because symbolTableCollection is used to cache them.
|
||||
void populateSimplificationPatterns(
|
||||
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
|
||||
void populateSimplifyPatterns(RewritePatternSet &patterns,
|
||||
SymbolTableCollection &symbolTableCollection);
|
||||
void populateFoldingPatterns(RewritePatternSet &patterns,
|
||||
SymbolTableCollection &symbolTableCollection);
|
||||
|
||||
} // namespace shard
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
|
||||
#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
|
||||
@ -16,6 +16,7 @@
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
@ -911,6 +912,79 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReduceScatterBlockOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ReduceScatterBlockOpLowering
|
||||
: public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
Type i32 = rewriter.getI32Type();
|
||||
Type i64 = rewriter.getI64Type();
|
||||
Type elemType = op.getSendbuf().getType().getElementType();
|
||||
int64_t sRank = op.getSendbuf().getType().getRank();
|
||||
int64_t rRank = op.getRecvbuf().getType().getRank();
|
||||
|
||||
// ptrType `!llvm.ptr`
|
||||
Type ptrType = LLVM::LLVMPointerType::get(context);
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
auto mpiTraits = MPIImplTraits::get(moduleOp);
|
||||
auto [sendPtr, sendSize] =
|
||||
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
|
||||
auto [recvPtr, recvSize] =
|
||||
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
|
||||
|
||||
// If input and output are the same, request in-place operation.
|
||||
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
|
||||
sendPtr = LLVM::ConstantOp::create(
|
||||
rewriter, loc, i64,
|
||||
reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
|
||||
sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
|
||||
}
|
||||
|
||||
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
|
||||
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
|
||||
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
|
||||
|
||||
Value nRanks =
|
||||
createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
|
||||
Value totalExpected =
|
||||
LLVM::MulOp::create(rewriter, loc, i32, recvSize, nRanks);
|
||||
Value sizeIsValid = LLVM::ICmpOp::create(
|
||||
rewriter, loc, LLVM::ICmpPredicate::eq, sendSize, totalExpected);
|
||||
cf::AssertOp::create(rewriter, loc, sizeIsValid,
|
||||
"Send buffer's size must be the receive buffer's size "
|
||||
"times the number of ranks");
|
||||
|
||||
// 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
|
||||
// int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
|
||||
auto funcType = LLVM::LLVMFunctionType::get(
|
||||
i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
|
||||
comm.getType()});
|
||||
// get or create function declaration:
|
||||
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType);
|
||||
|
||||
// replace op with function call
|
||||
auto funcCall = LLVM::CallOp::create(
|
||||
rewriter, loc, funcDecl,
|
||||
ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
|
||||
|
||||
if (op.getRetval())
|
||||
rewriter.replaceOp(op, funcCall.getResult());
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertToLLVMPatternInterface implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -943,7 +1017,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
|
||||
CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
|
||||
SendOpLowering, RecvOpLowering, AllGatherOpLowering,
|
||||
AllReduceOpLowering>(converter);
|
||||
AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
|
||||
}
|
||||
|
||||
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
|
||||
#include "mlir/Dialect/Shard/IR/ShardOps.h"
|
||||
#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
|
||||
#include "mlir/Dialect/Shard/Transforms/Simplify.h"
|
||||
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
@ -618,6 +618,156 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
|
||||
using CommOpPattern::CommOpPattern;
|
||||
|
||||
// shard.reduce_scatter reduces and then scatters along a specified
|
||||
// scatter-dim. mpi.reduce_scatter_block always scatters along the first
|
||||
// dimension. Hence, if scatter-dim != 0, we need to rearrange the input
|
||||
// data by expanding the scatter-dim into {nRanks, output_scatter_dim} and
|
||||
// transposing nRanks to the first dimension.
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto gridAxes = adaptor.getGridAxes();
|
||||
int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
|
||||
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
|
||||
if (failed(gridOp))
|
||||
return failure();
|
||||
|
||||
ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
|
||||
Value rawInput = adaptor.getInput();
|
||||
auto inShapedType = cast<ShapedType>(rawInput.getType());
|
||||
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
|
||||
auto elemType = outType.getElementType();
|
||||
auto inputShape = inShapedType.getShape();
|
||||
auto outputShape = outType.getShape();
|
||||
int64_t inputDimOnAxis = inputShape[scatterDim];
|
||||
int64_t outputDimOnAxis = outputShape[scatterDim];
|
||||
|
||||
for (size_t i = 0; i < outputShape.size(); ++i)
|
||||
if (outputShape[i] != inputShape[i] &&
|
||||
i != static_cast<size_t>(scatterDim))
|
||||
return op.emitError(
|
||||
"Result and input shapes must match along non-scatter axes.");
|
||||
if (outputDimOnAxis == 0)
|
||||
return op.emitError(
|
||||
"Output size along the scatter axis must be non-zero.");
|
||||
if (inputDimOnAxis % outputDimOnAxis != 0)
|
||||
return op.emitError(
|
||||
"Input size along the scatter axis must be an exact "
|
||||
"multiple of the output size along the scatter axis.");
|
||||
|
||||
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
|
||||
return op.emitError("Result must be a statically shaped memref in "
|
||||
"contiguous row-major layout.");
|
||||
|
||||
int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
|
||||
|
||||
// Verify that nRanks matches the number of devices along the grid axes.
|
||||
int64_t gridGroupSize =
|
||||
collectiveProcessGroupSize(gridAxes, gridOp->getShape());
|
||||
if (nRanks != gridGroupSize)
|
||||
return op.emitError()
|
||||
<< "Expected the scatter factor (" << nRanks
|
||||
<< ") to match the number of devices along grid_axes ("
|
||||
<< gridGroupSize << ").";
|
||||
|
||||
// Get the right communicator.
|
||||
Value comm = getComm(*gridOp, gridAxes, ib);
|
||||
|
||||
Value mpiInput;
|
||||
if (scatterDim == 0) {
|
||||
// scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
|
||||
// Input must be contiguous for MPI.
|
||||
Value input = getAsMemref(rawInput, ib);
|
||||
MemRefType inType = cast<MemRefType>(input.getType());
|
||||
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
|
||||
return op.emitError("Input must be a statically shaped memref in "
|
||||
"contiguous row-major layout.");
|
||||
mpiInput = input;
|
||||
} else {
|
||||
// For scatter_dim != 0 we rearrange the input so the scatter factor
|
||||
// becomes the first dimension.
|
||||
//
|
||||
// 1. Get a tensor representation of the input (avoid memref->tensor
|
||||
// round-trip if the input is already a tensor).
|
||||
Value tensorInput = rawInput;
|
||||
if (!isa<RankedTensorType>(rawInput.getType())) {
|
||||
auto inTensorType = RankedTensorType::get(inputShape, elemType);
|
||||
tensorInput =
|
||||
bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true);
|
||||
}
|
||||
|
||||
// 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to
|
||||
// {d0, ..., nRanks, o_sd, ..., dN}.
|
||||
SmallVector<int64_t> expandedShape;
|
||||
SmallVector<ReassociationIndices> expandReassociation;
|
||||
int64_t expandedIdx = 0;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
|
||||
if (i == scatterDim) {
|
||||
expandedShape.push_back(nRanks);
|
||||
expandedShape.push_back(outputDimOnAxis);
|
||||
expandReassociation.push_back({expandedIdx, expandedIdx + 1});
|
||||
expandedIdx += 2;
|
||||
} else {
|
||||
expandedShape.push_back(inputShape[i]);
|
||||
expandReassociation.push_back({expandedIdx});
|
||||
expandedIdx += 1;
|
||||
}
|
||||
}
|
||||
auto expandedType = RankedTensorType::get(expandedShape, elemType);
|
||||
tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
|
||||
expandReassociation);
|
||||
|
||||
// 3. Transpose to move nRanks (at position scatterDim) to position 0:
|
||||
// {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN}
|
||||
SmallVector<int64_t> permutation, transposedShape;
|
||||
permutation.emplace_back(scatterDim);
|
||||
for (int64_t i = 0; i < scatterDim; ++i)
|
||||
permutation.emplace_back(i);
|
||||
for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
|
||||
permutation.emplace_back(i);
|
||||
for (auto p : permutation)
|
||||
transposedShape.emplace_back(expandedShape[p]);
|
||||
|
||||
Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
|
||||
tensorInput =
|
||||
linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
|
||||
->getResult(0);
|
||||
|
||||
// 4. Materialize as contiguous memref for MPI by copying into a
|
||||
// freshly allocated buffer.
|
||||
auto mpiInType = MemRefType::get(transposedShape, elemType);
|
||||
Value transposedBuf =
|
||||
bufferization::ToBufferOp::create(ib, mpiInType, tensorInput);
|
||||
mpiInput = memref::AllocOp::create(ib, mpiInType);
|
||||
linalg::CopyOp::create(ib, transposedBuf, mpiInput);
|
||||
}
|
||||
|
||||
// Allocate output buffer.
|
||||
Value output = memref::AllocOp::create(ib, outType);
|
||||
// Create the MPI ReduceScatter operation.
|
||||
mpi::ReduceScatterBlockOp::create(
|
||||
ib, TypeRange(), mpiInput, output,
|
||||
getMPIReductionOp(adaptor.getReductionAttr()), comm);
|
||||
|
||||
// Deallocate the temporary input buffer if we allocated one.
|
||||
if (scatterDim != 0)
|
||||
memref::DeallocOp::create(ib, mpiInput);
|
||||
|
||||
// If the destination is a tensor, cast it to a tensor.
|
||||
if (isa<RankedTensorType>(op.getType()))
|
||||
output =
|
||||
bufferization::ToTensorOp::create(ib, op.getType(), output, true);
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
|
||||
using CommOpPattern::CommOpPattern;
|
||||
|
||||
@ -1048,7 +1198,7 @@ struct ConvertShardToMPIPass
|
||||
|
||||
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
|
||||
ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
|
||||
ConvertAllGatherOp, ConvertAllReduceOp,
|
||||
ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
|
||||
ConvertProcessLinearIndexOp>(typeConverter, ctxt);
|
||||
SymbolTableCollection stc;
|
||||
populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
|
||||
|
||||
@ -15,8 +15,23 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::mpi;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verifiers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult mlir::mpi::ReduceScatterBlockOp::verify() {
|
||||
if (getSendbuf().getType().getElementType() !=
|
||||
getRecvbuf().getType().getElementType())
|
||||
return emitOpError("sendbuf and recvbuf must have the same element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Canonicalization patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// If input memref has dynamic shape and is a cast and if the cast's input has
|
||||
// static shape, fold the cast's static input into the given operation.
|
||||
template <typename OpT>
|
||||
|
||||
@ -1416,7 +1416,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
}
|
||||
|
||||
return verifyScatterOrSliceOperandAndResultShape(
|
||||
getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
|
||||
getOperand(), getResult(), getScatterDim().getSExtValue(), getGridAxes(),
|
||||
grid.value().getShape());
|
||||
}
|
||||
|
||||
@ -1445,9 +1445,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto scatterAxis = getScatterAxis().getSExtValue();
|
||||
auto scatterDim = getScatterDim().getSExtValue();
|
||||
return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
|
||||
scatterAxis, getGridAxes(),
|
||||
scatterDim, getGridAxes(),
|
||||
grid.value().getShape());
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
add_mlir_dialect_library(MLIRShardTransforms
|
||||
Simplifications.cpp
|
||||
Simplify.cpp
|
||||
ShardingPropagation.cpp
|
||||
Partition.cpp
|
||||
Transforms.cpp
|
||||
@ -26,4 +26,5 @@ add_mlir_dialect_library(MLIRShardTransforms
|
||||
MLIRSupport
|
||||
MLIRTensorDialect
|
||||
MLIRTosaShardingInterfaceImpl
|
||||
MLIRTransformUtils
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
|
||||
//===- Simplify.cpp - Shard Simplify ----------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@ -6,13 +6,16 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
|
||||
#include "mlir/Dialect/Shard/Transforms/Simplify.h"
|
||||
#include "TransformsDetail.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
|
||||
#include "mlir/Dialect/Shard/IR/ShardOps.h"
|
||||
#include "mlir/Dialect/Shard/Transforms/Passes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include <numeric>
|
||||
@ -20,31 +23,8 @@
|
||||
namespace mlir {
|
||||
namespace shard {
|
||||
|
||||
void populateSimplificationPatterns(
|
||||
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
|
||||
patterns, ReductionKind::Sum);
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
|
||||
patterns, ReductionKind::Sum);
|
||||
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
|
||||
patterns, ReductionKind::Min);
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
|
||||
patterns, ReductionKind::Min);
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
|
||||
patterns, ReductionKind::Min);
|
||||
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
|
||||
patterns, ReductionKind::Max);
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
|
||||
patterns, ReductionKind::Max);
|
||||
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
|
||||
patterns, ReductionKind::Max);
|
||||
|
||||
// TODO: add simplifications for all-gather and other collectives.
|
||||
|
||||
populateFoldingPatterns(patterns, symbolTableCollection);
|
||||
}
|
||||
#define GEN_PASS_DEF_SHARDSIMPLIFY
|
||||
#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
@ -109,12 +89,97 @@ struct GridShapeFolder
|
||||
}
|
||||
};
|
||||
|
||||
// Simplify AllSliceOp(AllReduceOp) -> ReduceScatterOp when both ops share the
|
||||
// same grid and grid_axes.
|
||||
//
|
||||
// AllReduceOp performs an element-wise reduction across all devices in the
|
||||
// group, and AllSliceOp then slices (scatters) the result along a tensor
|
||||
// dimension. This is exactly what ReduceScatterOp does in a single collective.
|
||||
//
|
||||
// With a ring algorithm over N ranks and M elements:
|
||||
// AllReduce: 2*(N-1) steps of M/N each => ~2M total data transferred
|
||||
// AllSlice: local slice, no communication
|
||||
// ReduceScatter: (N-1) steps of M/N each => ~M total data transferred
|
||||
// So this fusion roughly halves the communication volume.
|
||||
//
|
||||
// Memory-wise, AllReduce produces a full-sized M-element result that the
|
||||
// subsequent AllSlice must keep alive until the slice is taken. ReduceScatter
|
||||
// only materializes the M/N-element local slice, reducing peak memory by
|
||||
// a factor of N.
|
||||
struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AllSliceOp sliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if the input to AllSliceOp is produced by an AllReduceOp.
|
||||
auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
|
||||
if (!reduceOp || !reduceOp->hasOneUse())
|
||||
return failure();
|
||||
|
||||
// Both ops must operate on the same grid and grid axes.
|
||||
if (reduceOp.getGrid() != sliceOp.getGrid() ||
|
||||
reduceOp.getGridAxes() != sliceOp.getGridAxes())
|
||||
return failure();
|
||||
|
||||
// Replace with a single ReduceScatterOp.
|
||||
rewriter.replaceOpWithNewOp<ReduceScatterOp>(
|
||||
sliceOp, sliceOp.getResult().getType(), sliceOp.getGridAttr(),
|
||||
sliceOp.getGridAxesAttr(), reduceOp.getInput(),
|
||||
reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateSimplifyPatterns(RewritePatternSet &patterns,
|
||||
SymbolTableCollection &symbolTableCollection) {
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::AddFOp>(
|
||||
patterns, ReductionKind::Sum);
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::AddIOp>(
|
||||
patterns, ReductionKind::Sum);
|
||||
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::MinimumFOp>(
|
||||
patterns, ReductionKind::Min);
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::MinSIOp>(
|
||||
patterns, ReductionKind::Min);
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::MinUIOp>(
|
||||
patterns, ReductionKind::Min);
|
||||
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::MaximumFOp>(
|
||||
patterns, ReductionKind::Max);
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::MaxSIOp>(
|
||||
patterns, ReductionKind::Max);
|
||||
populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
|
||||
patterns, ReductionKind::Max);
|
||||
|
||||
patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
|
||||
|
||||
// TODO: add simplify patterns for all-gather and other collectives.
|
||||
|
||||
populateFoldingPatterns(patterns, symbolTableCollection);
|
||||
}
|
||||
|
||||
void populateFoldingPatterns(RewritePatternSet &patterns,
|
||||
SymbolTableCollection &symbolTableCollection) {
|
||||
patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct ShardSimplifyPass : public impl::ShardSimplifyBase<ShardSimplifyPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
populateSimplifyPatterns(patterns, symbolTableCollection);
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace shard
|
||||
} // namespace mlir
|
||||
@ -1,127 +1,151 @@
|
||||
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
|
||||
|
||||
// COM: Test MPICH ABI
|
||||
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
|
||||
// CHECK: llvm.func @MPI_Finalize() -> i32
|
||||
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
|
||||
// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
|
||||
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
|
||||
// CHECK-DAG: llvm.func @MPI_Finalize() -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
|
||||
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
|
||||
|
||||
// CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
|
||||
func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
|
||||
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
|
||||
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-LABEL: llvm.func @test_init_finalize_mpich
|
||||
func.func @test_init_finalize_mpich() {
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr
|
||||
// CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
%0 = mpi.init : !mpi.retval
|
||||
|
||||
// CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
|
||||
// CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32
|
||||
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
|
||||
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
|
||||
|
||||
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
|
||||
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
|
||||
// CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
|
||||
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
|
||||
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
|
||||
// CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
|
||||
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
|
||||
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
|
||||
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
|
||||
// CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
|
||||
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
|
||||
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
|
||||
// CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
|
||||
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
|
||||
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
|
||||
// CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
|
||||
%color = arith.constant 10 : i32
|
||||
// CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
|
||||
%key = arith.constant 22 : i32
|
||||
// CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32
|
||||
// CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
|
||||
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
|
||||
|
||||
// CHECK: llvm.call @MPI_Comm_size
|
||||
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
|
||||
%err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
|
||||
|
||||
// CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v63a:%.*]] = llvm.trunc [[v62]] : i64 to i32
|
||||
// CHECK: [[v63:%.*]] = llvm.mul [[v63a]]
|
||||
// CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v68a:%.*]] = llvm.trunc [[v67]] : i64 to i32
|
||||
// CHECK: [[v68:%.*]] = llvm.mul [[v68a]]
|
||||
// CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
|
||||
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
|
||||
// CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
|
||||
// CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
|
||||
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
|
||||
|
||||
// CHECK: llvm.call @MPI_Finalize() : () -> i32
|
||||
%3 = mpi.finalize : !mpi.retval
|
||||
%1 = mpi.finalize : !mpi.retval
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_comm_rank_mpich
|
||||
func.func @test_comm_rank_mpich() {
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v1:%.*]] = llvm.trunc [[v0]] : i64 to i32
|
||||
// CHECK: [[v2:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v3:%.*]] = llvm.alloca [[v2]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call @MPI_Comm_rank([[v1]], [[v3]]) : (i32, !llvm.ptr) -> i32
|
||||
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_send_mpich
|
||||
func.func @test_send_mpich(%arg0: memref<100xf32>) {
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: llvm.call @MPI_Comm_rank
|
||||
// CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
|
||||
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
|
||||
// COM: Test send without retval
|
||||
// CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32
|
||||
// CHECK: [[v8:%.*]] = llvm.mul [[v7]]
|
||||
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32
|
||||
// CHECK: = llvm.call @MPI_Send([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
|
||||
// COM: Test send with retval
|
||||
// CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
|
||||
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_recv_mpich
|
||||
func.func @test_recv_mpich(%arg0: memref<100xf32>) {
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: llvm.call @MPI_Comm_rank
|
||||
// CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
|
||||
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
|
||||
// COM: Test recv without retval
|
||||
// CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32
|
||||
// CHECK: [[v8:%.*]] = llvm.mul [[v7]]
|
||||
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32
|
||||
// CHECK: [[v11:%.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: [[v12:%.*]] = llvm.inttoptr [[v11]] : i64 to !llvm.ptr
|
||||
// CHECK: = llvm.call @MPI_Recv([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]], [[v12]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
|
||||
// COM: Test recv with retval
|
||||
// CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
|
||||
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_comm_split_mpich
|
||||
func.func @test_comm_split_mpich() {
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.constant(10 : i32) : i32
|
||||
%color = arith.constant 10 : i32
|
||||
// CHECK: [[v2:%.*]] = llvm.mlir.constant(22 : i32) : i32
|
||||
%key = arith.constant 22 : i32
|
||||
// CHECK: [[v3:%.*]] = llvm.trunc [[v0]] : i64 to i32
|
||||
// CHECK: [[v4:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v5:%.*]] = llvm.alloca [[v4]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call @MPI_Comm_split([[v3]], [[v1]], [[v2]], [[v5]]) : (i32, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.load [[v5]] : !llvm.ptr -> i32
|
||||
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_allgather_mpich
|
||||
func.func @test_allgather_mpich(%arg0: memref<100xf32>) {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: llvm.call @MPI_Comm_size
|
||||
// CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
|
||||
%err = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_allreduce_mpich
|
||||
func.func @test_allreduce_mpich(%arg0: memref<100xf32>) {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
// CHECK: [[v2:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v3:%.*]] = llvm.mul
|
||||
// CHECK: [[v4:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v5:%.*]] = llvm.mlir.constant(-1 : i64) : i64
|
||||
// CHECK: [[v6:%.*]] = llvm.inttoptr [[v5]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v7:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
|
||||
// CHECK: [[v8:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
|
||||
// CHECK: [[v9:%.*]] = llvm.trunc [[v1]] : i64 to i32
|
||||
// CHECK: llvm.call @MPI_Allreduce([[v6]], [[v4]], [[v3]], [[v7]], [[v8]], [[v9]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
|
||||
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_reduce_scatter_block_mpich
|
||||
func.func @test_reduce_scatter_block_mpich(%arg0: memref<100xf32>) {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
// CHECK: llvm.mul
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32
|
||||
// CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32
|
||||
// CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32
|
||||
// CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
|
||||
// CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}}
|
||||
// CHECK: ^[[bb1]]:
|
||||
// CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
|
||||
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -129,131 +153,160 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
|
||||
// -----
|
||||
|
||||
// COM: Test OpenMPI ABI
|
||||
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
|
||||
// CHECK: llvm.func @MPI_Finalize() -> i32
|
||||
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
|
||||
// CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
|
||||
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
|
||||
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
|
||||
// CHECK-DAG: llvm.func @MPI_Finalize() -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
|
||||
// CHECK-DAG: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
|
||||
// CHECK-DAG: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-DAG: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
|
||||
// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
|
||||
module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
|
||||
|
||||
// CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
|
||||
func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
|
||||
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
|
||||
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-LABEL: llvm.func @test_init_finalize_openmpi
|
||||
func.func @test_init_finalize_openmpi() {
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr
|
||||
// CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
%0 = mpi.init : !mpi.retval
|
||||
// CHECK: llvm.call @MPI_Finalize() : () -> i32
|
||||
%1 = mpi.finalize : !mpi.retval
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_comm_rank_openmpi
|
||||
func.func @test_comm_rank_openmpi() {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64
|
||||
// CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64
|
||||
// CHECK: [[v2:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v3:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v4:%.*]] = llvm.alloca [[v3]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call @MPI_Comm_rank([[v2]], [[v4]]) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
|
||||
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
|
||||
// CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
|
||||
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK-LABEL: llvm.func @test_send_openmpi
|
||||
func.func @test_send_openmpi(%arg0: memref<100xf32>) {
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
|
||||
// CHECK: llvm.call @MPI_Comm_rank
|
||||
// CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
|
||||
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
|
||||
// COM: Test send without retval
|
||||
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32
|
||||
// CHECK: [[v9:%.*]] = llvm.mul [[v8]]
|
||||
// CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
|
||||
// CHECK: = llvm.call @MPI_Send([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
|
||||
// CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
|
||||
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// COM: Test send with retval
|
||||
// CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
|
||||
// CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
|
||||
// CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-LABEL: llvm.func @test_recv_openmpi
|
||||
func.func @test_recv_openmpi(%arg0: memref<100xf32>) {
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
|
||||
// CHECK: llvm.call @MPI_Comm_rank
|
||||
// CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
|
||||
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
|
||||
// COM: Test recv without retval
|
||||
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32
|
||||
// CHECK: [[v9:%.*]] = llvm.mul [[v8]]
|
||||
// CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v12:%.*]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: [[v13:%.*]] = llvm.inttoptr [[v12]] : i64 to !llvm.ptr
|
||||
// CHECK: = llvm.call @MPI_Recv([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]], [[v13]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
|
||||
|
||||
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
|
||||
// CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
|
||||
// CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// COM: Test recv with retval
|
||||
// CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_comm_split_openmpi
|
||||
func.func @test_comm_split_openmpi() {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64
|
||||
// CHECK: [[v2:%.*]] = llvm.mlir.constant(10 : i32) : i32
|
||||
%color = arith.constant 10 : i32
|
||||
// CHECK: [[v3:%.*]] = llvm.mlir.constant(22 : i32) : i32
|
||||
%key = arith.constant 22 : i32
|
||||
// CHECK: [[v4:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v5:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v6:%.*]] = llvm.alloca [[v5]] x !llvm.ptr : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call @MPI_Comm_split([[v4]], [[v2]], [[v3]], [[v6]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.load [[v6]] : !llvm.ptr -> i32
|
||||
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_allgather_openmpi
|
||||
func.func @test_allgather_openmpi(%arg0: memref<100xf32>) {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: llvm.call @MPI_Comm_size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.udiv {{.*}} : i32
|
||||
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
|
||||
mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v53a:%.*]] = llvm.trunc [[v52]] : i64 to i32
|
||||
// CHECK: [[v53:%.*]] = llvm.mul [[v53a]]
|
||||
// CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v58a:%.*]] = llvm.trunc [[v57]] : i64 to i32
|
||||
// CHECK: [[v58:%.*]] = llvm.mul [[v58a]]
|
||||
// CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
|
||||
// CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
|
||||
// CHECK-LABEL: llvm.func @test_allreduce_openmpi
|
||||
func.func @test_allreduce_openmpi(%arg0: memref<100xf32>) {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
|
||||
// CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
|
||||
// CHECK: [[v3:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v4:%.*]] = llvm.mul
|
||||
// CHECK: [[v5:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
|
||||
// CHECK: [[v6:%.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: [[v7:%.*]] = llvm.inttoptr [[v6]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
|
||||
// CHECK: [[v9:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
|
||||
// CHECK: [[v10:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
|
||||
// CHECK: llvm.call @MPI_Allreduce([[v7]], [[v5]], [[v4]], [[v8]], [[v9]], [[v10]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
|
||||
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
|
||||
%color = arith.constant 10 : i32
|
||||
// CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
|
||||
%key = arith.constant 22 : i32
|
||||
// CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
|
||||
// CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr
|
||||
// CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
|
||||
// CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32
|
||||
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
|
||||
|
||||
// CHECK: llvm.call @MPI_Finalize() : () -> i32
|
||||
%3 = mpi.finalize : !mpi.retval
|
||||
|
||||
// CHECK-LABEL: llvm.func @test_reduce_scatter_block_openmpi
|
||||
func.func @test_reduce_scatter_block_openmpi(%arg0: memref<100xf32>) {
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
|
||||
// CHECK: llvm.mul
|
||||
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32
|
||||
// CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32
|
||||
// CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32
|
||||
// CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
|
||||
// CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}}
|
||||
// CHECK: ^[[bb1]]:
|
||||
// CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
|
||||
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -261,13 +314,13 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
|
||||
// -----
|
||||
|
||||
module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = 4, "MPI:comm_world_rank" = 1> } {
|
||||
// CHECK: llvm.func @mpi_test_fold
|
||||
func.func @mpi_test_fold(%arg0: memref<100xf32>) {
|
||||
// CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
// CHECK-LABEL: llvm.func @test_fold
|
||||
func.func @test_fold(%arg0: memref<100xf32>) {
|
||||
// CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
|
||||
%comm = mpi.comm_world : !mpi.comm
|
||||
|
||||
// CHECK-NOT: llvm.call @MPI_Comm_size
|
||||
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
|
||||
// CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
|
||||
%err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
|
||||
return
|
||||
}
|
||||
|
||||
@ -159,6 +159,40 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
|
||||
return %0 : memref<3x4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @reduce_scatter_memref(
|
||||
func.func @reduce_scatter_memref(
|
||||
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
|
||||
%arg0 : memref<3x4xf32>) -> memref<1x4xf32> {
|
||||
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 7 : i32
|
||||
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<1x4xf32>
|
||||
// CHECK: mpi.reduce_scatter_block([[varg0]], [[valloc]], MPI_SUM, [[vnewcomm]]) : memref<3x4xf32>, memref<1x4xf32>
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 : memref<3x4xf32> -> memref<1x4xf32>
|
||||
// CHECK: return [[valloc]] : memref<1x4xf32>
|
||||
return %0 : memref<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @reduce_scatter_tensor_dim1(
|
||||
func.func @reduce_scatter_tensor_dim1(
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<2x12xf32>
|
||||
%arg0 : tensor<2x12xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK: [[vexpanded:%.*]] = tensor.expand_shape [[varg0]] {{\[\[}}0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xf32> into tensor<2x3x4xf32>
|
||||
// CHECK: [[vempty:%.*]] = tensor.empty() : tensor<3x2x4xf32>
|
||||
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[vexpanded]] : tensor<2x3x4xf32>) outs([[vempty]] : tensor<3x2x4xf32>) permutation = [1, 0, 2]
|
||||
// CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] : tensor<3x2x4xf32> to memref<3x2x4xf32>
|
||||
// CHECK: [[valloctmp:%.*]] = memref.alloc() : memref<3x2x4xf32>
|
||||
// CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
|
||||
// CHECK: mpi.reduce_scatter_block([[valloctmp]], [[valloc]], MPI_SUM,
|
||||
// CHECK: memref.dealloc [[valloctmp]] : memref<3x2x4xf32>
|
||||
// CHECK: [[vout:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<2x4xf32> to tensor<2x4xf32>
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 1 : tensor<2x12xf32> -> tensor<2x4xf32>
|
||||
// CHECK: return [[vout]] : tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @allgather_tensor_0
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
|
||||
func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {
|
||||
|
||||
@ -77,6 +77,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
|
||||
// CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
|
||||
mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
|
||||
|
||||
// CHECK-NEXT: [[v10:%.*]] = mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
|
||||
%err9 = mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
|
||||
|
||||
// CHECK-NEXT: mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
|
||||
mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
|
||||
|
||||
// CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval
|
||||
%rval = mpi.finalize : !mpi.retval
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --shard-simplify --cse %s | FileCheck %s
|
||||
|
||||
shard.grid @grid_1d(shape = ?)
|
||||
|
||||
@ -59,15 +59,15 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid(
|
||||
// CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index
|
||||
// CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index
|
||||
// CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
|
||||
// CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
|
||||
// CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
|
||||
// CHECK: %[[scatter_dim_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
|
||||
// CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
|
||||
// CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
|
||||
// CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
|
||||
// CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
|
||||
// CHECK: %[[RESULT_scatter_dim_SIZE:.*]] = arith.divui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
|
||||
// CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
|
||||
// CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
|
||||
// CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
|
||||
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
|
||||
// CHECK: %[[scatter_dim_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_scatter_dim_SIZE]] : index
|
||||
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[scatter_dim_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_scatter_dim_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
|
||||
%0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x?xf16>
|
||||
return %0 : tensor<?x?xf16>
|
||||
|
||||
@ -135,7 +135,7 @@ func.func @reduce_scatter_empty_grid_axes(
|
||||
// CHECK-NOT: shard.reduce_scatter
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0
|
||||
grid_axes = []
|
||||
scatter_axis = 0
|
||||
scatter_dim = 0
|
||||
: tensor<4xf32> -> tensor<4xf32>
|
||||
// CHECK: return %[[ARG]]
|
||||
return %0 : tensor<4xf32>
|
||||
@ -148,7 +148,7 @@ func.func @reduce_scatter_empty_grid_axes_different_return_type(
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0
|
||||
// CHECK-NOT: grid_axes
|
||||
grid_axes = []
|
||||
scatter_axis = 0
|
||||
scatter_dim = 0
|
||||
: tensor<4xf32> -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
@ -160,7 +160,7 @@ func.func @reduce_scatter_default_reduction(
|
||||
grid_axes = [0]
|
||||
// CHECK-NOT: reduction
|
||||
reduction = sum
|
||||
scatter_axis = 0
|
||||
scatter_dim = 0
|
||||
: tensor<4xf32> -> tensor<2xf64>
|
||||
return %0 : tensor<2xf64>
|
||||
}
|
||||
@ -172,7 +172,7 @@ func.func @scatter_empty_grid_axes(
|
||||
// CHECK-NOT: shard.scatter
|
||||
%0 = shard.scatter %arg0 on @grid0
|
||||
grid_axes = []
|
||||
scatter_axis = 0
|
||||
scatter_dim = 0
|
||||
root = []
|
||||
: (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %[[ARG]]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
|
||||
// RUN: mlir-opt -shard-simplify %s | FileCheck %s
|
||||
|
||||
shard.grid @grid0(shape = 4x?x2)
|
||||
shard.grid @grid1(shape = 2x3)
|
||||
|
||||
@ -707,7 +707,7 @@ shard.grid @grid0(shape = 3)
|
||||
func.func @reduce_scatter_duplicate_grid_axis(
|
||||
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
|
||||
// expected-error@+1 {{Grid axes contains duplicate elements.}}
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_dim = 0
|
||||
: tensor<?xf32> -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
@ -719,7 +719,7 @@ shard.grid @grid0(shape = 3)
|
||||
func.func @reduce_scatter_invalid_dynamic_dimension(
|
||||
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 scatter_dim = 0
|
||||
: tensor<?xf32> -> tensor<2xf64>
|
||||
return %0 : tensor<2xf64>
|
||||
}
|
||||
@ -731,7 +731,7 @@ shard.grid @grid0(shape = 3)
|
||||
func.func @reduce_scatter_invalid_static_dimension_size(
|
||||
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
|
||||
: tensor<3xf32> -> tensor<2xf64>
|
||||
return %0 : tensor<2xf64>
|
||||
}
|
||||
@ -743,7 +743,7 @@ shard.grid @grid0(shape = 3)
|
||||
func.func @reduce_scatter_invalid_operand_static_dimension_size(
|
||||
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
|
||||
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
|
||||
: tensor<4xf32> -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
@ -756,7 +756,7 @@ func.func @scatter_duplicate_grid_axis(
|
||||
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
|
||||
// expected-error@+1 {{Grid axes contains duplicate elements.}}
|
||||
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0]
|
||||
scatter_axis = 0 root = [0, 0]
|
||||
scatter_dim = 0 root = [0, 0]
|
||||
: (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -769,7 +769,7 @@ func.func @scatter_invalid_dynamic_dimension(
|
||||
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
|
||||
%0 = shard.scatter %arg0 on @grid0
|
||||
scatter_axis = 0 root = []
|
||||
scatter_dim = 0 root = []
|
||||
: (tensor<?xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -782,7 +782,7 @@ func.func @scatter_invalid_static_dimension_size(
|
||||
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
|
||||
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
|
||||
scatter_axis = 0 root = [1]
|
||||
scatter_dim = 0 root = [1]
|
||||
: (tensor<3xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -795,7 +795,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
|
||||
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
|
||||
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
|
||||
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
|
||||
scatter_axis = 0 root = [1]
|
||||
scatter_dim = 0 root = [1]
|
||||
: (tensor<4xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -808,7 +808,7 @@ func.func @scatter_root_dimension_out_of_bounds(
|
||||
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
|
||||
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
|
||||
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
|
||||
scatter_axis = 0 root = [3]
|
||||
scatter_dim = 0 root = [3]
|
||||
: (tensor<3xi8>) -> tensor<1xi8>
|
||||
return %0 : tensor<1xi8>
|
||||
}
|
||||
@ -821,7 +821,7 @@ func.func @scatter_root_wrong_number_dimensions(
|
||||
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
|
||||
// expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
|
||||
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
|
||||
scatter_axis = 0 root = [2, 2]
|
||||
scatter_dim = 0 root = [2, 2]
|
||||
: (tensor<3xi8>) -> tensor<1xi8>
|
||||
return %0 : tensor<1xi8>
|
||||
}
|
||||
|
||||
@ -453,10 +453,10 @@ func.func @reduce_scatter_static_dimensions(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
|
||||
// CHECK-NEXT: shard.reduce_scatter %[[ARG]]
|
||||
// CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1
|
||||
// CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_dim = 1
|
||||
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
|
||||
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2]
|
||||
reduction = max scatter_axis = 1
|
||||
reduction = max scatter_dim = 1
|
||||
: tensor<3x4xf32> -> tensor<3x1xf64>
|
||||
return %0 : tensor<3x1xf64>
|
||||
}
|
||||
@ -466,9 +466,9 @@ func.func @reduce_scatter_dynamic_dimensions(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
|
||||
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
|
||||
// CHECK-NEXT: shard.reduce_scatter %[[ARG]]
|
||||
// CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0
|
||||
// CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_dim = 0
|
||||
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
|
||||
%0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0
|
||||
%0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_dim = 0
|
||||
: tensor<?xf32> -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
@ -479,10 +479,10 @@ func.func @scatter_static_dimensions(
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
|
||||
// CHECK-NEXT: shard.scatter %[[ARG]]
|
||||
// CHECK-SAME: on @grid0 grid_axes = [2]
|
||||
// CHECK-SAME: scatter_axis = 1 root = [1]
|
||||
// CHECK-SAME: scatter_dim = 1 root = [1]
|
||||
// CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
|
||||
%0 = shard.scatter %arg0 on @grid0 grid_axes = [2]
|
||||
scatter_axis = 1 root = [1]
|
||||
scatter_dim = 1 root = [1]
|
||||
: (tensor<3x4xf32>) -> tensor<3x1xf32>
|
||||
return %0 : tensor<3x1xf32>
|
||||
}
|
||||
@ -493,10 +493,10 @@ func.func @scatter_dynamic_dimensions(
|
||||
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK-NEXT: shard.scatter %[[ARG]]
|
||||
// CHECK-SAME: on @grid3 grid_axes = [0, 1]
|
||||
// CHECK-SAME: scatter_axis = 0 root = [1, 2]
|
||||
// CHECK-SAME: scatter_dim = 0 root = [1, 2]
|
||||
// CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1]
|
||||
scatter_axis = 0 root = [1, 2]
|
||||
scatter_dim = 0 root = [1, 2]
|
||||
: (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -510,11 +510,11 @@ func.func @scatter_dynamic_root(
|
||||
) -> tensor<1xi8> {
|
||||
// CHECK-NEXT: shard.scatter %[[ARG0]]
|
||||
// CHECK-SAME: on @grid0 grid_axes = [0, 2]
|
||||
// CHECK-SAME: scatter_axis = 0
|
||||
// CHECK-SAME: scatter_dim = 0
|
||||
// CHECK-SAME: root = [1, %[[ARG1]]]
|
||||
// CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
|
||||
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2]
|
||||
scatter_axis = 0
|
||||
scatter_dim = 0
|
||||
root = [1, %arg1]
|
||||
: (tensor<8xi8>, index) -> tensor<1xi8>
|
||||
return %0 : tensor<1xi8>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
|
||||
// RUN: mlir-opt -shard-simplify %s | FileCheck %s
|
||||
|
||||
shard.grid @grid0(shape = 4x2)
|
||||
shard.grid @grid1(shape = 4)
|
||||
@ -177,3 +177,86 @@ func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
|
||||
%0 = arith.maxsi %extracted, %c1_i64 : i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// -----
|
||||
// AllReduceOp + AllSliceOp -> ReduceScatterOp tests
|
||||
// -----
|
||||
|
||||
// Basic case: all_slice(all_reduce(x)) with matching grid and axes folds
|
||||
// into reduce_scatter.
|
||||
// CHECK-LABEL: func.func @all_reduce_all_slice_to_reduce_scatter
|
||||
func.func @all_reduce_all_slice_to_reduce_scatter(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
|
||||
%arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
|
||||
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
|
||||
%1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
|
||||
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
|
||||
// CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
|
||||
// CHECK: return %[[RS]]
|
||||
return %1 : tensor<1x8xf32>
|
||||
}
|
||||
|
||||
// Verify non-default reduction kind is preserved.
|
||||
// CHECK-LABEL: func.func @all_reduce_all_slice_max_reduction
|
||||
func.func @all_reduce_all_slice_max_reduction(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
|
||||
%arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
|
||||
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max : tensor<4x8xf32> -> tensor<4x8xf32>
|
||||
%1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
|
||||
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] reduction = max scatter_dim = 0
|
||||
// CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
|
||||
// CHECK: return %[[RS]]
|
||||
return %1 : tensor<1x8xf32>
|
||||
}
|
||||
|
||||
// Slice on a different tensor axis than the reduce axes.
|
||||
// CHECK-LABEL: func.func @all_reduce_all_slice_different_slice_axis
|
||||
func.func @all_reduce_all_slice_different_slice_axis(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
|
||||
%arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
|
||||
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1] : tensor<4x8xf32> -> tensor<4x8xf32>
|
||||
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
|
||||
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 1
|
||||
// CHECK-SAME: : tensor<4x8xf32> -> tensor<4x4xf32>
|
||||
// CHECK: return %[[RS]]
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// Do not fold when grids differ.
|
||||
// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid
|
||||
func.func @all_reduce_all_slice_different_grid(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
|
||||
%arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
|
||||
// CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0
|
||||
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
|
||||
// CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid1
|
||||
%1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
|
||||
// CHECK: return %[[AS]]
|
||||
return %1 : tensor<1x8xf32>
|
||||
}
|
||||
|
||||
// Do not fold when grid_axes differ.
|
||||
// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid_axes
|
||||
func.func @all_reduce_all_slice_different_grid_axes(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
|
||||
%arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
|
||||
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
|
||||
// CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid0 grid_axes = [1]
|
||||
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
|
||||
// CHECK: return %[[AS]]
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// Verify element type conversion is preserved (all_reduce input/output types may differ).
|
||||
// CHECK-LABEL: func.func @all_reduce_all_slice_type_promotion
|
||||
func.func @all_reduce_all_slice_type_promotion(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
|
||||
%arg0: tensor<4x8xf32>) -> tensor<1x8xf64> {
|
||||
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf64>
|
||||
%1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf64> -> tensor<1x8xf64>
|
||||
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
|
||||
// CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf64>
|
||||
// CHECK: return %[[RS]]
|
||||
return %1 : tensor<1x8xf64>
|
||||
}
|
||||
@ -2,7 +2,6 @@
|
||||
add_mlir_library(MLIRShardTest
|
||||
TestOpLowering.cpp
|
||||
TestReshardingPartition.cpp
|
||||
TestSimplifications.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
//===- TestSimplification.cpp - Test simplification -----------------------===//
|
||||
//===- TestReshardingPartition.cpp - Test resharding partition ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
||||
@ -1,47 +0,0 @@
|
||||
//===- TestSimplification.cpp - Test simplification -----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
|
||||
#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestShardSimplificationsPass
|
||||
: public PassWrapper<TestShardSimplificationsPass, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass)
|
||||
|
||||
void runOnOperation() override;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<arith::ArithDialect, shard::ShardDialect>();
|
||||
}
|
||||
StringRef getArgument() const final { return "test-grid-simplifications"; }
|
||||
StringRef getDescription() const final { return "Test grid simplifications"; }
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TestShardSimplificationsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
shard::populateSimplificationPatterns(patterns, symbolTableCollection);
|
||||
[[maybe_unused]] LogicalResult status =
|
||||
applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
assert(succeeded(status) && "Rewrite patters application did not converge.");
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestShardSimplificationsPass() {
|
||||
PassRegistration<TestShardSimplificationsPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
@ -131,7 +131,6 @@ void registerTestMemRefDependenceCheck();
|
||||
void registerTestMemRefStrideCalculation();
|
||||
void registerTestMemRefToLLVMWithTransforms();
|
||||
void registerTestReshardingPartitionPass();
|
||||
void registerTestShardSimplificationsPass();
|
||||
void registerTestMultiBuffering();
|
||||
void registerTestNextAccessPass();
|
||||
void registerTestNVGPULowerings();
|
||||
@ -281,7 +280,6 @@ static void registerTestPasses() {
|
||||
mlir::test::registerTestMemRefStrideCalculation();
|
||||
mlir::test::registerTestMemRefToLLVMWithTransforms();
|
||||
mlir::test::registerTestReshardingPartitionPass();
|
||||
mlir::test::registerTestShardSimplificationsPass();
|
||||
mlir::test::registerTestMultiBuffering();
|
||||
mlir::test::registerTestNextAccessPass();
|
||||
mlir::test::registerTestNVGPULowerings();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user