[mlir][shard, mpi] Adding Shard/MPI reduce_scatter and simplification (#184189)

- introduces a simplify pass, which finds such patterns and replaces it
with the equivalent `reduce-scatter`
- promotes the test-pass `test-shard-optimizations` to a proper pass and adds
  - folding allgather+allslice into reduce_scatter
- sanitizes the `shard.reduce_scatter` op
- adds a new `mpi.reduce_scatter_block` op
- lowers `shard.reduce_scatter` to MPI
- lowers `mpi-reduce_scatter_block` to llvm

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Frank Schlimbach 2026-03-03 17:34:36 +01:00 committed by GitHub
parent 5f8f1e2afe
commit a232b5b96f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 846 additions and 363 deletions

View File

@ -333,6 +333,41 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
"(`->` type($retval)^)?";
}
//===----------------------------------------------------------------------===//
// ReduceScatterBlockOp
//===----------------------------------------------------------------------===//
def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
let summary = "Equivalent to `MPI_Reduce_scatter_block(sendbuf, recvbuf, "
"recvcount, dtype, op, comm)`";
let description = [{
MPI_Reduce_scatter_block first performs an element-wise reduction on the
sendbuf across all processes in the communicator, then scatters the result
by distributing equal-sized blocks to each process into recvbuf.
The `op` attribute specifies the reduction operation to be performed.
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
let arguments = (
ins AnyNon0RankedMemRef : $sendbuf,
AnyNon0RankedMemRef : $recvbuf,
MPI_ReductionOpEnum : $op,
MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
"attr-dict `:` type($sendbuf) `,` type($recvbuf) "
"(`->` type($retval)^)?";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//

View File

@ -901,13 +901,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
`grid_axes` using the specified reduction method. The reduction is performed
element-wise across the tensor pieces from all devices in the group.
After reduction, the reduction result is scattered (split and distributed)
across the device group along `scatter_axis`.
across the device group along `scatter_dim`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
reduction = <max> scatter_dim = 0
: tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
@ -940,17 +940,17 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
IndexAttr:$scatter_axis
IndexAttr:$scatter_dim
));
let results = (outs
AnyRankedTensor:$result
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
`scatter_dim` `=` $scatter_dim
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
@ -964,7 +964,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
let summary = "Scatter over a device grid.";
let description = [{
For each device group defined by `grid_axes`, the input tensor on the `root`
device is split along axis `scatter_axis` and distributed across the group.
device is split along axis `scatter_dim` and distributed across the group.
The content of the input on all other (non-root) devices is ignored.
The `root` device is defined by its in-group multi-index.
@ -972,7 +972,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
```
shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
scatter_axis = 0
scatter_dim = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
```
@ -1011,7 +1011,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$scatter_axis,
IndexAttr:$scatter_dim,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
@ -1020,7 +1020,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
);
let assemblyFormat = [{
$input `on` $grid (`grid_axes` `=` $grid_axes^)?
`scatter_axis` `=` $scatter_axis
`scatter_dim` `=` $scatter_dim
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];

View File

@ -1,4 +1,4 @@
//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
//===- Partition.h - Shard Partition ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -44,6 +44,23 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
];
}
def ShardSimplify : Pass<"shard-simplify"> {
let summary = "Shard simplify patterns.";
let description = [{
Applies simplification patterns on the Shard dialect operations.
This includes:
- All-reduce endomorphism simplification, e.g. transforming
`all_reduce_sum(x) + all_reduce_sum(y)` into `all_reduce_sum(x + y)`.
- Folding `AllSliceOp(AllReduceOp)` into `ReduceScatterOp` when both ops
share the same grid and grid_axes.
- Folding static grid shapes into constants.
}];
let dependentDialects = [
"arith::ArithDialect",
"shard::ShardDialect"
];
}
def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{

View File

@ -1,4 +1,4 @@
//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
//===- Simplify.h - Shard Simplify ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/PatternMatch.h"
@ -37,8 +37,8 @@ namespace shard {
// Will not work with some op `f(x, y, z)` where only `x` and `y` form
// the algebraic structure.
template <typename AlgebraicOp>
void populateAllReduceEndomorphismSimplificationPatterns(
RewritePatternSet &patterns, ReductionKind reduction) {
void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns,
ReductionKind reduction) {
auto getEndomorphismOpOperand = [](Operation *op) {
auto allReduceOp = llvm::cast<AllReduceOp>(op);
return &allReduceOp.getInputMutable();
@ -105,12 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
// It is invalid to change ops that declare symbols during the application of
// these patterns, because symbolTableCollection is used to cache them.
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void populateSimplifyPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection);
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection);
} // namespace shard
} // namespace mlir
#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H

View File

@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@ -911,6 +912,79 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
}
};
//===----------------------------------------------------------------------===//
// ReduceScatterBlockOpLowering
//===----------------------------------------------------------------------===//
struct ReduceScatterBlockOpLowering
: public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Type elemType = op.getSendbuf().getType().getElementType();
int64_t sRank = op.getSendbuf().getType().getRank();
int64_t rRank = op.getRecvbuf().getType().getRank();
// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
auto [sendPtr, sendSize] =
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
auto [recvPtr, recvSize] =
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
// If input and output are the same, request in-place operation.
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
sendPtr = LLVM::ConstantOp::create(
rewriter, loc, i64,
reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
}
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
Value nRanks =
createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
Value totalExpected =
LLVM::MulOp::create(rewriter, loc, i32, recvSize, nRanks);
Value sizeIsValid = LLVM::ICmpOp::create(
rewriter, loc, LLVM::ICmpPredicate::eq, sendSize, totalExpected);
cf::AssertOp::create(rewriter, loc, sizeIsValid,
"Send buffer's size must be the receive buffer's size "
"times the number of ranks");
// 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
// int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
auto funcType = LLVM::LLVMFunctionType::get(
i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
comm.getType()});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType);
// replace op with function call
auto funcCall = LLVM::CallOp::create(
rewriter, loc, funcDecl,
ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
@ -943,7 +1017,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
SendOpLowering, RecvOpLowering, AllGatherOpLowering,
AllReduceOpLowering>(converter);
AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

View File

@ -26,7 +26,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
#include "mlir/Dialect/Shard/Transforms/Simplify.h"
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@ -618,6 +618,156 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
}
};
struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
using CommOpPattern::CommOpPattern;
// shard.reduce_scatter reduces and then scatters along a specified
// scatter-dim. mpi.reduce_scatter_block always scatters along the first
// dimension. Hence, if scatter-dim != 0, we need to rearrange the input
// data by expanding the scatter-dim into {nRanks, output_scatter_dim} and
// transposing nRanks to the first dimension.
LogicalResult
matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto gridAxes = adaptor.getGridAxes();
int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
SymbolTableCollection symbolTableCollection;
FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
if (failed(gridOp))
return failure();
ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
Value rawInput = adaptor.getInput();
auto inShapedType = cast<ShapedType>(rawInput.getType());
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
auto elemType = outType.getElementType();
auto inputShape = inShapedType.getShape();
auto outputShape = outType.getShape();
int64_t inputDimOnAxis = inputShape[scatterDim];
int64_t outputDimOnAxis = outputShape[scatterDim];
for (size_t i = 0; i < outputShape.size(); ++i)
if (outputShape[i] != inputShape[i] &&
i != static_cast<size_t>(scatterDim))
return op.emitError(
"Result and input shapes must match along non-scatter axes.");
if (outputDimOnAxis == 0)
return op.emitError(
"Output size along the scatter axis must be non-zero.");
if (inputDimOnAxis % outputDimOnAxis != 0)
return op.emitError(
"Input size along the scatter axis must be an exact "
"multiple of the output size along the scatter axis.");
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
return op.emitError("Result must be a statically shaped memref in "
"contiguous row-major layout.");
int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
// Verify that nRanks matches the number of devices along the grid axes.
int64_t gridGroupSize =
collectiveProcessGroupSize(gridAxes, gridOp->getShape());
if (nRanks != gridGroupSize)
return op.emitError()
<< "Expected the scatter factor (" << nRanks
<< ") to match the number of devices along grid_axes ("
<< gridGroupSize << ").";
// Get the right communicator.
Value comm = getComm(*gridOp, gridAxes, ib);
Value mpiInput;
if (scatterDim == 0) {
// scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
// Input must be contiguous for MPI.
Value input = getAsMemref(rawInput, ib);
MemRefType inType = cast<MemRefType>(input.getType());
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
return op.emitError("Input must be a statically shaped memref in "
"contiguous row-major layout.");
mpiInput = input;
} else {
// For scatter_dim != 0 we rearrange the input so the scatter factor
// becomes the first dimension.
//
// 1. Get a tensor representation of the input (avoid memref->tensor
// round-trip if the input is already a tensor).
Value tensorInput = rawInput;
if (!isa<RankedTensorType>(rawInput.getType())) {
auto inTensorType = RankedTensorType::get(inputShape, elemType);
tensorInput =
bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true);
}
// 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to
// {d0, ..., nRanks, o_sd, ..., dN}.
SmallVector<int64_t> expandedShape;
SmallVector<ReassociationIndices> expandReassociation;
int64_t expandedIdx = 0;
for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
if (i == scatterDim) {
expandedShape.push_back(nRanks);
expandedShape.push_back(outputDimOnAxis);
expandReassociation.push_back({expandedIdx, expandedIdx + 1});
expandedIdx += 2;
} else {
expandedShape.push_back(inputShape[i]);
expandReassociation.push_back({expandedIdx});
expandedIdx += 1;
}
}
auto expandedType = RankedTensorType::get(expandedShape, elemType);
tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
expandReassociation);
// 3. Transpose to move nRanks (at position scatterDim) to position 0:
// {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN}
SmallVector<int64_t> permutation, transposedShape;
permutation.emplace_back(scatterDim);
for (int64_t i = 0; i < scatterDim; ++i)
permutation.emplace_back(i);
for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
permutation.emplace_back(i);
for (auto p : permutation)
transposedShape.emplace_back(expandedShape[p]);
Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
tensorInput =
linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
->getResult(0);
// 4. Materialize as contiguous memref for MPI by copying into a
// freshly allocated buffer.
auto mpiInType = MemRefType::get(transposedShape, elemType);
Value transposedBuf =
bufferization::ToBufferOp::create(ib, mpiInType, tensorInput);
mpiInput = memref::AllocOp::create(ib, mpiInType);
linalg::CopyOp::create(ib, transposedBuf, mpiInput);
}
// Allocate output buffer.
Value output = memref::AllocOp::create(ib, outType);
// Create the MPI ReduceScatter operation.
mpi::ReduceScatterBlockOp::create(
ib, TypeRange(), mpiInput, output,
getMPIReductionOp(adaptor.getReductionAttr()), comm);
// Deallocate the temporary input buffer if we allocated one.
if (scatterDim != 0)
memref::DeallocOp::create(ib, mpiInput);
// If the destination is a tensor, cast it to a tensor.
if (isa<RankedTensorType>(op.getType()))
output =
bufferization::ToTensorOp::create(ib, op.getType(), output, true);
rewriter.replaceOp(op, output);
return success();
}
};
struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
using CommOpPattern::CommOpPattern;
@ -1048,7 +1198,7 @@ struct ConvertShardToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
ConvertAllGatherOp, ConvertAllReduceOp,
ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
ConvertProcessLinearIndexOp>(typeConverter, ctxt);
SymbolTableCollection stc;
populateProcessMultiIndexOpLoweringPatterns(patterns, stc);

View File

@ -15,8 +15,23 @@
using namespace mlir;
using namespace mlir::mpi;
//===----------------------------------------------------------------------===//
// Verifiers
//===----------------------------------------------------------------------===//
LogicalResult mlir::mpi::ReduceScatterBlockOp::verify() {
if (getSendbuf().getType().getElementType() !=
getRecvbuf().getType().getElementType())
return emitOpError("sendbuf and recvbuf must have the same element type");
return success();
}
namespace {
//===----------------------------------------------------------------------===//
// Canonicalization patterns
//===----------------------------------------------------------------------===//
// If input memref has dynamic shape and is a cast and if the cast's input has
// static shape, fold the cast's static input into the given operation.
template <typename OpT>

View File

@ -1416,7 +1416,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
return verifyScatterOrSliceOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
getOperand(), getResult(), getScatterDim().getSExtValue(), getGridAxes(),
grid.value().getShape());
}
@ -1445,9 +1445,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
auto scatterAxis = getScatterAxis().getSExtValue();
auto scatterDim = getScatterDim().getSExtValue();
return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
scatterAxis, getGridAxes(),
scatterDim, getGridAxes(),
grid.value().getShape());
}

View File

@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRShardTransforms
Simplifications.cpp
Simplify.cpp
ShardingPropagation.cpp
Partition.cpp
Transforms.cpp
@ -26,4 +26,5 @@ add_mlir_dialect_library(MLIRShardTransforms
MLIRSupport
MLIRTensorDialect
MLIRTosaShardingInterfaceImpl
MLIRTransformUtils
)

View File

@ -1,4 +1,4 @@
//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
//===- Simplify.cpp - Shard Simplify ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,13 +6,16 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
#include "mlir/Dialect/Shard/Transforms/Simplify.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Shard/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <numeric>
@ -20,31 +23,8 @@
namespace mlir {
namespace shard {
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
patterns, ReductionKind::Sum);
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
patterns, ReductionKind::Sum);
populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
patterns, ReductionKind::Min);
populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
patterns, ReductionKind::Min);
populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
patterns, ReductionKind::Min);
populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
patterns, ReductionKind::Max);
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
patterns, ReductionKind::Max);
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
patterns, ReductionKind::Max);
// TODO: add simplifications for all-gather and other collectives.
populateFoldingPatterns(patterns, symbolTableCollection);
}
#define GEN_PASS_DEF_SHARDSIMPLIFY
#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
namespace {
@ -109,12 +89,97 @@ struct GridShapeFolder
}
};
// Simplify AllSliceOp(AllReduceOp) -> ReduceScatterOp when both ops share the
// same grid and grid_axes.
//
// AllReduceOp performs an element-wise reduction across all devices in the
// group, and AllSliceOp then slices (scatters) the result along a tensor
// dimension. This is exactly what ReduceScatterOp does in a single collective.
//
// With a ring algorithm over N ranks and M elements:
// AllReduce: 2*(N-1) steps of M/N each => ~2M total data transferred
// AllSlice: local slice, no communication
// ReduceScatter: (N-1) steps of M/N each => ~M total data transferred
// So this fusion roughly halves the communication volume.
//
// Memory-wise, AllReduce produces a full-sized M-element result that the
// subsequent AllSlice must keep alive until the slice is taken. ReduceScatter
// only materializes the M/N-element local slice, reducing peak memory by
// a factor of N.
struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AllSliceOp sliceOp,
PatternRewriter &rewriter) const override {
// Check if the input to AllSliceOp is produced by an AllReduceOp.
auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
if (!reduceOp || !reduceOp->hasOneUse())
return failure();
// Both ops must operate on the same grid and grid axes.
if (reduceOp.getGrid() != sliceOp.getGrid() ||
reduceOp.getGridAxes() != sliceOp.getGridAxes())
return failure();
// Replace with a single ReduceScatterOp.
rewriter.replaceOpWithNewOp<ReduceScatterOp>(
sliceOp, sliceOp.getResult().getType(), sliceOp.getGridAttr(),
sliceOp.getGridAxesAttr(), reduceOp.getInput(),
reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
return success();
}
};
} // namespace
void populateSimplifyPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection) {
populateAllReduceEndomorphismSimplifyPatterns<arith::AddFOp>(
patterns, ReductionKind::Sum);
populateAllReduceEndomorphismSimplifyPatterns<arith::AddIOp>(
patterns, ReductionKind::Sum);
populateAllReduceEndomorphismSimplifyPatterns<arith::MinimumFOp>(
patterns, ReductionKind::Min);
populateAllReduceEndomorphismSimplifyPatterns<arith::MinSIOp>(
patterns, ReductionKind::Min);
populateAllReduceEndomorphismSimplifyPatterns<arith::MinUIOp>(
patterns, ReductionKind::Min);
populateAllReduceEndomorphismSimplifyPatterns<arith::MaximumFOp>(
patterns, ReductionKind::Max);
populateAllReduceEndomorphismSimplifyPatterns<arith::MaxSIOp>(
patterns, ReductionKind::Max);
populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
patterns, ReductionKind::Max);
patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
// TODO: add simplify patterns for all-gather and other collectives.
populateFoldingPatterns(patterns, symbolTableCollection);
}
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection) {
patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
}
namespace {
struct ShardSimplifyPass : public impl::ShardSimplifyBase<ShardSimplifyPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
populateSimplifyPatterns(patterns, symbolTableCollection);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace
} // namespace shard
} // namespace mlir

View File

@ -1,127 +1,151 @@
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
// COM: Test MPICH ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK-DAG: llvm.func @MPI_Finalize() -> i32
// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
// CHECK-DAG: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
// CHECK-DAG: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// CHECK-DAG: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
// CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-LABEL: llvm.func @test_init_finalize_mpich
func.func @test_init_finalize_mpich() {
// CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr
// CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
// CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
// CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
// CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
// CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
// CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
// CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
// CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
// CHECK: llvm.call @MPI_Comm_size
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
%err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
// CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v63a:%.*]] = llvm.trunc [[v62]] : i64 to i32
// CHECK: [[v63:%.*]] = llvm.mul [[v63a]]
// CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v68a:%.*]] = llvm.trunc [[v67]] : i64 to i32
// CHECK: [[v68:%.*]] = llvm.mul [[v68a]]
// CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
// CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
%1 = mpi.finalize : !mpi.retval
return
}
// CHECK-LABEL: llvm.func @test_comm_rank_mpich
func.func @test_comm_rank_mpich() {
// CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v1:%.*]] = llvm.trunc [[v0]] : i64 to i32
// CHECK: [[v2:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v3:%.*]] = llvm.alloca [[v2]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call @MPI_Comm_rank([[v1]], [[v3]]) : (i32, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
return
}
// CHECK-LABEL: llvm.func @test_send_mpich
func.func @test_send_mpich(%arg0: memref<100xf32>) {
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK: llvm.call @MPI_Comm_rank
// CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// COM: Test send without retval
// CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32
// CHECK: [[v8:%.*]] = llvm.mul [[v7]]
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32
// CHECK: = llvm.call @MPI_Send([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// COM: Test send with retval
// CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
return
}
// CHECK-LABEL: llvm.func @test_recv_mpich
func.func @test_recv_mpich(%arg0: memref<100xf32>) {
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK: llvm.call @MPI_Comm_rank
// CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// COM: Test recv without retval
// CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32
// CHECK: [[v8:%.*]] = llvm.mul [[v7]]
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32
// CHECK: [[v11:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v12:%.*]] = llvm.inttoptr [[v11]] : i64 to !llvm.ptr
// CHECK: = llvm.call @MPI_Recv([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]], [[v12]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// COM: Test recv with retval
// CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
return
}
// CHECK-LABEL: llvm.func @test_comm_split_mpich
func.func @test_comm_split_mpich() {
// CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v1:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
// CHECK: [[v2:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
// CHECK: [[v3:%.*]] = llvm.trunc [[v0]] : i64 to i32
// CHECK: [[v4:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v5:%.*]] = llvm.alloca [[v4]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call @MPI_Comm_split([[v3]], [[v1]], [[v2]], [[v5]]) : (i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.load [[v5]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
return
}
// CHECK-LABEL: llvm.func @test_allgather_mpich
func.func @test_allgather_mpich(%arg0: memref<100xf32>) {
%comm = mpi.comm_world : !mpi.comm
// CHECK: llvm.call @MPI_Comm_size
// CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
%err = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
return
}
// CHECK-LABEL: llvm.func @test_allreduce_mpich
func.func @test_allreduce_mpich(%arg0: memref<100xf32>) {
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
// CHECK: [[v2:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v3:%.*]] = llvm.mul
// CHECK: [[v4:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v5:%.*]] = llvm.mlir.constant(-1 : i64) : i64
// CHECK: [[v6:%.*]] = llvm.inttoptr [[v5]] : i64 to !llvm.ptr
// CHECK: [[v7:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v8:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
// CHECK: [[v9:%.*]] = llvm.trunc [[v1]] : i64 to i32
// CHECK: llvm.call @MPI_Allreduce([[v6]], [[v4]], [[v3]], [[v7]], [[v8]], [[v9]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
return
}
// CHECK-LABEL: llvm.func @test_reduce_scatter_block_mpich
func.func @test_reduce_scatter_block_mpich(%arg0: memref<100xf32>) {
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: llvm.mul
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32
// CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32
// CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32
// CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
// CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}}
// CHECK: ^[[bb1]]:
// CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
return
}
}
@ -129,131 +153,160 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// -----
// COM: Test OpenMPI ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
// CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK-DAG: llvm.func @MPI_Finalize() -> i32
// CHECK-DAG: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK-DAG: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
// CHECK-DAG: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK-DAG: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
// CHECK-DAG: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
// CHECK-DAG: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
// CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-LABEL: llvm.func @test_init_finalize_openmpi
func.func @test_init_finalize_openmpi() {
// CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr
// CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
// CHECK: llvm.call @MPI_Finalize() : () -> i32
%1 = mpi.finalize : !mpi.retval
return
}
// CHECK-LABEL: llvm.func @test_comm_rank_openmpi
func.func @test_comm_rank_openmpi() {
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64
// CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64
// CHECK: [[v2:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr
// CHECK: [[v3:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v4:%.*]] = llvm.alloca [[v3]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call @MPI_Comm_rank([[v2]], [[v4]]) : (!llvm.ptr, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
return
}
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
// CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK-LABEL: llvm.func @test_send_openmpi
func.func @test_send_openmpi(%arg0: memref<100xf32>) {
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
// CHECK: llvm.call @MPI_Comm_rank
// CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// COM: Test send without retval
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32
// CHECK: [[v9:%.*]] = llvm.mul [[v8]]
// CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
// CHECK: = llvm.call @MPI_Send([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
// CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// COM: Test send with retval
// CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
return
}
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
// CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
// CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK-LABEL: llvm.func @test_recv_openmpi
func.func @test_recv_openmpi(%arg0: memref<100xf32>) {
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
// CHECK: llvm.call @MPI_Comm_rank
// CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// COM: Test recv without retval
// CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32
// CHECK: [[v9:%.*]] = llvm.mul [[v8]]
// CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
// CHECK: [[v12:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v13:%.*]] = llvm.inttoptr [[v12]] : i64 to !llvm.ptr
// CHECK: = llvm.call @MPI_Recv([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]], [[v13]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
// CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
// CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// COM: Test recv with retval
// CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
return
}
// CHECK-LABEL: llvm.func @test_comm_split_openmpi
func.func @test_comm_split_openmpi() {
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64
// CHECK: [[v2:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
// CHECK: [[v3:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
// CHECK: [[v4:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr
// CHECK: [[v5:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v6:%.*]] = llvm.alloca [[v5]] x !llvm.ptr : (i32) -> !llvm.ptr
// CHECK: llvm.call @MPI_Comm_split([[v4]], [[v2]], [[v3]], [[v6]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.load [[v6]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
return
}
// CHECK-LABEL: llvm.func @test_allgather_openmpi
func.func @test_allgather_openmpi(%arg0: memref<100xf32>) {
%comm = mpi.comm_world : !mpi.comm
// CHECK: llvm.call @MPI_Comm_size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.udiv {{.*}} : i32
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
return
}
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v53a:%.*]] = llvm.trunc [[v52]] : i64 to i32
// CHECK: [[v53:%.*]] = llvm.mul [[v53a]]
// CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v58a:%.*]] = llvm.trunc [[v57]] : i64 to i32
// CHECK: [[v58:%.*]] = llvm.mul [[v58a]]
// CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
// CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK-LABEL: llvm.func @test_allreduce_openmpi
func.func @test_allreduce_openmpi(%arg0: memref<100xf32>) {
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64
// CHECK: [[v3:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v4:%.*]] = llvm.mul
// CHECK: [[v5:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v6:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v7:%.*]] = llvm.inttoptr [[v6]] : i64 to !llvm.ptr
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v9:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
// CHECK: [[v10:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr
// CHECK: llvm.call @MPI_Allreduce([[v7]], [[v5]], [[v4]], [[v8]], [[v9]], [[v10]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
return
}
// CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
// CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
// CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr
// CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
// CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
// CHECK-LABEL: llvm.func @test_reduce_scatter_block_openmpi
func.func @test_reduce_scatter_block_openmpi(%arg0: memref<100xf32>) {
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0]
// CHECK: llvm.mul
// CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32
// CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32
// CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32
// CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
// CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}}
// CHECK: ^[[bb1]]:
// CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
return
}
}
@ -261,13 +314,13 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// -----
module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = 4, "MPI:comm_world_rank" = 1> } {
// CHECK: llvm.func @mpi_test_fold
func.func @mpi_test_fold(%arg0: memref<100xf32>) {
// CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
// CHECK-LABEL: llvm.func @test_fold
func.func @test_fold(%arg0: memref<100xf32>) {
// CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
// CHECK-NOT: llvm.call @MPI_Comm_size
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
// CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
%err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
return
}

View File

@ -159,6 +159,40 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
return %0 : memref<3x4xf64>
}
// CHECK-LABEL: func.func @reduce_scatter_memref(
func.func @reduce_scatter_memref(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<1x4xf32> {
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 0 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 7 : i32
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<1x4xf32>
// CHECK: mpi.reduce_scatter_block([[varg0]], [[valloc]], MPI_SUM, [[vnewcomm]]) : memref<3x4xf32>, memref<1x4xf32>
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 : memref<3x4xf32> -> memref<1x4xf32>
// CHECK: return [[valloc]] : memref<1x4xf32>
return %0 : memref<1x4xf32>
}
// CHECK-LABEL: func.func @reduce_scatter_tensor_dim1(
func.func @reduce_scatter_tensor_dim1(
// CHECK-SAME: [[varg0:%.*]]: tensor<2x12xf32>
%arg0 : tensor<2x12xf32>) -> tensor<2x4xf32> {
// CHECK: [[vexpanded:%.*]] = tensor.expand_shape [[varg0]] {{\[\[}}0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xf32> into tensor<2x3x4xf32>
// CHECK: [[vempty:%.*]] = tensor.empty() : tensor<3x2x4xf32>
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[vexpanded]] : tensor<2x3x4xf32>) outs([[vempty]] : tensor<3x2x4xf32>) permutation = [1, 0, 2]
// CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] : tensor<3x2x4xf32> to memref<3x2x4xf32>
// CHECK: [[valloctmp:%.*]] = memref.alloc() : memref<3x2x4xf32>
// CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
// CHECK: mpi.reduce_scatter_block([[valloctmp]], [[valloc]], MPI_SUM,
// CHECK: memref.dealloc [[valloctmp]] : memref<3x2x4xf32>
// CHECK: [[vout:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<2x4xf32> to tensor<2x4xf32>
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 1 : tensor<2x12xf32> -> tensor<2x4xf32>
// CHECK: return [[vout]] : tensor<2x4xf32>
return %0 : tensor<2x4xf32>
}
// CHECK-LABEL: func @allgather_tensor_0
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {

View File

@ -77,6 +77,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK-NEXT: [[v10:%.*]] = mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
%err9 = mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
// CHECK-NEXT: mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval
%rval = mpi.finalize : !mpi.retval

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s
// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --shard-simplify --cse %s | FileCheck %s
shard.grid @grid_1d(shape = ?)
@ -59,15 +59,15 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid(
// CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index
// CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index
// CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
// CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
// CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
// CHECK: %[[scatter_dim_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
// CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
// CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
// CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
// CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
// CHECK: %[[RESULT_scatter_dim_SIZE:.*]] = arith.divui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
// CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
// CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
// CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
// CHECK: %[[scatter_dim_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_scatter_dim_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[scatter_dim_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_scatter_dim_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
%0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
// CHECK: return %[[RESULT]] : tensor<?x?xf16>
return %0 : tensor<?x?xf16>

View File

@ -135,7 +135,7 @@ func.func @reduce_scatter_empty_grid_axes(
// CHECK-NOT: shard.reduce_scatter
%0 = shard.reduce_scatter %arg0 on @grid0
grid_axes = []
scatter_axis = 0
scatter_dim = 0
: tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
return %0 : tensor<4xf32>
@ -148,7 +148,7 @@ func.func @reduce_scatter_empty_grid_axes_different_return_type(
%0 = shard.reduce_scatter %arg0 on @grid0
// CHECK-NOT: grid_axes
grid_axes = []
scatter_axis = 0
scatter_dim = 0
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
@ -160,7 +160,7 @@ func.func @reduce_scatter_default_reduction(
grid_axes = [0]
// CHECK-NOT: reduction
reduction = sum
scatter_axis = 0
scatter_dim = 0
: tensor<4xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
@ -172,7 +172,7 @@ func.func @scatter_empty_grid_axes(
// CHECK-NOT: shard.scatter
%0 = shard.scatter %arg0 on @grid0
grid_axes = []
scatter_axis = 0
scatter_dim = 0
root = []
: (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
// RUN: mlir-opt -shard-simplify %s | FileCheck %s
shard.grid @grid0(shape = 4x?x2)
shard.grid @grid1(shape = 2x3)

View File

@ -707,7 +707,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
// expected-error@+1 {{Grid axes contains duplicate elements.}}
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_dim = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@ -719,7 +719,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
%0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0
%0 = shard.reduce_scatter %arg0 on @grid0 scatter_dim = 0
: tensor<?xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
@ -731,7 +731,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
: tensor<3xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
@ -743,7 +743,7 @@ shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
: tensor<4xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@ -756,7 +756,7 @@ func.func @scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// expected-error@+1 {{Grid axes contains duplicate elements.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0]
scatter_axis = 0 root = [0, 0]
scatter_dim = 0 root = [0, 0]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -769,7 +769,7 @@ func.func @scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
%0 = shard.scatter %arg0 on @grid0
scatter_axis = 0 root = []
scatter_dim = 0 root = []
: (tensor<?xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -782,7 +782,7 @@ func.func @scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [1]
scatter_dim = 0 root = [1]
: (tensor<3xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -795,7 +795,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [1]
scatter_dim = 0 root = [1]
: (tensor<4xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -808,7 +808,7 @@ func.func @scatter_root_dimension_out_of_bounds(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [3]
scatter_dim = 0 root = [3]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
@ -821,7 +821,7 @@ func.func @scatter_root_wrong_number_dimensions(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [2, 2]
scatter_dim = 0 root = [2, 2]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}

View File

@ -453,10 +453,10 @@ func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
// CHECK-NEXT: shard.reduce_scatter %[[ARG]]
// CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1
// CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_dim = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
%0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2]
reduction = max scatter_axis = 1
reduction = max scatter_dim = 1
: tensor<3x4xf32> -> tensor<3x1xf64>
return %0 : tensor<3x1xf64>
}
@ -466,9 +466,9 @@ func.func @reduce_scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
// CHECK-NEXT: shard.reduce_scatter %[[ARG]]
// CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0
// CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_dim = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
%0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0
%0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_dim = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@ -479,10 +479,10 @@ func.func @scatter_static_dimensions(
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
// CHECK-NEXT: shard.scatter %[[ARG]]
// CHECK-SAME: on @grid0 grid_axes = [2]
// CHECK-SAME: scatter_axis = 1 root = [1]
// CHECK-SAME: scatter_dim = 1 root = [1]
// CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
%0 = shard.scatter %arg0 on @grid0 grid_axes = [2]
scatter_axis = 1 root = [1]
scatter_dim = 1 root = [1]
: (tensor<3x4xf32>) -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
}
@ -493,10 +493,10 @@ func.func @scatter_dynamic_dimensions(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NEXT: shard.scatter %[[ARG]]
// CHECK-SAME: on @grid3 grid_axes = [0, 1]
// CHECK-SAME: scatter_axis = 0 root = [1, 2]
// CHECK-SAME: scatter_dim = 0 root = [1, 2]
// CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
%0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1]
scatter_axis = 0 root = [1, 2]
scatter_dim = 0 root = [1, 2]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -510,11 +510,11 @@ func.func @scatter_dynamic_root(
) -> tensor<1xi8> {
// CHECK-NEXT: shard.scatter %[[ARG0]]
// CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: scatter_axis = 0
// CHECK-SAME: scatter_dim = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
%0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2]
scatter_axis = 0
scatter_dim = 0
root = [1, %arg1]
: (tensor<8xi8>, index) -> tensor<1xi8>
return %0 : tensor<1xi8>

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
// RUN: mlir-opt -shard-simplify %s | FileCheck %s
shard.grid @grid0(shape = 4x2)
shard.grid @grid1(shape = 4)
@ -177,3 +177,86 @@ func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
%0 = arith.maxsi %extracted, %c1_i64 : i64
return %0 : i64
}
// -----
// AllReduceOp + AllSliceOp -> ReduceScatterOp tests
// -----
// Basic case: all_slice(all_reduce(x)) with matching grid and axes folds
// into reduce_scatter.
// CHECK-LABEL: func.func @all_reduce_all_slice_to_reduce_scatter
func.func @all_reduce_all_slice_to_reduce_scatter(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
%arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
%1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
// CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
// CHECK: return %[[RS]]
return %1 : tensor<1x8xf32>
}
// Verify non-default reduction kind is preserved.
// CHECK-LABEL: func.func @all_reduce_all_slice_max_reduction
func.func @all_reduce_all_slice_max_reduction(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
%arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max : tensor<4x8xf32> -> tensor<4x8xf32>
%1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] reduction = max scatter_dim = 0
// CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
// CHECK: return %[[RS]]
return %1 : tensor<1x8xf32>
}
// Slice on a different tensor axis than the reduce axes.
// CHECK-LABEL: func.func @all_reduce_all_slice_different_slice_axis
func.func @all_reduce_all_slice_different_slice_axis(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
%arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1] : tensor<4x8xf32> -> tensor<4x8xf32>
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 1
// CHECK-SAME: : tensor<4x8xf32> -> tensor<4x4xf32>
// CHECK: return %[[RS]]
return %1 : tensor<4x4xf32>
}
// Do not fold when grids differ.
// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid
func.func @all_reduce_all_slice_different_grid(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
%arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
// CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
// CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid1
%1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
// CHECK: return %[[AS]]
return %1 : tensor<1x8xf32>
}
// Do not fold when grid_axes differ.
// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid_axes
func.func @all_reduce_all_slice_different_grid_axes(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
%arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
// CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
// CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid0 grid_axes = [1]
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
// CHECK: return %[[AS]]
return %1 : tensor<4x4xf32>
}
// Verify element type conversion is preserved (all_reduce input/output types may differ).
// CHECK-LABEL: func.func @all_reduce_all_slice_type_promotion
func.func @all_reduce_all_slice_type_promotion(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
%arg0: tensor<4x8xf32>) -> tensor<1x8xf64> {
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf64>
%1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf64> -> tensor<1x8xf64>
// CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
// CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf64>
// CHECK: return %[[RS]]
return %1 : tensor<1x8xf64>
}

View File

@ -2,7 +2,6 @@
add_mlir_library(MLIRShardTest
TestOpLowering.cpp
TestReshardingPartition.cpp
TestSimplifications.cpp
EXCLUDE_FROM_LIBMLIR
)

View File

@ -1,4 +1,4 @@
//===- TestSimplification.cpp - Test simplification -----------------------===//
//===- TestReshardingPartition.cpp - Test resharding partition ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -1,47 +0,0 @@
//===- TestSimplification.cpp - Test simplification -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestShardSimplificationsPass
: public PassWrapper<TestShardSimplificationsPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass)
void runOnOperation() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, shard::ShardDialect>();
}
StringRef getArgument() const final { return "test-grid-simplifications"; }
StringRef getDescription() const final { return "Test grid simplifications"; }
};
} // namespace
void TestShardSimplificationsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
shard::populateSimplificationPatterns(patterns, symbolTableCollection);
[[maybe_unused]] LogicalResult status =
applyPatternsGreedily(getOperation(), std::move(patterns));
assert(succeeded(status) && "Rewrite patters application did not converge.");
}
namespace mlir {
namespace test {
void registerTestShardSimplificationsPass() {
PassRegistration<TestShardSimplificationsPass>();
}
} // namespace test
} // namespace mlir

View File

@ -131,7 +131,6 @@ void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMemRefToLLVMWithTransforms();
void registerTestReshardingPartitionPass();
void registerTestShardSimplificationsPass();
void registerTestMultiBuffering();
void registerTestNextAccessPass();
void registerTestNVGPULowerings();
@ -281,7 +280,6 @@ static void registerTestPasses() {
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMemRefToLLVMWithTransforms();
mlir::test::registerTestReshardingPartitionPass();
mlir::test::registerTestShardSimplificationsPass();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestNVGPULowerings();