diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index d9e47ea3f6bf..7e68b152fdf7 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -103,7 +103,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> { // CommSplitOp //===----------------------------------------------------------------------===// -def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> { +def MPI_CommSplitOp : MPI_Op<"comm_split"> { let summary = "Partition the group associated with the given communicator into " "disjoint subgroups"; let description = [{ diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index c765ad5a579c..1db14e60c5a7 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -16,7 +16,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -624,9 +624,8 @@ struct ConvertAllGatherOp : public CommOpPattern { // shard.allgather concatenates along a specified gather-axis. // mpi.allgather always concatenates along the first dimension and // there is no MPI operation that allows gathering along an arbitrary axis. - // Hence, if gather-axis!=0, we need to create a temporary buffer - // where we gather along the first dimension and then copy from that - // buffer to the final output along the specified gather-axis. + // Hence, if gather-axis != 0, we need to permute the output buffer + // accordingly. LogicalResult matchAndRewrite(AllGatherOp op, OpAdaptor adaptor, @@ -635,104 +634,124 @@ struct ConvertAllGatherOp : public CommOpPattern { FailureOr gridOp = checkGrid(op, symbolTableCollection); if (failed(gridOp)) return failure(); - ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); - Value input = getAsMemref(adaptor.getInput(), iBuilder); + + ImplicitLocOpBuilder ib(op.getLoc(), rewriter); + Value input = getAsMemref(adaptor.getInput(), ib); MemRefType inType = cast(input.getType()); - if (!memref::isStaticShapeAndContiguousRowMajor(inType)) - return op.emitError( - "Expected static shaped memref in contiguous row-major layout."); MemRefType outType = getMemrefType(cast(op.getType())); - if (!memref::isStaticShapeAndContiguousRowMajor(outType)) - return op.emitError( - "Expected static shaped memref in contiguous row-major layout."); + auto inputShape = inType.getShape(); + auto outputShape = outType.getShape(); int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt(); - auto ctx = op->getContext(); + int64_t inputDimOnAxis = inputShape[gatherAxis]; + int64_t outputDimOnAxis = outputShape[gatherAxis]; - // Get the right communicator - Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder); - - Value nRanks = - mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm) - .getSize(); - nRanks = - arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks); - - Value tmpOutput, gatherDimSz; - if (gatherAxis == 0) { - tmpOutput = memref::AllocOp::create(iBuilder, outType); - } else { - // MPI's allgather always concatenates along the first dimension. - // Create a memref type for the output buffer with adjusted (expanded) - // shape. - SmallVector gatherShape(1, ShapedType::kDynamic); - llvm::append_range(gatherShape, outType.getShape()); - gatherShape[gatherAxis + 1] = ShapedType::kDynamic; - MemRefType gatherType = - MemRefType::get(gatherShape, outType.getElementType()); - gatherDimSz = arith::ConstantIndexOp::create( - iBuilder, outType.getDimSize(gatherAxis)); - gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(), - gatherDimSz, nRanks); - // Allocate output buffer - tmpOutput = - memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz}); + for (size_t i = 0; i < outputShape.size(); ++i) + if (outputShape[i] != inputShape[i] && i != (size_t)gatherAxis) + return op.emitError( + "Result and input shapes must match along non-gather axes."); + if (inputDimOnAxis == 0) + return op.emitError("Input size along the gather axis must be non-zero."); + if (inputDimOnAxis == 1) { + assert(outputDimOnAxis == inputDimOnAxis); + rewriter.replaceOp(op, adaptor.getInput()); + return success(); } + if (outputDimOnAxis % inputDimOnAxis != 0) + return op.emitError("Result size along the gather axis must be an exact " + "multiple of the input size along the gather axis."); + + if (!memref::isStaticShapeAndContiguousRowMajor(inType) || + !memref::isStaticShapeAndContiguousRowMajor(outType)) + return op.emitError("Input/result must be statically shaped memrefs in " + "contiguous row-major layout."); + + // Get the right communicator. + Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib); + Value nRanksV = + mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize(); + nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV); + int64_t nRanks = outputDimOnAxis / inputDimOnAxis; + Value nRanksC = arith::ConstantIndexOp::create(ib, nRanks); + Value notError = + arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC); + cf::AssertOp::create(ib, notError, + "Expected number of ranks in the communicator to " + "match the output size along the gather axis divided " + "by the input size along the gather axis."); + + // mpi.allgather always concatenates along the first dimension, so + // get a output buffer of shape {nRanks, dim0, ...}. + SmallVector gatherShape; + gatherShape.emplace_back(nRanks); + gatherShape.append(inputShape.begin(), inputShape.end()); + auto gatherType = MemRefType::get(gatherShape, outType.getElementType()); + Value finalOutput = memref::AllocOp::create(ib, gatherType); // Create the MPI AllGather operation. - mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm); + mpi::AllGatherOp::create(ib, TypeRange(), input, finalOutput, comm); - // If gather-axis!=0, copy from gathered buffer to output with the right - // layout. - Value finalOutput = tmpOutput; - if (gatherAxis != 0) { - int64_t nSrcDims = cast(tmpOutput.getType()).getRank(); - assert(nSrcDims == outType.getRank() + 1 && - "Expected gathered type to have rank one more than output type."); + if (gatherAxis == 0) { + // If gather axis == 0, simply collapse the first 2 dims from {nRanks, + // dim0, ...} to {nRanks*dim0, ...}. + SmallVector reassociation; + reassociation.push_back({0, 1}); + int64_t numGatherDims = gatherShape.size(); + for (int64_t i = 2; i < numGatherDims; ++i) + reassociation.push_back({i}); + finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput, + reassociation); - // Create affine map for copying from gathered buffer to output. - SmallVector dims; - dims.reserve(nSrcDims); - for (unsigned i = 0; i < nSrcDims; ++i) - dims.emplace_back(getAffineDimExpr(i, ctx)); - AffineExpr s = getAffineSymbolExpr(0, ctx); - SmallVector results; - results.reserve(nSrcDims); - for (unsigned i = 0; i < nSrcDims - 1; ++i) { - if (i == gatherAxis) - results.emplace_back(dims[0] * s + dims[gatherAxis + 1]); - else - results.emplace_back(dims[i + 1]); + // If the op's result is a tensor, cast it to a tensor. + if (isa(op.getType())) + finalOutput = bufferization::ToTensorOp::create(ib, op.getType(), + finalOutput, true); + } else { + // 1. Enter tensor-land. + auto inType = + RankedTensorType::get(gatherShape, outType.getElementType()); + finalOutput = + bufferization::ToTensorOp::create(ib, inType, finalOutput, true); + + // 2. Permute the output buffer from {nRanks, dim0, ..., gatherAxis, ...} + // to {dim0, ..., nRanks, dim1,...}. + SmallVector outShapePermuted, permutation; + for (int i = 1; i <= gatherAxis; ++i) { + outShapePermuted.emplace_back(gatherShape[i]); + permutation.emplace_back(i); } - auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx); + outShapePermuted.emplace_back(gatherShape[0]); + permutation.emplace_back(0); + for (size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) { + outShapePermuted.emplace_back(gatherShape[i]); + permutation.emplace_back(i); + } + Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted, + outType.getElementType()); + finalOutput = + linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation) + ->getResult(0); - finalOutput = memref::AllocOp::create(iBuilder, outType); + // 3. Collapse the output buffer from {dim0, ..., nRanks, gatherAxis, ...} + // to {dim0, ..., nRanks*gatherAxis, ...}. + SmallVector reassociation; + for (int64_t i = 0; i < gatherAxis; ++i) { + reassociation.push_back({i}); + } + reassociation.push_back({gatherAxis, gatherAxis + 1}); + for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size(); + ++i) { + reassociation.push_back({i}); + } + auto outTType = + RankedTensorType::get(outputShape, outType.getElementType()); + finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput, + reassociation); - // Now build a loop nest to copy from gathered buffer to finalOutput - // It would be nicer to just use a memref.transpose/collapse_shape op but - // these currently only support simpler cases. - Value zero = arith::ConstantIndexOp::create(iBuilder, 0); - SmallVector lbs(nSrcDims, zero); - SmallVector ubs; - for (int64_t d = 0; d < nSrcDims; ++d) - ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d)); - SmallVector steps(nSrcDims, 1); - auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) { - Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs); - // set symbol value - SmallVector ivss(ivs.begin(), ivs.end()); - ivss.emplace_back(gatherDimSz); - affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap, - ivss); - }; - affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps, - emitCopy); - - memref::DeallocOp::create(iBuilder, tmpOutput); + // 4. Cast back to memref if needed. + if (isa(op.getType())) + finalOutput = + bufferization::ToBufferOp::create(ib, outType, finalOutput); } - // If the destination is a tensor, cast it to a tensor - if (isa(op.getType())) - finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(), - finalOutput, true); rewriter.replaceOp(op, finalOutput); return success(); } diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 62dc8f5917ab..e619c7073a8c 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -436,24 +436,32 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, targetSharding); } -// Handles only resharding on a 1D shard. -// Currently the sharded tensor axes must be exactly divisible by the single -// grid axis size. +// In most cases the sharded tensor axes must be exactly divisible by the single +// grid axis size. Only halo size changes can deal with non-divisible cases. static TypedValue -reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, - const Sharding &sourceSharding, const Sharding &targetSharding, - TypedValue sourceUnshardedValue, - TypedValue sourceShard) { +reshard(ImplicitLocOpBuilder &builder, GridOp grid, + const Sharding &sourceSharding, const Sharding &targetSharding, + TypedValue sourceUnshardedValue, + TypedValue sourceShard) { + // If source and destination sharding are the same, no need to do anything. + if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) && + isFullReplication(targetSharding))) { + return sourceShard; + } + + // Tries to handle the case where the resharding is needed because the halo + // sizes are different. Supports arbitrary grid dimensionality. + if (auto tryRes = tryUpdateHaloInResharding( + builder, grid, sourceSharding, targetSharding, + sourceUnshardedValue.getType(), sourceShard)) { + return std::get<0>(tryRes.value()); // targetShard + } + assert(sourceShard.getType() == shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding)); [[maybe_unused]] ShapedType targetShardType = shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding); assert(sourceShard.getType().getRank() == targetShardType.getRank()); - assert(grid.getRank() == 1 && "Only 1D grides are currently supported."); - - if (sourceSharding == targetSharding) { - return sourceShard; - } TypedValue targetShard; Sharding actualTargetSharding; @@ -475,38 +483,13 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, std::tie(targetShard, actualTargetSharding) = tryRes.value(); } } + assert(targetShard && "Did not find any pattern to apply."); assert(actualTargetSharding == targetSharding); assert(targetShard.getType() == targetShardType); return targetShard; } -static TypedValue -reshard(ImplicitLocOpBuilder &builder, GridOp grid, - const Sharding &sourceSharding, const Sharding &targetSharding, - TypedValue sourceUnshardedValue, - TypedValue sourceShard) { - // If source and destination sharding are the same, no need to do anything. - if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) && - isFullReplication(targetSharding))) { - return sourceShard; - } - - // Tries to handle the case where the resharding is needed because the halo - // sizes are different. Supports arbitrary grid dimensionality. - if (auto tryRes = tryUpdateHaloInResharding( - builder, grid, sourceSharding, targetSharding, - sourceUnshardedValue.getType(), sourceShard)) { - return std::get<0>(tryRes.value()); // targetShard - } - - // Resort to handling only 1D grids since the general case is complicated if - // it needs to be communication efficient in terms of minimizing the data - // transfered between devices. - return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding, - sourceUnshardedValue, sourceShard); -} - TypedValue reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue sourceShardValue) { diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir index 4ac4a69dd5b1..6161c131c8f5 100644 --- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir @@ -148,6 +148,28 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { return %0 : memref<3x4xf64> } + // CHECK-LABEL: func @allgather_tensor_0 + // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> + func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> { + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index + // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32> + // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm + // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32 + // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index + // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index + // CHECK: cf.assert [[v3]] + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x3x4xf32> + // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<4x3x4xf32> + // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1], [2]] : memref<4x3x4xf32> into memref<12x4xf32> + // CHECK: [[v4:%.*]] = bufferization.to_tensor [[vcollapse_shape]] restrict : memref<12x4xf32> to tensor<12x4xf32> + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0 : tensor<3x4xf32> -> tensor<12x4xf32> + // CHECK: return [[v4]] : tensor<12x4xf32> + return %0 : tensor<12x4xf32> + } + // CHECK-LABEL: func @allgather_tensor func.func @allgather_tensor( // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> @@ -155,28 +177,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> { // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 - // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index + // CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32> // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32 // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index - // CHECK: [[v3:%.*]] = arith.divsi [[vc20]], [[v2]] : index - // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref - // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref - // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32> - // CHECK: affine.for [[varg1:%.*]] = 0 to [[v2]] { - // CHECK: affine.for [[varg2:%.*]] = 0 to 3 { - // CHECK: affine.for [[varg3:%.*]] = 0 to [[v3]] { - // CHECK: [[v5:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref - // CHECK: affine.store [[v5]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v3]]) + [[varg3]]] : memref<3x20xf32> - // CHECK: } - // CHECK: } - // CHECK: } - // CHECK: memref.dealloc [[valloc]] : memref - // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<3x20xf32> to tensor<3x20xf32> + // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc5]] : index + // CHECK: cf.assert [[v3]] + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32> + // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32> + // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32> + // CHECK: [[v5:%.*]] = tensor.empty() : tensor<3x5x4xf32> + // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<5x3x4xf32>) outs([[v5]] : tensor<3x5x4xf32>) permutation = [1, 0, 2] + // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32> %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32> - // CHECK: return [[v4]] : tensor<3x20xf32> + // CHECK: return [[vcollapsed]] : tensor<3x20xf32> return %0 : tensor<3x20xf32> } @@ -185,28 +201,24 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32> // CHECK-SAME: -> memref<3x20xf32> %arg0 : memref<3x4xf32>) -> memref<3x20xf32> { - // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 - // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 + // CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32 // CHECK: [[v1:%.*]] = arith.index_cast [[vsize]] : i32 to index - // CHECK: [[v2:%.*]] = arith.divsi [[vc20]], [[v1]] : index - // CHECK: [[valloc:%.*]] = memref.alloc([[v1]], [[v2]]) : memref - // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref - // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32> - // CHECK: affine.for [[varg1:%.*]] = 0 to [[v1]] { - // CHECK: affine.for [[varg2:%.*]] = 0 to 3 { - // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] { - // CHECK: [[v3:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref - // CHECK: affine.store [[v3]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v2]]) + [[varg3]]] : memref<3x20xf32> - // CHECK: } - // CHECK: } - // CHECK: } - // CHECK: memref.dealloc [[valloc]] : memref + // CHECK: [[v2:%.*]] = arith.cmpi eq, [[v1]], [[vc5]] : index + // CHECK: cf.assert [[v2]] + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32> + // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32> + // CHECK: [[v3:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32> + // CHECK: [[v4:%.*]] = tensor.empty() : tensor<3x5x4xf32> + // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v3]] : tensor<5x3x4xf32>) outs([[v4]] : tensor<3x5x4xf32>) permutation = [1, 0, 2] + // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32> + // CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] : tensor<3x20xf32> to memref<3x20xf32> %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32> - // CHECK: return [[valloc_0]] : memref<3x20xf32> + // CHECK: return [[v5]] : memref<3x20xf32> return %0 : memref<3x20xf32> } } @@ -377,9 +389,9 @@ shard.grid @grid0(shape = 2x2x4) // CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor, tensor, tensor) { func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) { %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding - // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16> - // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> - // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16 // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> @@ -397,10 +409,10 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.s // CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor, tensor, tensor) { func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) { %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding - // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64> - // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16> - // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> - // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64> + // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16 // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> @@ -417,12 +429,12 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !s // CHECK-SAME: [[varg0:%.*]]: tensor) -> (tensor, tensor, tensor, tensor) { func.func @return_sharding_offs(%arg0: tensor) -> (tensor, !shard.sharding) { %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding - // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64> - // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64> - // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64 - // CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16> - // CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> - // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16 + // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64> + // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64> + // CHECK-DAG: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64 + // CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16> + // CHECK-DAG: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16> + // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16 // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16> // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16> // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16> @@ -444,39 +456,38 @@ shard.grid @grid_1d_4(shape = 4) // CHECK-LABEL: func.func @mlp_1dgrid( // CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32> func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %arg2: tensor<256x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} { - // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32 %cst = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: [[vc0:%.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index + // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32> // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm - // CHECK: [[vsize:%.*]] = mpi.comm_size + // CHECK: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32 // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index - // CHECK: [[v3:%.*]] = arith.divsi - // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref - // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref - // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32> - // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] { - // CHECK: affine.for [[varg4:%.*]] = 0 to 512 { - // CHECK: affine.for [[varg5:%.*]] = 0 to [[v3]] { - // CHECK: [[v19:%.*]] = memref.load [[valloc]][[[varg3]], [[varg4]], [[varg5]]] : memref - // CHECK: affine.store [[v19]], [[valloc_0]][[[varg4]], [[varg3]] * symbol([[v3]]) + [[varg5]]] : memref<512x2048xf32> - // CHECK: memref.dealloc [[valloc]] : memref - // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32> + // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index + // CHECK: cf.assert [[v3]] + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x512x512xf32> + // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<4x512x512xf32> + // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<4x512x512xf32> to tensor<4x512x512xf32> + // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x4x512xf32> + // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<4x512x512xf32>) outs([[v5]] : tensor<512x4x512xf32>) permutation = [1, 0, 2] + // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<512x4x512xf32> into tensor<512x2048xf32> %all_gather = shard.all_gather %arg0 on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32> - // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x256xf32> + // CHECK: [[v6:%.*]] = tensor.empty() : tensor<512x256xf32> %0 = tensor.empty() : tensor<512x256xf32> - // CHECK: [[v6:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v5]] : tensor<512x256xf32>) -> tensor<512x256xf32> + // CHECK: [[v7:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32> %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x256xf32>) -> tensor<512x256xf32> - // CHECK: [[v7:%.*]] = linalg.matmul ins([[v4]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32> + // CHECK: [[v8:%.*]] = linalg.matmul ins([[vcollapsed]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v7]] : tensor<512x256xf32>) -> tensor<512x256xf32> %2 = linalg.matmul ins(%all_gather, %arg1 : tensor<512x2048xf32>, tensor<2048x256xf32>) outs(%1 : tensor<512x256xf32>) -> tensor<512x256xf32> - // CHECK: [[v8:%.*]] = tosa.sigmoid [[v7]] : (tensor<512x256xf32>) -> tensor<512x256xf32> + // CHECK: [[v9:%.*]] = tosa.sigmoid [[v8]] : (tensor<512x256xf32>) -> tensor<512x256xf32> %3 = tosa.sigmoid %2 : (tensor<512x256xf32>) -> tensor<512x256xf32> %4 = tensor.empty() : tensor<512x2048xf32> %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<512x2048xf32>) -> tensor<512x2048xf32> %proc_linear_idx = shard.process_multi_index on @grid_1d_4 axes = [0] : index %grid_shape = shard.grid_shape @grid_1d_4 axes = [0] : index %6 = arith.cmpi eq, %proc_linear_idx, %c0 : index - // CHECK: [[v14:%.*]] = scf.if + // CHECK: [[v15:%.*]] = scf.if %7 = scf.if %6 -> (tensor<512x2048xf32>) { scf.yield %5 : tensor<512x2048xf32> } else { @@ -484,15 +495,15 @@ func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, % %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<512x2048xf32>) -> tensor<512x2048xf32> scf.yield %10 : tensor<512x2048xf32> } - // CHECK: [[v15:%.*]] = linalg.matmul ins([[v8]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v14]] : tensor<512x2048xf32>) -> tensor<512x2048xf32> + // CHECK: [[v16:%.*]] = linalg.matmul ins([[v9]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v15]] : tensor<512x2048xf32>) -> tensor<512x2048xf32> %8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32> - // CHECK: [[v16:%.*]] = bufferization.to_buffer - // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<512x2048xf32> - // CHECK: linalg.copy ins([[v16]] : memref<512x2048xf32>) outs([[valloc_1]] : memref<512x2048xf32>) - // CHECK: [[v17:%.*]] = mpi.comm_world : !mpi.comm - // CHECK: mpi.allreduce([[valloc_1]], [[valloc_1]], MPI_SUM, [[v17]]) : memref<512x2048xf32>, memref<512x2048xf32> - // CHECK: [[v18:%.*]] = bufferization.to_tensor [[valloc_1]] restrict : memref<512x2048xf32> to tensor<512x2048xf32> + // CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] : tensor<512x2048xf32> to memref<512x2048xf32> + // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32> + // CHECK: linalg.copy ins([[v17]] : memref<512x2048xf32>) outs([[valloc_0]] : memref<512x2048xf32>) + // CHECK: [[v18:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: mpi.allreduce([[valloc_0]], [[valloc_0]], MPI_SUM, [[v18]]) : memref<512x2048xf32>, memref<512x2048xf32> + // CHECK: [[v19:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32> %all_reduce = shard.all_reduce %8 on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32> - // CHECK: return [[v18]] : tensor<512x2048xf32> + // CHECK: return [[v19]] : tensor<512x2048xf32> return %all_reduce : tensor<512x2048xf32> } diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir index cd9fa2215e0e..4c8271aefcaf 100644 --- a/mlir/test/Dialect/Shard/partition.mlir +++ b/mlir/test/Dialect/Shard/partition.mlir @@ -4,6 +4,7 @@ shard.grid @grid_1d(shape = 2) shard.grid @grid_1d_4(shape = 4) +shard.grid @grid_2d_16(shape = 4x4) // CHECK-LABEL: func @return_sharding func.func @return_sharding( @@ -318,9 +319,9 @@ func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) { return %sharded_ret : tensor<6xi32> } -// CHECK-LABEL: func.func @mlp_1dgrid +// CHECK-LABEL: func.func @mlp_1d_weight_stationary // CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32> -func.func @mlp_1dgrid(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} { +func.func @mlp_1d_weight_stationary(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} { // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32 %sharding = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding