diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 7e68b152fdf7..fb0192ba748a 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -333,6 +333,41 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { "(`->` type($retval)^)?"; } +//===----------------------------------------------------------------------===// +// ReduceScatterBlockOp +//===----------------------------------------------------------------------===// + +def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> { + let summary = "Equivalent to `MPI_Reduce_scatter_block(sendbuf, recvbuf, " + "recvcount, dtype, op, comm)`"; + let description = [{ + MPI_Reduce_scatter_block first performs an element-wise reduction on the + sendbuf across all processes in the communicator, then scatters the result + by distributing equal-sized blocks to each process into recvbuf. + + The `op` attribute specifies the reduction operation to be performed. + Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are + supported. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = ( + ins AnyNon0RankedMemRef : $sendbuf, + AnyNon0RankedMemRef : $recvbuf, + MPI_ReductionOpEnum : $op, + MPI_Comm : $comm + ); + + let results = (outs Optional:$retval); + + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` " + "attr-dict `:` type($sendbuf) `,` type($recvbuf) " + "(`->` type($retval)^)?"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // BarrierOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index 6ef7c72d305e..60f6d6fc1ffe 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -901,13 +901,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" `grid_axes` using the specified reduction method. The reduction is performed element-wise across the tensor pieces from all devices in the group. After reduction, the reduction result is scattered (split and distributed) - across the device group along `scatter_axis`. + across the device group along `scatter_dim`. Example: ``` shard.grid @grid0(shape = 2x2) ... %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] - reduction = scatter_axis = 0 + reduction = scatter_dim = 0 : tensor<2x2xf32> -> tensor<1x2xf64> ``` Input: @@ -940,17 +940,17 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" ``` }]; let arguments = !con(commonArgs, (ins - AnyNon0RankedTensor:$input, + AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input, DefaultValuedAttr:$reduction, - IndexAttr:$scatter_axis + IndexAttr:$scatter_dim )); let results = (outs - AnyRankedTensor:$result + AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result ); let assemblyFormat = [{ $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)? - `scatter_axis` `=` $scatter_axis + `scatter_dim` `=` $scatter_dim attr-dict `:` type($input) `->` type($result) }]; let hasCanonicalizer = 1; @@ -964,7 +964,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ let summary = "Scatter over a device grid."; let description = [{ For each device group defined by `grid_axes`, the input tensor on the `root` - device is split along axis `scatter_axis` and distributed across the group. + device is split along axis `scatter_dim` and distributed across the group. The content of the input on all other (non-root) devices is ignored. The `root` device is defined by its in-group multi-index. @@ -972,7 +972,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ ``` shard.grid @grid0(shape = 2x2) %1 = shard.scatter %0 on @grid0 grid_axes = [0] - scatter_axis = 0 + scatter_dim = 0 root = [1] : (tensor<2x2xi8>) -> tensor<1x2xi8> ``` @@ -1011,7 +1011,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, - IndexAttr:$scatter_axis, + IndexAttr:$scatter_dim, DenseI64ArrayAttr:$root, Variadic:$root_dynamic )); @@ -1020,7 +1020,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ ); let assemblyFormat = [{ $input `on` $grid (`grid_axes` `=` $grid_axes^)? - `scatter_axis` `=` $scatter_axis + `scatter_dim` `=` $scatter_dim `root` `=` custom($root_dynamic, $root) attr-dict `:` functional-type(operands, results) }]; diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h index 37903765903d..ba12002b9f1e 100644 --- a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h @@ -1,4 +1,4 @@ -//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===// +//===- Partition.h - Shard Partition ----------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td index bbc6a1977b13..575c176217e6 100644 --- a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td @@ -44,6 +44,23 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO ]; } +def ShardSimplify : Pass<"shard-simplify"> { + let summary = "Shard simplify patterns."; + let description = [{ + Applies simplification patterns on the Shard dialect operations. + This includes: + - All-reduce endomorphism simplification, e.g. transforming + `all_reduce_sum(x) + all_reduce_sum(y)` into `all_reduce_sum(x + y)`. + - Folding `AllSliceOp(AllReduceOp)` into `ReduceScatterOp` when both ops + share the same grid and grid_axes. + - Folding static grid shapes into constants. + }]; + let dependentDialects = [ + "arith::ArithDialect", + "shard::ShardDialect" + ]; +} + def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> { let summary = "Partition a function into SPMD form."; let description = [{ diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h similarity index 89% rename from mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h index 45ae758ec14c..f3f4feffd8a7 100644 --- a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h @@ -1,4 +1,4 @@ -//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===// +//===- Simplify.h - Shard Simplify ------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H -#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H +#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H #include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/PatternMatch.h" @@ -37,8 +37,8 @@ namespace shard { // Will not work with some op `f(x, y, z)` where only `x` and `y` form // the algebraic structure. template -void populateAllReduceEndomorphismSimplificationPatterns( - RewritePatternSet &patterns, ReductionKind reduction) { +void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns, + ReductionKind reduction) { auto getEndomorphismOpOperand = [](Operation *op) { auto allReduceOp = llvm::cast(op); return &allReduceOp.getInputMutable(); @@ -105,12 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns( // It is invalid to change ops that declare symbols during the application of // these patterns, because symbolTableCollection is used to cache them. -void populateSimplificationPatterns( - RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); +void populateSimplifyPatterns(RewritePatternSet &patterns, + SymbolTableCollection &symbolTableCollection); void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection); } // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 0dbc0a126a5c..50817cf5e00d 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -911,6 +912,79 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern { } }; +//===----------------------------------------------------------------------===// +// ReduceScatterBlockOpLowering +//===----------------------------------------------------------------------===// + +struct ReduceScatterBlockOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = rewriter.getContext(); + Type i32 = rewriter.getI32Type(); + Type i64 = rewriter.getI64Type(); + Type elemType = op.getSendbuf().getType().getElementType(); + int64_t sRank = op.getSendbuf().getType().getRank(); + int64_t rRank = op.getRecvbuf().getType().getRank(); + + // ptrType `!llvm.ptr` + Type ptrType = LLVM::LLVMPointerType::get(context); + auto moduleOp = op->getParentOfType(); + auto mpiTraits = MPIImplTraits::get(moduleOp); + auto [sendPtr, sendSize] = + getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType); + auto [recvPtr, recvSize] = + getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType); + + // If input and output are the same, request in-place operation. + if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { + sendPtr = LLVM::ConstantOp::create( + rewriter, loc, i64, + reinterpret_cast(mpiTraits->getInPlace())); + sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr); + } + + Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); + Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp()); + Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); + + Value nRanks = + createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm()); + Value totalExpected = + LLVM::MulOp::create(rewriter, loc, i32, recvSize, nRanks); + Value sizeIsValid = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::eq, sendSize, totalExpected); + cf::AssertOp::create(rewriter, loc, sizeIsValid, + "Send buffer's size must be the receive buffer's size " + "times the number of ranks"); + + // 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf, + // int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)' + auto funcType = LLVM::LLVMFunctionType::get( + i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(), + comm.getType()}); + // get or create function declaration: + LLVM::LLVMFuncOp funcDecl = getOrDefineFunction( + moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType); + + // replace op with function call + auto funcCall = LLVM::CallOp::create( + rewriter, loc, funcDecl, + ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm}); + + if (op.getRetval()) + rewriter.replaceOp(op, funcCall.getResult()); + else + rewriter.eraseOp(op); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// @@ -943,7 +1017,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, patterns.add(converter); + AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter); } void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) { diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index 1db14e60c5a7..830a9333ade4 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -26,7 +26,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shard/IR/ShardDialect.h" #include "mlir/Dialect/Shard/IR/ShardOps.h" -#include "mlir/Dialect/Shard/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Simplify.h" #include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" @@ -618,6 +618,156 @@ struct ConvertAllReduceOp : public CommOpPattern { } }; +struct ConvertReduceScatterOp : public CommOpPattern { + using CommOpPattern::CommOpPattern; + + // shard.reduce_scatter reduces and then scatters along a specified + // scatter-dim. mpi.reduce_scatter_block always scatters along the first + // dimension. Hence, if scatter-dim != 0, we need to rearrange the input + // data by expanding the scatter-dim into {nRanks, output_scatter_dim} and + // transposing nRanks to the first dimension. + + LogicalResult + matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto gridAxes = adaptor.getGridAxes(); + int64_t scatterDim = adaptor.getScatterDimAttr().getInt(); + + SymbolTableCollection symbolTableCollection; + FailureOr gridOp = checkGrid(op, symbolTableCollection); + if (failed(gridOp)) + return failure(); + + ImplicitLocOpBuilder ib(op.getLoc(), rewriter); + Value rawInput = adaptor.getInput(); + auto inShapedType = cast(rawInput.getType()); + MemRefType outType = getMemrefType(cast(op.getType())); + auto elemType = outType.getElementType(); + auto inputShape = inShapedType.getShape(); + auto outputShape = outType.getShape(); + int64_t inputDimOnAxis = inputShape[scatterDim]; + int64_t outputDimOnAxis = outputShape[scatterDim]; + + for (size_t i = 0; i < outputShape.size(); ++i) + if (outputShape[i] != inputShape[i] && + i != static_cast(scatterDim)) + return op.emitError( + "Result and input shapes must match along non-scatter axes."); + if (outputDimOnAxis == 0) + return op.emitError( + "Output size along the scatter axis must be non-zero."); + if (inputDimOnAxis % outputDimOnAxis != 0) + return op.emitError( + "Input size along the scatter axis must be an exact " + "multiple of the output size along the scatter axis."); + + if (!memref::isStaticShapeAndContiguousRowMajor(outType)) + return op.emitError("Result must be a statically shaped memref in " + "contiguous row-major layout."); + + int64_t nRanks = inputDimOnAxis / outputDimOnAxis; + + // Verify that nRanks matches the number of devices along the grid axes. + int64_t gridGroupSize = + collectiveProcessGroupSize(gridAxes, gridOp->getShape()); + if (nRanks != gridGroupSize) + return op.emitError() + << "Expected the scatter factor (" << nRanks + << ") to match the number of devices along grid_axes (" + << gridGroupSize << ")."; + + // Get the right communicator. + Value comm = getComm(*gridOp, gridAxes, ib); + + Value mpiInput; + if (scatterDim == 0) { + // scatter_dim == 0 maps directly to MPI_Reduce_scatter_block. + // Input must be contiguous for MPI. + Value input = getAsMemref(rawInput, ib); + MemRefType inType = cast(input.getType()); + if (!memref::isStaticShapeAndContiguousRowMajor(inType)) + return op.emitError("Input must be a statically shaped memref in " + "contiguous row-major layout."); + mpiInput = input; + } else { + // For scatter_dim != 0 we rearrange the input so the scatter factor + // becomes the first dimension. + // + // 1. Get a tensor representation of the input (avoid memref->tensor + // round-trip if the input is already a tensor). + Value tensorInput = rawInput; + if (!isa(rawInput.getType())) { + auto inTensorType = RankedTensorType::get(inputShape, elemType); + tensorInput = + bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true); + } + + // 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to + // {d0, ..., nRanks, o_sd, ..., dN}. + SmallVector expandedShape; + SmallVector expandReassociation; + int64_t expandedIdx = 0; + for (int64_t i = 0; i < static_cast(inputShape.size()); ++i) { + if (i == scatterDim) { + expandedShape.push_back(nRanks); + expandedShape.push_back(outputDimOnAxis); + expandReassociation.push_back({expandedIdx, expandedIdx + 1}); + expandedIdx += 2; + } else { + expandedShape.push_back(inputShape[i]); + expandReassociation.push_back({expandedIdx}); + expandedIdx += 1; + } + } + auto expandedType = RankedTensorType::get(expandedShape, elemType); + tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput, + expandReassociation); + + // 3. Transpose to move nRanks (at position scatterDim) to position 0: + // {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN} + SmallVector permutation, transposedShape; + permutation.emplace_back(scatterDim); + for (int64_t i = 0; i < scatterDim; ++i) + permutation.emplace_back(i); + for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i) + permutation.emplace_back(i); + for (auto p : permutation) + transposedShape.emplace_back(expandedShape[p]); + + Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType); + tensorInput = + linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation) + ->getResult(0); + + // 4. Materialize as contiguous memref for MPI by copying into a + // freshly allocated buffer. + auto mpiInType = MemRefType::get(transposedShape, elemType); + Value transposedBuf = + bufferization::ToBufferOp::create(ib, mpiInType, tensorInput); + mpiInput = memref::AllocOp::create(ib, mpiInType); + linalg::CopyOp::create(ib, transposedBuf, mpiInput); + } + + // Allocate output buffer. + Value output = memref::AllocOp::create(ib, outType); + // Create the MPI ReduceScatter operation. + mpi::ReduceScatterBlockOp::create( + ib, TypeRange(), mpiInput, output, + getMPIReductionOp(adaptor.getReductionAttr()), comm); + + // Deallocate the temporary input buffer if we allocated one. + if (scatterDim != 0) + memref::DeallocOp::create(ib, mpiInput); + + // If the destination is a tensor, cast it to a tensor. + if (isa(op.getType())) + output = + bufferization::ToTensorOp::create(ib, op.getType(), output, true); + rewriter.replaceOp(op, output); + return success(); + } +}; + struct ConvertAllGatherOp : public CommOpPattern { using CommOpPattern::CommOpPattern; @@ -1048,7 +1198,7 @@ struct ConvertShardToMPIPass patterns.add(typeConverter, ctxt); SymbolTableCollection stc; populateProcessMultiIndexOpLoweringPatterns(patterns, stc); diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index 6cca853071dc..e5e09e28998b 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -15,8 +15,23 @@ using namespace mlir; using namespace mlir::mpi; +//===----------------------------------------------------------------------===// +// Verifiers +//===----------------------------------------------------------------------===// + +LogicalResult mlir::mpi::ReduceScatterBlockOp::verify() { + if (getSendbuf().getType().getElementType() != + getRecvbuf().getType().getElementType()) + return emitOpError("sendbuf and recvbuf must have the same element type"); + return success(); +} + namespace { +//===----------------------------------------------------------------------===// +// Canonicalization patterns +//===----------------------------------------------------------------------===// + // If input memref has dynamic shape and is a cast and if the cast's input has // static shape, fold the cast's static input into the given operation. template diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 98234bada09e..a173da3db1d1 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -1416,7 +1416,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } return verifyScatterOrSliceOperandAndResultShape( - getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(), + getOperand(), getResult(), getScatterDim().getSExtValue(), getGridAxes(), grid.value().getShape()); } @@ -1445,9 +1445,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return failure(); } - auto scatterAxis = getScatterAxis().getSExtValue(); + auto scatterDim = getScatterDim().getSExtValue(); return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(), - scatterAxis, getGridAxes(), + scatterDim, getGridAxes(), grid.value().getShape()); } diff --git a/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt index a884764e70e9..4e3fb6db9966 100644 --- a/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIRShardTransforms - Simplifications.cpp + Simplify.cpp ShardingPropagation.cpp Partition.cpp Transforms.cpp @@ -26,4 +26,5 @@ add_mlir_dialect_library(MLIRShardTransforms MLIRSupport MLIRTensorDialect MLIRTosaShardingInterfaceImpl + MLIRTransformUtils ) diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp similarity index 51% rename from mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp rename to mlir/lib/Dialect/Shard/Transforms/Simplify.cpp index a17671e5408c..525ff007bc2f 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp @@ -1,4 +1,4 @@ -//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===// +//===- Simplify.cpp - Shard Simplify ----------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,13 +6,16 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Shard/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Simplify.h" #include "TransformsDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" #include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Passes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include @@ -20,31 +23,8 @@ namespace mlir { namespace shard { -void populateSimplificationPatterns( - RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Sum); - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Sum); - - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Min); - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Min); - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Min); - - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Max); - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Max); - populateAllReduceEndomorphismSimplificationPatterns( - patterns, ReductionKind::Max); - - // TODO: add simplifications for all-gather and other collectives. - - populateFoldingPatterns(patterns, symbolTableCollection); -} +#define GEN_PASS_DEF_SHARDSIMPLIFY +#include "mlir/Dialect/Shard/Transforms/Passes.h.inc" namespace { @@ -109,12 +89,97 @@ struct GridShapeFolder } }; +// Simplify AllSliceOp(AllReduceOp) -> ReduceScatterOp when both ops share the +// same grid and grid_axes. +// +// AllReduceOp performs an element-wise reduction across all devices in the +// group, and AllSliceOp then slices (scatters) the result along a tensor +// dimension. This is exactly what ReduceScatterOp does in a single collective. +// +// With a ring algorithm over N ranks and M elements: +// AllReduce: 2*(N-1) steps of M/N each => ~2M total data transferred +// AllSlice: local slice, no communication +// ReduceScatter: (N-1) steps of M/N each => ~M total data transferred +// So this fusion roughly halves the communication volume. +// +// Memory-wise, AllReduce produces a full-sized M-element result that the +// subsequent AllSlice must keep alive until the slice is taken. ReduceScatter +// only materializes the M/N-element local slice, reducing peak memory by +// a factor of N. +struct AllReduceAllSliceSimplification : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllSliceOp sliceOp, + PatternRewriter &rewriter) const override { + // Check if the input to AllSliceOp is produced by an AllReduceOp. + auto reduceOp = sliceOp.getInput().getDefiningOp(); + if (!reduceOp || !reduceOp->hasOneUse()) + return failure(); + + // Both ops must operate on the same grid and grid axes. + if (reduceOp.getGrid() != sliceOp.getGrid() || + reduceOp.getGridAxes() != sliceOp.getGridAxes()) + return failure(); + + // Replace with a single ReduceScatterOp. + rewriter.replaceOpWithNewOp( + sliceOp, sliceOp.getResult().getType(), sliceOp.getGridAttr(), + sliceOp.getGridAxesAttr(), reduceOp.getInput(), + reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr()); + + return success(); + } +}; + } // namespace +void populateSimplifyPatterns(RewritePatternSet &patterns, + SymbolTableCollection &symbolTableCollection) { + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Sum); + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Sum); + + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Min); + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Min); + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Min); + + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Max); + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Max); + populateAllReduceEndomorphismSimplifyPatterns( + patterns, ReductionKind::Max); + + patterns.add(patterns.getContext()); + + // TODO: add simplify patterns for all-gather and other collectives. + + populateFoldingPatterns(patterns, symbolTableCollection); +} + void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { patterns.add(symbolTableCollection, patterns.getContext()); } +namespace { + +struct ShardSimplifyPass : public impl::ShardSimplifyBase { + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + SymbolTableCollection symbolTableCollection; + populateSimplifyPatterns(patterns, symbolTableCollection); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + } // namespace shard } // namespace mlir diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir index 9ec81c53b41f..73ad2d8f9299 100644 --- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir +++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir @@ -1,127 +1,151 @@ // RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s // COM: Test MPICH ABI -// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { -// CHECK: llvm.func @MPI_Finalize() -> i32 -// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 -// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32 -// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32 -// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32 +// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { +// CHECK-DAG: llvm.func @MPI_Finalize() -> i32 +// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 +// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 +// CHECK-DAG: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32 +// CHECK-DAG: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32 +// CHECK-DAG: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32 module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { - // CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) { - func.func @mpi_test_mpich(%arg0: memref<100xf32>) { - - // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr - // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32 + // CHECK-LABEL: llvm.func @test_init_finalize_mpich + func.func @test_init_finalize_mpich() { + // CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr + // CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32 %0 = mpi.init : !mpi.retval - - // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 - %comm = mpi.comm_world : !mpi.comm - - // CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr - // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32 - %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 - - // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32 - // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32 - // CHECK: [[v17:%.*]] = llvm.mul [[v17a]] - // CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 - mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 - - // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32 - // CHECK: [[v25:%.*]] = llvm.mul [[v25a]] - // CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 - %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval - - // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32 - // CHECK: [[v33:%.*]] = llvm.mul [[v33a]] - // CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64 - // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr - // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 - - // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32 - // CHECK: [[v43:%.*]] = llvm.mul [[v43a]] - // CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64 - // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr - // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval - - // CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32 - %color = arith.constant 10 : i32 - // CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32 - %key = arith.constant 22 : i32 - // CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr - // CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32 - // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32 - %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm - - // CHECK: llvm.call @MPI_Comm_size - // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32 - %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval - - // CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v63a:%.*]] = llvm.trunc [[v62]] : i64 to i32 - // CHECK: [[v63:%.*]] = llvm.mul [[v63a]] - // CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v68a:%.*]] = llvm.trunc [[v67]] : i64 to i32 - // CHECK: [[v68:%.*]] = llvm.mul [[v68a]] - // CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64 - // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr - // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32 - // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32 - // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 - mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> - // CHECK: llvm.call @MPI_Finalize() : () -> i32 - %3 = mpi.finalize : !mpi.retval + %1 = mpi.finalize : !mpi.retval + return + } + // CHECK-LABEL: llvm.func @test_comm_rank_mpich + func.func @test_comm_rank_mpich() { + // CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v1:%.*]] = llvm.trunc [[v0]] : i64 to i32 + // CHECK: [[v2:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[v3:%.*]] = llvm.alloca [[v2]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.call @MPI_Comm_rank([[v1]], [[v3]]) : (i32, !llvm.ptr) -> i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 + return + } + + // CHECK-LABEL: llvm.func @test_send_mpich + func.func @test_send_mpich(%arg0: memref<100xf32>) { + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + // CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 + %comm = mpi.comm_world : !mpi.comm + // CHECK: llvm.call @MPI_Comm_rank + // CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 + // COM: Test send without retval + // CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32 + // CHECK: [[v8:%.*]] = llvm.mul [[v7]] + // CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 + // CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32 + // CHECK: = llvm.call @MPI_Send([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 + mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 + // COM: Test send with retval + // CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 + %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval + return + } + + // CHECK-LABEL: llvm.func @test_recv_mpich + func.func @test_recv_mpich(%arg0: memref<100xf32>) { + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + // CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 + %comm = mpi.comm_world : !mpi.comm + // CHECK: llvm.call @MPI_Comm_rank + // CHECK: [[v2:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 + // COM: Test recv without retval + // CHECK: [[v3:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v5:%.*]] = llvm.getelementptr [[v3]][[[v4]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v6:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v7:%.*]] = llvm.trunc [[v6]] : i64 to i32 + // CHECK: [[v8:%.*]] = llvm.mul [[v7]] + // CHECK: [[v9:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 + // CHECK: [[v10:%.*]] = llvm.trunc [[v1]] : i64 to i32 + // CHECK: [[v11:%.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: [[v12:%.*]] = llvm.inttoptr [[v11]] : i64 to !llvm.ptr + // CHECK: = llvm.call @MPI_Recv([[v5]], [[v8]], [[v9]], [[v2]], [[v2]], [[v10]], [[v12]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 + mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 + // COM: Test recv with retval + // CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 + %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval + return + } + + // CHECK-LABEL: llvm.func @test_comm_split_mpich + func.func @test_comm_split_mpich() { + // CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v1:%.*]] = llvm.mlir.constant(10 : i32) : i32 + %color = arith.constant 10 : i32 + // CHECK: [[v2:%.*]] = llvm.mlir.constant(22 : i32) : i32 + %key = arith.constant 22 : i32 + // CHECK: [[v3:%.*]] = llvm.trunc [[v0]] : i64 to i32 + // CHECK: [[v4:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[v5:%.*]] = llvm.alloca [[v4]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.call @MPI_Comm_split([[v3]], [[v1]], [[v2]], [[v5]]) : (i32, i32, i32, !llvm.ptr) -> i32 + // CHECK: llvm.load [[v5]] : !llvm.ptr -> i32 + %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm + return + } + + // CHECK-LABEL: llvm.func @test_allgather_mpich + func.func @test_allgather_mpich(%arg0: memref<100xf32>) { + %comm = mpi.comm_world : !mpi.comm + // CHECK: llvm.call @MPI_Comm_size + // CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32 + %err = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval + return + } + + // CHECK-LABEL: llvm.func @test_allreduce_mpich + func.func @test_allreduce_mpich(%arg0: memref<100xf32>) { + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + // CHECK: [[v1:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 + // CHECK: [[v2:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v3:%.*]] = llvm.mul + // CHECK: [[v4:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v5:%.*]] = llvm.mlir.constant(-1 : i64) : i64 + // CHECK: [[v6:%.*]] = llvm.inttoptr [[v5]] : i64 to !llvm.ptr + // CHECK: [[v7:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 + // CHECK: [[v8:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32 + // CHECK: [[v9:%.*]] = llvm.trunc [[v1]] : i64 to i32 + // CHECK: llvm.call @MPI_Allreduce([[v6]], [[v4]], [[v3]], [[v7]], [[v8]], [[v9]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 + mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> + return + } + + // CHECK-LABEL: llvm.func @test_reduce_scatter_block_mpich + func.func @test_reduce_scatter_block_mpich(%arg0: memref<100xf32>) { + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + // CHECK: llvm.mul + // CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32 + // CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32 + // CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32 + // CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr + // CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}} + // CHECK: ^[[bb1]]: + // CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 + mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> return } } @@ -129,131 +153,160 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // ----- // COM: Test OpenMPI ABI -// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} { -// CHECK: llvm.func @MPI_Finalize() -> i32 -// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 -// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque> -// CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 -// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 -// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque> -// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32 -// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque> -// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32 +// CHECK-LABEL: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} { +// CHECK-DAG: llvm.func @MPI_Finalize() -> i32 +// CHECK-DAG: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque> +// CHECK-DAG: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque> +// CHECK-DAG: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32 +// CHECK-DAG: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque> +// CHECK-DAG: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32 module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { - // CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) { - func.func @mpi_test_openmpi(%arg0: memref<100xf32>) { - - // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr - // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32 + // CHECK-LABEL: llvm.func @test_init_finalize_openmpi + func.func @test_init_finalize_openmpi() { + // CHECK: [[v0:%.*]] = llvm.mlir.zero : !llvm.ptr + // CHECK: llvm.call @MPI_Init([[v0]], [[v0]]) : (!llvm.ptr, !llvm.ptr) -> i32 %0 = mpi.init : !mpi.retval + // CHECK: llvm.call @MPI_Finalize() : () -> i32 + %1 = mpi.finalize : !mpi.retval + return + } + // CHECK-LABEL: llvm.func @test_comm_rank_openmpi + func.func @test_comm_rank_openmpi() { %comm = mpi.comm_world : !mpi.comm - // CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr - // CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64 - // CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr - // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32 + // CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64 + // CHECK: [[v2:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr + // CHECK: [[v3:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[v4:%.*]] = llvm.alloca [[v3]] x i32 : (i32) -> !llvm.ptr + // CHECK: llvm.call @MPI_Comm_rank([[v2]], [[v4]]) : (!llvm.ptr, !llvm.ptr) -> i32 %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 + return + } - // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32 - // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32 - // CHECK: [[v17:%.*]] = llvm.mul [[v17a]] - // CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 + // CHECK-LABEL: llvm.func @test_send_openmpi + func.func @test_send_openmpi(%arg0: memref<100xf32>) { + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64 + // CHECK: llvm.call @MPI_Comm_rank + // CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 + // COM: Test send without retval + // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32 + // CHECK: [[v9:%.*]] = llvm.mul [[v8]] + // CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr + // CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr + // CHECK: = llvm.call @MPI_Send([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 - - // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32 - // CHECK: [[v25:%.*]] = llvm.mul [[v25a]] - // CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 + // COM: Test send with retval + // CHECK: = llvm.call @MPI_Send({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval + return + } - // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32 - // CHECK: [[v33:%.*]] = llvm.mul [[v33a]] - // CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64 - // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr - // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 + // CHECK-LABEL: llvm.func @test_recv_openmpi + func.func @test_recv_openmpi(%arg0: memref<100xf32>) { + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64 + // CHECK: llvm.call @MPI_Comm_rank + // CHECK: [[v3:%.*]] = llvm.load {{.*}} : !llvm.ptr -> i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 + // COM: Test recv without retval + // CHECK: [[v4:%.*]] = llvm.extractvalue [[v0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v5:%.*]] = llvm.extractvalue [[v0]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v6:%.*]] = llvm.getelementptr [[v4]][[[v5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v7:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v8:%.*]] = llvm.trunc [[v7]] : i64 to i32 + // CHECK: [[v9:%.*]] = llvm.mul [[v8]] + // CHECK: [[v10:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr + // CHECK: [[v11:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr + // CHECK: [[v12:%.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: [[v13:%.*]] = llvm.inttoptr [[v12]] : i64 to !llvm.ptr + // CHECK: = llvm.call @MPI_Recv([[v6]], [[v9]], [[v10]], [[v3]], [[v3]], [[v11]], [[v13]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 - - // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32 - // CHECK: [[v43:%.*]] = llvm.mul [[v43a]] - // CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64 - // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr - // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 + // COM: Test recv with retval + // CHECK: = llvm.call @MPI_Recv({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval - + return + } + + // CHECK-LABEL: llvm.func @test_comm_split_openmpi + func.func @test_comm_split_openmpi() { + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v0:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v1:%.*]] = llvm.ptrtoint [[v0]] : !llvm.ptr to i64 + // CHECK: [[v2:%.*]] = llvm.mlir.constant(10 : i32) : i32 + %color = arith.constant 10 : i32 + // CHECK: [[v3:%.*]] = llvm.mlir.constant(22 : i32) : i32 + %key = arith.constant 22 : i32 + // CHECK: [[v4:%.*]] = llvm.inttoptr [[v1]] : i64 to !llvm.ptr + // CHECK: [[v5:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[v6:%.*]] = llvm.alloca [[v5]] x !llvm.ptr : (i32) -> !llvm.ptr + // CHECK: llvm.call @MPI_Comm_split([[v4]], [[v2]], [[v3]], [[v6]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32 + // CHECK: llvm.load [[v6]] : !llvm.ptr -> i32 + %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm + return + } + + // CHECK-LABEL: llvm.func @test_allgather_openmpi + func.func @test_allgather_openmpi(%arg0: memref<100xf32>) { + %comm = mpi.comm_world : !mpi.comm // CHECK: llvm.call @MPI_Comm_size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 // CHECK: llvm.udiv {{.*}} : i32 - // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32 + // CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32 mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> + return + } - // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v53a:%.*]] = llvm.trunc [[v52]] : i64 to i32 - // CHECK: [[v53:%.*]] = llvm.mul [[v53a]] - // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v58a:%.*]] = llvm.trunc [[v57]] : i64 to i32 - // CHECK: [[v58:%.*]] = llvm.mul [[v58a]] - // CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64 - // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr - // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr - // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 + // CHECK-LABEL: llvm.func @test_allreduce_openmpi + func.func @test_allreduce_openmpi(%arg0: memref<100xf32>) { + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + // CHECK: [[v1:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v2:%.*]] = llvm.ptrtoint [[v1]] : !llvm.ptr to i64 + // CHECK: [[v3:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v4:%.*]] = llvm.mul + // CHECK: [[v5:%.*]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v6:%.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: [[v7:%.*]] = llvm.inttoptr [[v6]] : i64 to !llvm.ptr + // CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr + // CHECK: [[v9:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr + // CHECK: [[v10:%.*]] = llvm.inttoptr [[v2]] : i64 to !llvm.ptr + // CHECK: llvm.call @MPI_Allreduce([[v7]], [[v5]], [[v4]], [[v8]], [[v9]], [[v10]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> + return + } - // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32 - %color = arith.constant 10 : i32 - // CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32 - %key = arith.constant 22 : i32 - // CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr - // CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr - // CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32 - // CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32 - %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm - - // CHECK: llvm.call @MPI_Finalize() : () -> i32 - %3 = mpi.finalize : !mpi.retval - + // CHECK-LABEL: llvm.func @test_reduce_scatter_block_openmpi + func.func @test_reduce_scatter_block_openmpi(%arg0: memref<100xf32>) { + %comm = mpi.comm_world : !mpi.comm + // CHECK: [[v0:%.*]] = llvm.insertvalue {{.*}}[4, 0] + // CHECK: llvm.mul + // CHECK: [[v1:%.*]] = llvm.mlir.constant(1 : index) : i32 + // CHECK: [[v2:%.*]] = llvm.extractvalue [[v0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v3:%.*]] = llvm.trunc [[v2]] : i64 to i32 + // CHECK: [[v4:%.*]] = llvm.mul [[v3]], [[v1]] : i32 + // CHECK: [[v5:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr + // CHECK: llvm.cond_br {{.*}}, ^[[bb1:.*]], ^{{.*}} + // CHECK: ^[[bb1]]: + // CHECK: llvm.call @MPI_Reduce_scatter_block([[v5]], {{.*}}, [[v4]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 + mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> return } } @@ -261,13 +314,13 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // ----- module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = 4, "MPI:comm_world_rank" = 1> } { - // CHECK: llvm.func @mpi_test_fold - func.func @mpi_test_fold(%arg0: memref<100xf32>) { - // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 + // CHECK-LABEL: llvm.func @test_fold + func.func @test_fold(%arg0: memref<100xf32>) { + // CHECK: [[v0:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 %comm = mpi.comm_world : !mpi.comm // CHECK-NOT: llvm.call @MPI_Comm_size - // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32 + // CHECK: llvm.call @MPI_Allgather({{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32 %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval return } diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir index f3da09d05e3b..08c3897e4e65 100644 --- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir @@ -159,6 +159,40 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { return %0 : memref<3x4xf64> } + // CHECK-LABEL: func.func @reduce_scatter_memref( + func.func @reduce_scatter_memref( + // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32> + %arg0 : memref<3x4xf32>) -> memref<1x4xf32> { + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 0 : i32 + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 7 : i32 + // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<1x4xf32> + // CHECK: mpi.reduce_scatter_block([[varg0]], [[valloc]], MPI_SUM, [[vnewcomm]]) : memref<3x4xf32>, memref<1x4xf32> + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 : memref<3x4xf32> -> memref<1x4xf32> + // CHECK: return [[valloc]] : memref<1x4xf32> + return %0 : memref<1x4xf32> + } + + // CHECK-LABEL: func.func @reduce_scatter_tensor_dim1( + func.func @reduce_scatter_tensor_dim1( + // CHECK-SAME: [[varg0:%.*]]: tensor<2x12xf32> + %arg0 : tensor<2x12xf32>) -> tensor<2x4xf32> { + // CHECK: [[vexpanded:%.*]] = tensor.expand_shape [[varg0]] {{\[\[}}0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xf32> into tensor<2x3x4xf32> + // CHECK: [[vempty:%.*]] = tensor.empty() : tensor<3x2x4xf32> + // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[vexpanded]] : tensor<2x3x4xf32>) outs([[vempty]] : tensor<3x2x4xf32>) permutation = [1, 0, 2] + // CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] : tensor<3x2x4xf32> to memref<3x2x4xf32> + // CHECK: [[valloctmp:%.*]] = memref.alloc() : memref<3x2x4xf32> + // CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>) + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32> + // CHECK: mpi.reduce_scatter_block([[valloctmp]], [[valloc]], MPI_SUM, + // CHECK: memref.dealloc [[valloctmp]] : memref<3x2x4xf32> + // CHECK: [[vout:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<2x4xf32> to tensor<2x4xf32> + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 1 : tensor<2x12xf32> -> tensor<2x4xf32> + // CHECK: return [[vout]] : tensor<2x4xf32> + return %0 : tensor<2x4xf32> + } + // CHECK-LABEL: func @allgather_tensor_0 // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> { diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir index 87a5647ee91d..ee979d33c699 100644 --- a/mlir/test/Dialect/MPI/mpiops.mlir +++ b/mlir/test/Dialect/MPI/mpiops.mlir @@ -77,6 +77,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> + // CHECK-NEXT: [[v10:%.*]] = mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval + %err9 = mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval + + // CHECK-NEXT: mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> + mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> + // CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval %rval = mpi.finalize : !mpi.retval diff --git a/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir index bc911215851a..5b11078c32bb 100644 --- a/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir +++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s +// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --shard-simplify --cse %s | FileCheck %s shard.grid @grid_1d(shape = ?) @@ -59,15 +59,15 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid( // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index - // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor - // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index + // CHECK: %[[scatter_dim_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor + // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index // CHECK: cf.assert %[[AXIS_SIZE_CHECK]] - // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index + // CHECK: %[[RESULT_scatter_dim_SIZE:.*]] = arith.divui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1] // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor - // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index - // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor to tensor + // CHECK: %[[scatter_dim_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_scatter_dim_SIZE]] : index + // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[scatter_dim_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_scatter_dim_SIZE]]] [1, 1] : tensor to tensor %0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor -> tensor // CHECK: return %[[RESULT]] : tensor return %0 : tensor diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir index ed40dfb7237d..a3a9c592ff0a 100644 --- a/mlir/test/Dialect/Shard/canonicalization.mlir +++ b/mlir/test/Dialect/Shard/canonicalization.mlir @@ -135,7 +135,7 @@ func.func @reduce_scatter_empty_grid_axes( // CHECK-NOT: shard.reduce_scatter %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [] - scatter_axis = 0 + scatter_dim = 0 : tensor<4xf32> -> tensor<4xf32> // CHECK: return %[[ARG]] return %0 : tensor<4xf32> @@ -148,7 +148,7 @@ func.func @reduce_scatter_empty_grid_axes_different_return_type( %0 = shard.reduce_scatter %arg0 on @grid0 // CHECK-NOT: grid_axes grid_axes = [] - scatter_axis = 0 + scatter_dim = 0 : tensor<4xf32> -> tensor<4xf64> return %0 : tensor<4xf64> } @@ -160,7 +160,7 @@ func.func @reduce_scatter_default_reduction( grid_axes = [0] // CHECK-NOT: reduction reduction = sum - scatter_axis = 0 + scatter_dim = 0 : tensor<4xf32> -> tensor<2xf64> return %0 : tensor<2xf64> } @@ -172,7 +172,7 @@ func.func @scatter_empty_grid_axes( // CHECK-NOT: shard.scatter %0 = shard.scatter %arg0 on @grid0 grid_axes = [] - scatter_axis = 0 + scatter_dim = 0 root = [] : (tensor<4xf32>) -> tensor<4xf32> // CHECK: return %[[ARG]] diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir index 5a0f35b53a12..7b6ba33f84fd 100644 --- a/mlir/test/Dialect/Shard/folding.mlir +++ b/mlir/test/Dialect/Shard/folding.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s +// RUN: mlir-opt -shard-simplify %s | FileCheck %s shard.grid @grid0(shape = 4x?x2) shard.grid @grid1(shape = 2x3) diff --git a/mlir/test/Dialect/Shard/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir index 6acac971164e..c92932a725d1 100644 --- a/mlir/test/Dialect/Shard/invalid.mlir +++ b/mlir/test/Dialect/Shard/invalid.mlir @@ -707,7 +707,7 @@ shard.grid @grid0(shape = 3) func.func @reduce_scatter_duplicate_grid_axis( %arg0 : tensor) -> tensor { // expected-error@+1 {{Grid axes contains duplicate elements.}} - %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_dim = 0 : tensor -> tensor return %0 : tensor } @@ -719,7 +719,7 @@ shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_dynamic_dimension( %arg0 : tensor) -> tensor<2xf64> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} - %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 scatter_dim = 0 : tensor -> tensor<2xf64> return %0 : tensor<2xf64> } @@ -731,7 +731,7 @@ shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf64> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} - %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 : tensor<3xf32> -> tensor<2xf64> return %0 : tensor<2xf64> } @@ -743,7 +743,7 @@ shard.grid @grid0(shape = 3) func.func @reduce_scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} - %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 : tensor<4xf32> -> tensor return %0 : tensor } @@ -756,7 +756,7 @@ func.func @scatter_duplicate_grid_axis( %arg0 : tensor) -> tensor { // expected-error@+1 {{Grid axes contains duplicate elements.}} %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0] - scatter_axis = 0 root = [0, 0] + scatter_dim = 0 root = [0, 0] : (tensor) -> tensor return %0 : tensor } @@ -769,7 +769,7 @@ func.func @scatter_invalid_dynamic_dimension( %arg0 : tensor) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}} %0 = shard.scatter %arg0 on @grid0 - scatter_axis = 0 root = [] + scatter_dim = 0 root = [] : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -782,7 +782,7 @@ func.func @scatter_invalid_static_dimension_size( %arg0 : tensor<3xf32>) -> tensor<2xf32> { // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}} %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] - scatter_axis = 0 root = [1] + scatter_dim = 0 root = [1] : (tensor<3xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -795,7 +795,7 @@ func.func @scatter_invalid_operand_static_dimension_size( %arg0 : tensor<4xf32>) -> tensor { // expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}} %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] - scatter_axis = 0 root = [1] + scatter_dim = 0 root = [1] : (tensor<4xf32>) -> tensor return %0 : tensor } @@ -808,7 +808,7 @@ func.func @scatter_root_dimension_out_of_bounds( %arg0 : tensor<3xi8>) -> tensor<1xi8> { // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}} %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] - scatter_axis = 0 root = [3] + scatter_dim = 0 root = [3] : (tensor<3xi8>) -> tensor<1xi8> return %0 : tensor<1xi8> } @@ -821,7 +821,7 @@ func.func @scatter_root_wrong_number_dimensions( %arg0 : tensor<3xi8>) -> tensor<1xi8> { // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}} %0 = shard.scatter %arg0 on @grid0 grid_axes = [0] - scatter_axis = 0 root = [2, 2] + scatter_dim = 0 root = [2, 2] : (tensor<3xi8>) -> tensor<1xi8> return %0 : tensor<1xi8> } diff --git a/mlir/test/Dialect/Shard/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir index 5265dadd2a84..5d9ea064bed4 100644 --- a/mlir/test/Dialect/Shard/ops.mlir +++ b/mlir/test/Dialect/Shard/ops.mlir @@ -453,10 +453,10 @@ func.func @reduce_scatter_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> { // CHECK-NEXT: shard.reduce_scatter %[[ARG]] - // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1 + // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_dim = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64> %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2] - reduction = max scatter_axis = 1 + reduction = max scatter_dim = 1 : tensor<3x4xf32> -> tensor<3x1xf64> return %0 : tensor<3x1xf64> } @@ -466,9 +466,9 @@ func.func @reduce_scatter_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { // CHECK-NEXT: shard.reduce_scatter %[[ARG]] - // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0 + // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_dim = 0 // CHECK-SAME: : tensor -> tensor - %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0 + %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_dim = 0 : tensor -> tensor return %0 : tensor } @@ -479,10 +479,10 @@ func.func @scatter_static_dimensions( %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { // CHECK-NEXT: shard.scatter %[[ARG]] // CHECK-SAME: on @grid0 grid_axes = [2] - // CHECK-SAME: scatter_axis = 1 root = [1] + // CHECK-SAME: scatter_dim = 1 root = [1] // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32> %0 = shard.scatter %arg0 on @grid0 grid_axes = [2] - scatter_axis = 1 root = [1] + scatter_dim = 1 root = [1] : (tensor<3x4xf32>) -> tensor<3x1xf32> return %0 : tensor<3x1xf32> } @@ -493,10 +493,10 @@ func.func @scatter_dynamic_dimensions( %arg0 : tensor) -> tensor { // CHECK-NEXT: shard.scatter %[[ARG]] // CHECK-SAME: on @grid3 grid_axes = [0, 1] - // CHECK-SAME: scatter_axis = 0 root = [1, 2] + // CHECK-SAME: scatter_dim = 0 root = [1, 2] // CHECK-SAME: : (tensor) -> tensor %0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1] - scatter_axis = 0 root = [1, 2] + scatter_dim = 0 root = [1, 2] : (tensor) -> tensor return %0 : tensor } @@ -510,11 +510,11 @@ func.func @scatter_dynamic_root( ) -> tensor<1xi8> { // CHECK-NEXT: shard.scatter %[[ARG0]] // CHECK-SAME: on @grid0 grid_axes = [0, 2] - // CHECK-SAME: scatter_axis = 0 + // CHECK-SAME: scatter_dim = 0 // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8> %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2] - scatter_axis = 0 + scatter_dim = 0 root = [1, %arg1] : (tensor<8xi8>, index) -> tensor<1xi8> return %0 : tensor<1xi8> diff --git a/mlir/test/Dialect/Shard/simplifications.mlir b/mlir/test/Dialect/Shard/simplify.mlir similarity index 66% rename from mlir/test/Dialect/Shard/simplifications.mlir rename to mlir/test/Dialect/Shard/simplify.mlir index 33cd490be744..e5693a288fda 100644 --- a/mlir/test/Dialect/Shard/simplifications.mlir +++ b/mlir/test/Dialect/Shard/simplify.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s +// RUN: mlir-opt -shard-simplify %s | FileCheck %s shard.grid @grid0(shape = 4x2) shard.grid @grid1(shape = 4) @@ -177,3 +177,86 @@ func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 { %0 = arith.maxsi %extracted, %c1_i64 : i64 return %0 : i64 } + +// ----- +// AllReduceOp + AllSliceOp -> ReduceScatterOp tests +// ----- + +// Basic case: all_slice(all_reduce(x)) with matching grid and axes folds +// into reduce_scatter. +// CHECK-LABEL: func.func @all_reduce_all_slice_to_reduce_scatter +func.func @all_reduce_all_slice_to_reduce_scatter( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> { + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32> + %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0 + // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: return %[[RS]] + return %1 : tensor<1x8xf32> +} + +// Verify non-default reduction kind is preserved. +// CHECK-LABEL: func.func @all_reduce_all_slice_max_reduction +func.func @all_reduce_all_slice_max_reduction( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> { + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max : tensor<4x8xf32> -> tensor<4x8xf32> + %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] reduction = max scatter_dim = 0 + // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: return %[[RS]] + return %1 : tensor<1x8xf32> +} + +// Slice on a different tensor axis than the reduce axes. +// CHECK-LABEL: func.func @all_reduce_all_slice_different_slice_axis +func.func @all_reduce_all_slice_different_slice_axis( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<4x4xf32> { + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1] : tensor<4x8xf32> -> tensor<4x8xf32> + %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32> + // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 1 + // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x4xf32> + // CHECK: return %[[RS]] + return %1 : tensor<4x4xf32> +} + +// Do not fold when grids differ. +// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid +func.func @all_reduce_all_slice_different_grid( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> { + // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32> + // CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid1 + %1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32> + // CHECK: return %[[AS]] + return %1 : tensor<1x8xf32> +} + +// Do not fold when grid_axes differ. +// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid_axes +func.func @all_reduce_all_slice_different_grid_axes( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<4x4xf32> { + // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32> + // CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid0 grid_axes = [1] + %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32> + // CHECK: return %[[AS]] + return %1 : tensor<4x4xf32> +} + +// Verify element type conversion is preserved (all_reduce input/output types may differ). +// CHECK-LABEL: func.func @all_reduce_all_slice_type_promotion +func.func @all_reduce_all_slice_type_promotion( + // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32> + %arg0: tensor<4x8xf32>) -> tensor<1x8xf64> { + %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf64> + %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf64> -> tensor<1x8xf64> + // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0 + // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf64> + // CHECK: return %[[RS]] + return %1 : tensor<1x8xf64> +} diff --git a/mlir/test/lib/Dialect/Shard/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt index f91c54721e03..a97839b6b1ff 100644 --- a/mlir/test/lib/Dialect/Shard/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt @@ -2,7 +2,6 @@ add_mlir_library(MLIRShardTest TestOpLowering.cpp TestReshardingPartition.cpp - TestSimplifications.cpp EXCLUDE_FROM_LIBMLIR ) diff --git a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp index 23fdad1bd624..1d1812e4aea3 100644 --- a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp +++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp @@ -1,4 +1,4 @@ -//===- TestSimplification.cpp - Test simplification -----------------------===// +//===- TestReshardingPartition.cpp - Test resharding partition ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp deleted file mode 100644 index 28852153f37f..000000000000 --- a/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp +++ /dev/null @@ -1,47 +0,0 @@ -//===- TestSimplification.cpp - Test simplification -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Shard/IR/ShardDialect.h" -#include "mlir/Dialect/Shard/Transforms/Simplifications.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -struct TestShardSimplificationsPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass) - - void runOnOperation() override; - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - StringRef getArgument() const final { return "test-grid-simplifications"; } - StringRef getDescription() const final { return "Test grid simplifications"; } -}; -} // namespace - -void TestShardSimplificationsPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - SymbolTableCollection symbolTableCollection; - shard::populateSimplificationPatterns(patterns, symbolTableCollection); - [[maybe_unused]] LogicalResult status = - applyPatternsGreedily(getOperation(), std::move(patterns)); - assert(succeeded(status) && "Rewrite patters application did not converge."); -} - -namespace mlir { -namespace test { -void registerTestShardSimplificationsPass() { - PassRegistration(); -} -} // namespace test -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index a427132247e6..564bb700b53e 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -131,7 +131,6 @@ void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); void registerTestMemRefToLLVMWithTransforms(); void registerTestReshardingPartitionPass(); -void registerTestShardSimplificationsPass(); void registerTestMultiBuffering(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); @@ -281,7 +280,6 @@ static void registerTestPasses() { mlir::test::registerTestMemRefStrideCalculation(); mlir::test::registerTestMemRefToLLVMWithTransforms(); mlir::test::registerTestReshardingPartitionPass(); - mlir::test::registerTestShardSimplificationsPass(); mlir::test::registerTestMultiBuffering(); mlir::test::registerTestNextAccessPass(); mlir::test::registerTestNVGPULowerings();