[mlir][shard,mpi] Allowing 2d-grids and simplifying lowering shard.all_gather (#180243)

- fixing incorrect assertion and related function name
- MPI_comm_split is not pure
- simplifying/standardizing permutation in all_gather

---------

Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
This commit is contained in:
Frank Schlimbach 2026-02-10 16:04:22 +01:00 committed by GitHub
parent 5e0e389360
commit a6929f7937
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 215 additions and 201 deletions

View File

@ -103,7 +103,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
// CommSplitOp
//===----------------------------------------------------------------------===//
def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
def MPI_CommSplitOp : MPI_Op<"comm_split"> {
let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{

View File

@ -16,7 +16,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@ -624,9 +624,8 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
// shard.allgather concatenates along a specified gather-axis.
// mpi.allgather always concatenates along the first dimension and
// there is no MPI operation that allows gathering along an arbitrary axis.
// Hence, if gather-axis!=0, we need to create a temporary buffer
// where we gather along the first dimension and then copy from that
// buffer to the final output along the specified gather-axis.
// Hence, if gather-axis != 0, we need to permute the output buffer
// accordingly.
LogicalResult
matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
@ -635,104 +634,124 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
if (failed(gridOp))
return failure();
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
Value input = getAsMemref(adaptor.getInput(), iBuilder);
ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
Value input = getAsMemref(adaptor.getInput(), ib);
MemRefType inType = cast<MemRefType>(input.getType());
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");
auto inputShape = inType.getShape();
auto outputShape = outType.getShape();
int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
auto ctx = op->getContext();
int64_t inputDimOnAxis = inputShape[gatherAxis];
int64_t outputDimOnAxis = outputShape[gatherAxis];
// Get the right communicator
Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
Value nRanks =
mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
.getSize();
nRanks =
arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
Value tmpOutput, gatherDimSz;
if (gatherAxis == 0) {
tmpOutput = memref::AllocOp::create(iBuilder, outType);
} else {
// MPI's allgather always concatenates along the first dimension.
// Create a memref type for the output buffer with adjusted (expanded)
// shape.
SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
llvm::append_range(gatherShape, outType.getShape());
gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
MemRefType gatherType =
MemRefType::get(gatherShape, outType.getElementType());
gatherDimSz = arith::ConstantIndexOp::create(
iBuilder, outType.getDimSize(gatherAxis));
gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
gatherDimSz, nRanks);
// Allocate output buffer
tmpOutput =
memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
for (size_t i = 0; i < outputShape.size(); ++i)
if (outputShape[i] != inputShape[i] && i != (size_t)gatherAxis)
return op.emitError(
"Result and input shapes must match along non-gather axes.");
if (inputDimOnAxis == 0)
return op.emitError("Input size along the gather axis must be non-zero.");
if (inputDimOnAxis == 1) {
assert(outputDimOnAxis == inputDimOnAxis);
rewriter.replaceOp(op, adaptor.getInput());
return success();
}
if (outputDimOnAxis % inputDimOnAxis != 0)
return op.emitError("Result size along the gather axis must be an exact "
"multiple of the input size along the gather axis.");
if (!memref::isStaticShapeAndContiguousRowMajor(inType) ||
!memref::isStaticShapeAndContiguousRowMajor(outType))
return op.emitError("Input/result must be statically shaped memrefs in "
"contiguous row-major layout.");
// Get the right communicator.
Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
Value nRanksV =
mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
Value nRanksC = arith::ConstantIndexOp::create(ib, nRanks);
Value notError =
arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
cf::AssertOp::create(ib, notError,
"Expected number of ranks in the communicator to "
"match the output size along the gather axis divided "
"by the input size along the gather axis.");
// mpi.allgather always concatenates along the first dimension, so
// get a output buffer of shape {nRanks, dim0, ...}.
SmallVector<int64_t> gatherShape;
gatherShape.emplace_back(nRanks);
gatherShape.append(inputShape.begin(), inputShape.end());
auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
Value finalOutput = memref::AllocOp::create(ib, gatherType);
// Create the MPI AllGather operation.
mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
mpi::AllGatherOp::create(ib, TypeRange(), input, finalOutput, comm);
// If gather-axis!=0, copy from gathered buffer to output with the right
// layout.
Value finalOutput = tmpOutput;
if (gatherAxis != 0) {
int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
assert(nSrcDims == outType.getRank() + 1 &&
"Expected gathered type to have rank one more than output type.");
if (gatherAxis == 0) {
// If gather axis == 0, simply collapse the first 2 dims from {nRanks,
// dim0, ...} to {nRanks*dim0, ...}.
SmallVector<ReassociationIndices> reassociation;
reassociation.push_back({0, 1});
int64_t numGatherDims = gatherShape.size();
for (int64_t i = 2; i < numGatherDims; ++i)
reassociation.push_back({i});
finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
reassociation);
// Create affine map for copying from gathered buffer to output.
SmallVector<AffineExpr> dims;
dims.reserve(nSrcDims);
for (unsigned i = 0; i < nSrcDims; ++i)
dims.emplace_back(getAffineDimExpr(i, ctx));
AffineExpr s = getAffineSymbolExpr(0, ctx);
SmallVector<AffineExpr> results;
results.reserve(nSrcDims);
for (unsigned i = 0; i < nSrcDims - 1; ++i) {
if (i == gatherAxis)
results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
else
results.emplace_back(dims[i + 1]);
// If the op's result is a tensor, cast it to a tensor.
if (isa<RankedTensorType>(op.getType()))
finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
finalOutput, true);
} else {
// 1. Enter tensor-land.
auto inType =
RankedTensorType::get(gatherShape, outType.getElementType());
finalOutput =
bufferization::ToTensorOp::create(ib, inType, finalOutput, true);
// 2. Permute the output buffer from {nRanks, dim0, ..., gatherAxis, ...}
// to {dim0, ..., nRanks, dim1,...}.
SmallVector<int64_t> outShapePermuted, permutation;
for (int i = 1; i <= gatherAxis; ++i) {
outShapePermuted.emplace_back(gatherShape[i]);
permutation.emplace_back(i);
}
auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
outShapePermuted.emplace_back(gatherShape[0]);
permutation.emplace_back(0);
for (size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
outShapePermuted.emplace_back(gatherShape[i]);
permutation.emplace_back(i);
}
Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
outType.getElementType());
finalOutput =
linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
->getResult(0);
finalOutput = memref::AllocOp::create(iBuilder, outType);
// 3. Collapse the output buffer from {dim0, ..., nRanks, gatherAxis, ...}
// to {dim0, ..., nRanks*gatherAxis, ...}.
SmallVector<ReassociationIndices> reassociation;
for (int64_t i = 0; i < gatherAxis; ++i) {
reassociation.push_back({i});
}
reassociation.push_back({gatherAxis, gatherAxis + 1});
for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
++i) {
reassociation.push_back({i});
}
auto outTType =
RankedTensorType::get(outputShape, outType.getElementType());
finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
reassociation);
// Now build a loop nest to copy from gathered buffer to finalOutput
// It would be nicer to just use a memref.transpose/collapse_shape op but
// these currently only support simpler cases.
Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
SmallVector<Value> lbs(nSrcDims, zero);
SmallVector<Value> ubs;
for (int64_t d = 0; d < nSrcDims; ++d)
ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
SmallVector<int64_t> steps(nSrcDims, 1);
auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
// set symbol value
SmallVector<Value> ivss(ivs.begin(), ivs.end());
ivss.emplace_back(gatherDimSz);
affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
ivss);
};
affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
emitCopy);
memref::DeallocOp::create(iBuilder, tmpOutput);
// 4. Cast back to memref if needed.
if (isa<MemRefType>(op.getType()))
finalOutput =
bufferization::ToBufferOp::create(ib, outType, finalOutput);
}
// If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
finalOutput, true);
rewriter.replaceOp(op, finalOutput);
return success();
}

View File

@ -436,24 +436,32 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
targetSharding);
}
// Handles only resharding on a 1D shard.
// Currently the sharded tensor axes must be exactly divisible by the single
// grid axis size.
// In most cases the sharded tensor axes must be exactly divisible by the single
// grid axis size. Only halo size changes can deal with non-divisible cases.
static TypedValue<ShapedType>
reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
const Sharding &sourceSharding, const Sharding &targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
reshard(ImplicitLocOpBuilder &builder, GridOp grid,
const Sharding &sourceSharding, const Sharding &targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
isFullReplication(targetSharding))) {
return sourceShard;
}
// Tries to handle the case where the resharding is needed because the halo
// sizes are different. Supports arbitrary grid dimensionality.
if (auto tryRes = tryUpdateHaloInResharding(
builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
return std::get<0>(tryRes.value()); // targetShard
}
assert(sourceShard.getType() ==
shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
[[maybe_unused]] ShapedType targetShardType =
shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
assert(sourceShard.getType().getRank() == targetShardType.getRank());
assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
if (sourceSharding == targetSharding) {
return sourceShard;
}
TypedValue<ShapedType> targetShard;
Sharding actualTargetSharding;
@ -475,38 +483,13 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
std::tie(targetShard, actualTargetSharding) = tryRes.value();
}
}
assert(targetShard && "Did not find any pattern to apply.");
assert(actualTargetSharding == targetSharding);
assert(targetShard.getType() == targetShardType);
return targetShard;
}
static TypedValue<ShapedType>
reshard(ImplicitLocOpBuilder &builder, GridOp grid,
const Sharding &sourceSharding, const Sharding &targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
isFullReplication(targetSharding))) {
return sourceShard;
}
// Tries to handle the case where the resharding is needed because the halo
// sizes are different. Supports arbitrary grid dimensionality.
if (auto tryRes = tryUpdateHaloInResharding(
builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
return std::get<0>(tryRes.value()); // targetShard
}
// Resort to handling only 1D grids since the general case is complicated if
// it needs to be communication efficient in terms of minimizing the data
// transfered between devices.
return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue, sourceShard);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) {

View File

@ -148,6 +148,28 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
return %0 : memref<3x4xf64>
}
// CHECK-LABEL: func @allgather_tensor_0
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
// CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
// CHECK: cf.assert [[v3]]
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x3x4xf32>
// CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<4x3x4xf32>
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1], [2]] : memref<4x3x4xf32> into memref<12x4xf32>
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[vcollapse_shape]] restrict : memref<12x4xf32> to tensor<12x4xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0 : tensor<3x4xf32> -> tensor<12x4xf32>
// CHECK: return [[v4]] : tensor<12x4xf32>
return %0 : tensor<12x4xf32>
}
// CHECK-LABEL: func @allgather_tensor
func.func @allgather_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
@ -155,28 +177,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
%arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
// CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
// CHECK: [[v3:%.*]] = arith.divsi [[vc20]], [[v2]] : index
// CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x3x?xf32>
// CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
// CHECK: affine.for [[varg1:%.*]] = 0 to [[v2]] {
// CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
// CHECK: affine.for [[varg3:%.*]] = 0 to [[v3]] {
// CHECK: [[v5:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
// CHECK: affine.store [[v5]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v3]]) + [[varg3]]] : memref<3x20xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<3x20xf32> to tensor<3x20xf32>
// CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc5]] : index
// CHECK: cf.assert [[v3]]
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
// CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
// CHECK: [[v5:%.*]] = tensor.empty() : tensor<3x5x4xf32>
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<5x3x4xf32>) outs([[v5]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
// CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32>
// CHECK: return [[v4]] : tensor<3x20xf32>
// CHECK: return [[vcollapsed]] : tensor<3x20xf32>
return %0 : tensor<3x20xf32>
}
@ -185,28 +201,24 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
// CHECK-SAME: -> memref<3x20xf32>
%arg0 : memref<3x4xf32>) -> memref<3x20xf32> {
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
// CHECK: [[v1:%.*]] = arith.index_cast [[vsize]] : i32 to index
// CHECK: [[v2:%.*]] = arith.divsi [[vc20]], [[v1]] : index
// CHECK: [[valloc:%.*]] = memref.alloc([[v1]], [[v2]]) : memref<?x3x?xf32>
// CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
// CHECK: affine.for [[varg1:%.*]] = 0 to [[v1]] {
// CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
// CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
// CHECK: [[v3:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
// CHECK: affine.store [[v3]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v2]]) + [[varg3]]] : memref<3x20xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
// CHECK: [[v2:%.*]] = arith.cmpi eq, [[v1]], [[vc5]] : index
// CHECK: cf.assert [[v2]]
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
// CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
// CHECK: [[v3:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
// CHECK: [[v4:%.*]] = tensor.empty() : tensor<3x5x4xf32>
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v3]] : tensor<5x3x4xf32>) outs([[v4]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
// CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
// CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] : tensor<3x20xf32> to memref<3x20xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
// CHECK: return [[valloc_0]] : memref<3x20xf32>
// CHECK: return [[v5]] : memref<3x20xf32>
return %0 : memref<3x20xf32>
}
}
@ -377,9 +389,9 @@ shard.grid @grid0(shape = 2x2x4)
// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) {
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK-DAG: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@ -397,10 +409,10 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.s
// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) {
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK-DAG: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@ -417,12 +429,12 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !s
// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) {
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
// CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
// CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK-DAG: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
// CHECK-DAG: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
// CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK-DAG: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@ -444,39 +456,38 @@ shard.grid @grid_1d_4(shape = 4)
// CHECK-LABEL: func.func @mlp_1dgrid(
// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %arg2: tensor<256x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
// CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
%cst = arith.constant 0.000000e+00 : f32
// CHECK-DAG: [[vc0:%.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size
// CHECK: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
// CHECK: [[v3:%.*]] = arith.divsi
// CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x512x?xf32>
// CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<?x512x?xf32>
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
// CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
// CHECK: affine.for [[varg4:%.*]] = 0 to 512 {
// CHECK: affine.for [[varg5:%.*]] = 0 to [[v3]] {
// CHECK: [[v19:%.*]] = memref.load [[valloc]][[[varg3]], [[varg4]], [[varg5]]] : memref<?x512x?xf32>
// CHECK: affine.store [[v19]], [[valloc_0]][[[varg4]], [[varg3]] * symbol([[v3]]) + [[varg5]]] : memref<512x2048xf32>
// CHECK: memref.dealloc [[valloc]] : memref<?x512x?xf32>
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
// CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
// CHECK: cf.assert [[v3]]
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x512x512xf32>
// CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<4x512x512xf32>
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<4x512x512xf32> to tensor<4x512x512xf32>
// CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x4x512xf32>
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<4x512x512xf32>) outs([[v5]] : tensor<512x4x512xf32>) permutation = [1, 0, 2]
// CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<512x4x512xf32> into tensor<512x2048xf32>
%all_gather = shard.all_gather %arg0 on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
// CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x256xf32>
// CHECK: [[v6:%.*]] = tensor.empty() : tensor<512x256xf32>
%0 = tensor.empty() : tensor<512x256xf32>
// CHECK: [[v6:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v5]] : tensor<512x256xf32>) -> tensor<512x256xf32>
// CHECK: [[v7:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x256xf32>) -> tensor<512x256xf32>
// CHECK: [[v7:%.*]] = linalg.matmul ins([[v4]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
// CHECK: [[v8:%.*]] = linalg.matmul ins([[vcollapsed]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v7]] : tensor<512x256xf32>) -> tensor<512x256xf32>
%2 = linalg.matmul ins(%all_gather, %arg1 : tensor<512x2048xf32>, tensor<2048x256xf32>) outs(%1 : tensor<512x256xf32>) -> tensor<512x256xf32>
// CHECK: [[v8:%.*]] = tosa.sigmoid [[v7]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
// CHECK: [[v9:%.*]] = tosa.sigmoid [[v8]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
%3 = tosa.sigmoid %2 : (tensor<512x256xf32>) -> tensor<512x256xf32>
%4 = tensor.empty() : tensor<512x2048xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
%proc_linear_idx = shard.process_multi_index on @grid_1d_4 axes = [0] : index
%grid_shape = shard.grid_shape @grid_1d_4 axes = [0] : index
%6 = arith.cmpi eq, %proc_linear_idx, %c0 : index
// CHECK: [[v14:%.*]] = scf.if
// CHECK: [[v15:%.*]] = scf.if
%7 = scf.if %6 -> (tensor<512x2048xf32>) {
scf.yield %5 : tensor<512x2048xf32>
} else {
@ -484,15 +495,15 @@ func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
scf.yield %10 : tensor<512x2048xf32>
}
// CHECK: [[v15:%.*]] = linalg.matmul ins([[v8]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v14]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
// CHECK: [[v16:%.*]] = linalg.matmul ins([[v9]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v15]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
%8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
// CHECK: [[v16:%.*]] = bufferization.to_buffer
// CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<512x2048xf32>
// CHECK: linalg.copy ins([[v16]] : memref<512x2048xf32>) outs([[valloc_1]] : memref<512x2048xf32>)
// CHECK: [[v17:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: mpi.allreduce([[valloc_1]], [[valloc_1]], MPI_SUM, [[v17]]) : memref<512x2048xf32>, memref<512x2048xf32>
// CHECK: [[v18:%.*]] = bufferization.to_tensor [[valloc_1]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
// CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] : tensor<512x2048xf32> to memref<512x2048xf32>
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
// CHECK: linalg.copy ins([[v17]] : memref<512x2048xf32>) outs([[valloc_0]] : memref<512x2048xf32>)
// CHECK: [[v18:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: mpi.allreduce([[valloc_0]], [[valloc_0]], MPI_SUM, [[v18]]) : memref<512x2048xf32>, memref<512x2048xf32>
// CHECK: [[v19:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
%all_reduce = shard.all_reduce %8 on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
// CHECK: return [[v18]] : tensor<512x2048xf32>
// CHECK: return [[v19]] : tensor<512x2048xf32>
return %all_reduce : tensor<512x2048xf32>
}

View File

@ -4,6 +4,7 @@
shard.grid @grid_1d(shape = 2)
shard.grid @grid_1d_4(shape = 4)
shard.grid @grid_2d_16(shape = 4x4)
// CHECK-LABEL: func @return_sharding
func.func @return_sharding(
@ -318,9 +319,9 @@ func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
return %sharded_ret : tensor<6xi32>
}
// CHECK-LABEL: func.func @mlp_1dgrid
// CHECK-LABEL: func.func @mlp_1d_weight_stationary
// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
func.func @mlp_1dgrid(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
func.func @mlp_1d_weight_stationary(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
// CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
%sharding = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
%sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding