[mlir][shard,mpi] Allowing 2d-grids and simplifying lowering shard.all_gather (#180243)
- fixing incorrect assertion and related function name - MPI_comm_split is not pure - simplifying/standardizing permutation in all_gather --------- Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
This commit is contained in:
parent
5e0e389360
commit
a6929f7937
@ -103,7 +103,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
|
||||
// CommSplitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
|
||||
def MPI_CommSplitOp : MPI_Op<"comm_split"> {
|
||||
let summary = "Partition the group associated with the given communicator into "
|
||||
"disjoint subgroups";
|
||||
let description = [{
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
@ -624,9 +624,8 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
|
||||
// shard.allgather concatenates along a specified gather-axis.
|
||||
// mpi.allgather always concatenates along the first dimension and
|
||||
// there is no MPI operation that allows gathering along an arbitrary axis.
|
||||
// Hence, if gather-axis!=0, we need to create a temporary buffer
|
||||
// where we gather along the first dimension and then copy from that
|
||||
// buffer to the final output along the specified gather-axis.
|
||||
// Hence, if gather-axis != 0, we need to permute the output buffer
|
||||
// accordingly.
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
|
||||
@ -635,104 +634,124 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
|
||||
FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
|
||||
if (failed(gridOp))
|
||||
return failure();
|
||||
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
|
||||
Value input = getAsMemref(adaptor.getInput(), iBuilder);
|
||||
|
||||
ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
|
||||
Value input = getAsMemref(adaptor.getInput(), ib);
|
||||
MemRefType inType = cast<MemRefType>(input.getType());
|
||||
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
|
||||
return op.emitError(
|
||||
"Expected static shaped memref in contiguous row-major layout.");
|
||||
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
|
||||
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
|
||||
return op.emitError(
|
||||
"Expected static shaped memref in contiguous row-major layout.");
|
||||
auto inputShape = inType.getShape();
|
||||
auto outputShape = outType.getShape();
|
||||
int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
|
||||
auto ctx = op->getContext();
|
||||
int64_t inputDimOnAxis = inputShape[gatherAxis];
|
||||
int64_t outputDimOnAxis = outputShape[gatherAxis];
|
||||
|
||||
// Get the right communicator
|
||||
Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
|
||||
|
||||
Value nRanks =
|
||||
mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
|
||||
.getSize();
|
||||
nRanks =
|
||||
arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
|
||||
|
||||
Value tmpOutput, gatherDimSz;
|
||||
if (gatherAxis == 0) {
|
||||
tmpOutput = memref::AllocOp::create(iBuilder, outType);
|
||||
} else {
|
||||
// MPI's allgather always concatenates along the first dimension.
|
||||
// Create a memref type for the output buffer with adjusted (expanded)
|
||||
// shape.
|
||||
SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
|
||||
llvm::append_range(gatherShape, outType.getShape());
|
||||
gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
|
||||
MemRefType gatherType =
|
||||
MemRefType::get(gatherShape, outType.getElementType());
|
||||
gatherDimSz = arith::ConstantIndexOp::create(
|
||||
iBuilder, outType.getDimSize(gatherAxis));
|
||||
gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
|
||||
gatherDimSz, nRanks);
|
||||
// Allocate output buffer
|
||||
tmpOutput =
|
||||
memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
|
||||
for (size_t i = 0; i < outputShape.size(); ++i)
|
||||
if (outputShape[i] != inputShape[i] && i != (size_t)gatherAxis)
|
||||
return op.emitError(
|
||||
"Result and input shapes must match along non-gather axes.");
|
||||
if (inputDimOnAxis == 0)
|
||||
return op.emitError("Input size along the gather axis must be non-zero.");
|
||||
if (inputDimOnAxis == 1) {
|
||||
assert(outputDimOnAxis == inputDimOnAxis);
|
||||
rewriter.replaceOp(op, adaptor.getInput());
|
||||
return success();
|
||||
}
|
||||
if (outputDimOnAxis % inputDimOnAxis != 0)
|
||||
return op.emitError("Result size along the gather axis must be an exact "
|
||||
"multiple of the input size along the gather axis.");
|
||||
|
||||
if (!memref::isStaticShapeAndContiguousRowMajor(inType) ||
|
||||
!memref::isStaticShapeAndContiguousRowMajor(outType))
|
||||
return op.emitError("Input/result must be statically shaped memrefs in "
|
||||
"contiguous row-major layout.");
|
||||
|
||||
// Get the right communicator.
|
||||
Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
|
||||
Value nRanksV =
|
||||
mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
|
||||
nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
|
||||
int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
|
||||
Value nRanksC = arith::ConstantIndexOp::create(ib, nRanks);
|
||||
Value notError =
|
||||
arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
|
||||
cf::AssertOp::create(ib, notError,
|
||||
"Expected number of ranks in the communicator to "
|
||||
"match the output size along the gather axis divided "
|
||||
"by the input size along the gather axis.");
|
||||
|
||||
// mpi.allgather always concatenates along the first dimension, so
|
||||
// get a output buffer of shape {nRanks, dim0, ...}.
|
||||
SmallVector<int64_t> gatherShape;
|
||||
gatherShape.emplace_back(nRanks);
|
||||
gatherShape.append(inputShape.begin(), inputShape.end());
|
||||
auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
|
||||
Value finalOutput = memref::AllocOp::create(ib, gatherType);
|
||||
// Create the MPI AllGather operation.
|
||||
mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
|
||||
mpi::AllGatherOp::create(ib, TypeRange(), input, finalOutput, comm);
|
||||
|
||||
// If gather-axis!=0, copy from gathered buffer to output with the right
|
||||
// layout.
|
||||
Value finalOutput = tmpOutput;
|
||||
if (gatherAxis != 0) {
|
||||
int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
|
||||
assert(nSrcDims == outType.getRank() + 1 &&
|
||||
"Expected gathered type to have rank one more than output type.");
|
||||
if (gatherAxis == 0) {
|
||||
// If gather axis == 0, simply collapse the first 2 dims from {nRanks,
|
||||
// dim0, ...} to {nRanks*dim0, ...}.
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
reassociation.push_back({0, 1});
|
||||
int64_t numGatherDims = gatherShape.size();
|
||||
for (int64_t i = 2; i < numGatherDims; ++i)
|
||||
reassociation.push_back({i});
|
||||
finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
|
||||
reassociation);
|
||||
|
||||
// Create affine map for copying from gathered buffer to output.
|
||||
SmallVector<AffineExpr> dims;
|
||||
dims.reserve(nSrcDims);
|
||||
for (unsigned i = 0; i < nSrcDims; ++i)
|
||||
dims.emplace_back(getAffineDimExpr(i, ctx));
|
||||
AffineExpr s = getAffineSymbolExpr(0, ctx);
|
||||
SmallVector<AffineExpr> results;
|
||||
results.reserve(nSrcDims);
|
||||
for (unsigned i = 0; i < nSrcDims - 1; ++i) {
|
||||
if (i == gatherAxis)
|
||||
results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
|
||||
else
|
||||
results.emplace_back(dims[i + 1]);
|
||||
// If the op's result is a tensor, cast it to a tensor.
|
||||
if (isa<RankedTensorType>(op.getType()))
|
||||
finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
|
||||
finalOutput, true);
|
||||
} else {
|
||||
// 1. Enter tensor-land.
|
||||
auto inType =
|
||||
RankedTensorType::get(gatherShape, outType.getElementType());
|
||||
finalOutput =
|
||||
bufferization::ToTensorOp::create(ib, inType, finalOutput, true);
|
||||
|
||||
// 2. Permute the output buffer from {nRanks, dim0, ..., gatherAxis, ...}
|
||||
// to {dim0, ..., nRanks, dim1,...}.
|
||||
SmallVector<int64_t> outShapePermuted, permutation;
|
||||
for (int i = 1; i <= gatherAxis; ++i) {
|
||||
outShapePermuted.emplace_back(gatherShape[i]);
|
||||
permutation.emplace_back(i);
|
||||
}
|
||||
auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
|
||||
outShapePermuted.emplace_back(gatherShape[0]);
|
||||
permutation.emplace_back(0);
|
||||
for (size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
|
||||
outShapePermuted.emplace_back(gatherShape[i]);
|
||||
permutation.emplace_back(i);
|
||||
}
|
||||
Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
|
||||
outType.getElementType());
|
||||
finalOutput =
|
||||
linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
|
||||
->getResult(0);
|
||||
|
||||
finalOutput = memref::AllocOp::create(iBuilder, outType);
|
||||
// 3. Collapse the output buffer from {dim0, ..., nRanks, gatherAxis, ...}
|
||||
// to {dim0, ..., nRanks*gatherAxis, ...}.
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
for (int64_t i = 0; i < gatherAxis; ++i) {
|
||||
reassociation.push_back({i});
|
||||
}
|
||||
reassociation.push_back({gatherAxis, gatherAxis + 1});
|
||||
for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
|
||||
++i) {
|
||||
reassociation.push_back({i});
|
||||
}
|
||||
auto outTType =
|
||||
RankedTensorType::get(outputShape, outType.getElementType());
|
||||
finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
|
||||
reassociation);
|
||||
|
||||
// Now build a loop nest to copy from gathered buffer to finalOutput
|
||||
// It would be nicer to just use a memref.transpose/collapse_shape op but
|
||||
// these currently only support simpler cases.
|
||||
Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
|
||||
SmallVector<Value> lbs(nSrcDims, zero);
|
||||
SmallVector<Value> ubs;
|
||||
for (int64_t d = 0; d < nSrcDims; ++d)
|
||||
ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
|
||||
SmallVector<int64_t> steps(nSrcDims, 1);
|
||||
auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
|
||||
Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
|
||||
// set symbol value
|
||||
SmallVector<Value> ivss(ivs.begin(), ivs.end());
|
||||
ivss.emplace_back(gatherDimSz);
|
||||
affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
|
||||
ivss);
|
||||
};
|
||||
affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
|
||||
emitCopy);
|
||||
|
||||
memref::DeallocOp::create(iBuilder, tmpOutput);
|
||||
// 4. Cast back to memref if needed.
|
||||
if (isa<MemRefType>(op.getType()))
|
||||
finalOutput =
|
||||
bufferization::ToBufferOp::create(ib, outType, finalOutput);
|
||||
}
|
||||
|
||||
// If the destination is a tensor, cast it to a tensor
|
||||
if (isa<RankedTensorType>(op.getType()))
|
||||
finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
|
||||
finalOutput, true);
|
||||
rewriter.replaceOp(op, finalOutput);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -436,24 +436,32 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
|
||||
targetSharding);
|
||||
}
|
||||
|
||||
// Handles only resharding on a 1D shard.
|
||||
// Currently the sharded tensor axes must be exactly divisible by the single
|
||||
// grid axis size.
|
||||
// In most cases the sharded tensor axes must be exactly divisible by the single
|
||||
// grid axis size. Only halo size changes can deal with non-divisible cases.
|
||||
static TypedValue<ShapedType>
|
||||
reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
|
||||
const Sharding &sourceSharding, const Sharding &targetSharding,
|
||||
TypedValue<ShapedType> sourceUnshardedValue,
|
||||
TypedValue<ShapedType> sourceShard) {
|
||||
reshard(ImplicitLocOpBuilder &builder, GridOp grid,
|
||||
const Sharding &sourceSharding, const Sharding &targetSharding,
|
||||
TypedValue<ShapedType> sourceUnshardedValue,
|
||||
TypedValue<ShapedType> sourceShard) {
|
||||
// If source and destination sharding are the same, no need to do anything.
|
||||
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
|
||||
isFullReplication(targetSharding))) {
|
||||
return sourceShard;
|
||||
}
|
||||
|
||||
// Tries to handle the case where the resharding is needed because the halo
|
||||
// sizes are different. Supports arbitrary grid dimensionality.
|
||||
if (auto tryRes = tryUpdateHaloInResharding(
|
||||
builder, grid, sourceSharding, targetSharding,
|
||||
sourceUnshardedValue.getType(), sourceShard)) {
|
||||
return std::get<0>(tryRes.value()); // targetShard
|
||||
}
|
||||
|
||||
assert(sourceShard.getType() ==
|
||||
shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
|
||||
[[maybe_unused]] ShapedType targetShardType =
|
||||
shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
|
||||
assert(sourceShard.getType().getRank() == targetShardType.getRank());
|
||||
assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
|
||||
|
||||
if (sourceSharding == targetSharding) {
|
||||
return sourceShard;
|
||||
}
|
||||
|
||||
TypedValue<ShapedType> targetShard;
|
||||
Sharding actualTargetSharding;
|
||||
@ -475,38 +483,13 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
|
||||
std::tie(targetShard, actualTargetSharding) = tryRes.value();
|
||||
}
|
||||
}
|
||||
|
||||
assert(targetShard && "Did not find any pattern to apply.");
|
||||
assert(actualTargetSharding == targetSharding);
|
||||
assert(targetShard.getType() == targetShardType);
|
||||
return targetShard;
|
||||
}
|
||||
|
||||
static TypedValue<ShapedType>
|
||||
reshard(ImplicitLocOpBuilder &builder, GridOp grid,
|
||||
const Sharding &sourceSharding, const Sharding &targetSharding,
|
||||
TypedValue<ShapedType> sourceUnshardedValue,
|
||||
TypedValue<ShapedType> sourceShard) {
|
||||
// If source and destination sharding are the same, no need to do anything.
|
||||
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
|
||||
isFullReplication(targetSharding))) {
|
||||
return sourceShard;
|
||||
}
|
||||
|
||||
// Tries to handle the case where the resharding is needed because the halo
|
||||
// sizes are different. Supports arbitrary grid dimensionality.
|
||||
if (auto tryRes = tryUpdateHaloInResharding(
|
||||
builder, grid, sourceSharding, targetSharding,
|
||||
sourceUnshardedValue.getType(), sourceShard)) {
|
||||
return std::get<0>(tryRes.value()); // targetShard
|
||||
}
|
||||
|
||||
// Resort to handling only 1D grids since the general case is complicated if
|
||||
// it needs to be communication efficient in terms of minimizing the data
|
||||
// transfered between devices.
|
||||
return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
|
||||
sourceUnshardedValue, sourceShard);
|
||||
}
|
||||
|
||||
TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
|
||||
ShardOp target,
|
||||
TypedValue<ShapedType> sourceShardValue) {
|
||||
|
||||
@ -148,6 +148,28 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
|
||||
return %0 : memref<3x4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @allgather_tensor_0
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
|
||||
func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {
|
||||
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
|
||||
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
|
||||
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
|
||||
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
|
||||
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
|
||||
// CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
|
||||
// CHECK: cf.assert [[v3]]
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x3x4xf32>
|
||||
// CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<4x3x4xf32>
|
||||
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1], [2]] : memref<4x3x4xf32> into memref<12x4xf32>
|
||||
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[vcollapse_shape]] restrict : memref<12x4xf32> to tensor<12x4xf32>
|
||||
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0 : tensor<3x4xf32> -> tensor<12x4xf32>
|
||||
// CHECK: return [[v4]] : tensor<12x4xf32>
|
||||
return %0 : tensor<12x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @allgather_tensor
|
||||
func.func @allgather_tensor(
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
|
||||
@ -155,28 +177,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
|
||||
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
|
||||
// CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
|
||||
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
|
||||
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
|
||||
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
|
||||
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
|
||||
// CHECK: [[v3:%.*]] = arith.divsi [[vc20]], [[v2]] : index
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x3x?xf32>
|
||||
// CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
|
||||
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
|
||||
// CHECK: affine.for [[varg1:%.*]] = 0 to [[v2]] {
|
||||
// CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
|
||||
// CHECK: affine.for [[varg3:%.*]] = 0 to [[v3]] {
|
||||
// CHECK: [[v5:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
|
||||
// CHECK: affine.store [[v5]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v3]]) + [[varg3]]] : memref<3x20xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
|
||||
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<3x20xf32> to tensor<3x20xf32>
|
||||
// CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc5]] : index
|
||||
// CHECK: cf.assert [[v3]]
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
|
||||
// CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
|
||||
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
|
||||
// CHECK: [[v5:%.*]] = tensor.empty() : tensor<3x5x4xf32>
|
||||
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<5x3x4xf32>) outs([[v5]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
|
||||
// CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
|
||||
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32>
|
||||
// CHECK: return [[v4]] : tensor<3x20xf32>
|
||||
// CHECK: return [[vcollapsed]] : tensor<3x20xf32>
|
||||
return %0 : tensor<3x20xf32>
|
||||
}
|
||||
|
||||
@ -185,28 +201,24 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
|
||||
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
|
||||
// CHECK-SAME: -> memref<3x20xf32>
|
||||
%arg0 : memref<3x4xf32>) -> memref<3x20xf32> {
|
||||
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
|
||||
// CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
|
||||
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
|
||||
// CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
|
||||
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
|
||||
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
|
||||
// CHECK: [[v1:%.*]] = arith.index_cast [[vsize]] : i32 to index
|
||||
// CHECK: [[v2:%.*]] = arith.divsi [[vc20]], [[v1]] : index
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc([[v1]], [[v2]]) : memref<?x3x?xf32>
|
||||
// CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
|
||||
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
|
||||
// CHECK: affine.for [[varg1:%.*]] = 0 to [[v1]] {
|
||||
// CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
|
||||
// CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
|
||||
// CHECK: [[v3:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
|
||||
// CHECK: affine.store [[v3]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v2]]) + [[varg3]]] : memref<3x20xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
|
||||
// CHECK: [[v2:%.*]] = arith.cmpi eq, [[v1]], [[vc5]] : index
|
||||
// CHECK: cf.assert [[v2]]
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
|
||||
// CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
|
||||
// CHECK: [[v3:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
|
||||
// CHECK: [[v4:%.*]] = tensor.empty() : tensor<3x5x4xf32>
|
||||
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v3]] : tensor<5x3x4xf32>) outs([[v4]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
|
||||
// CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
|
||||
// CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] : tensor<3x20xf32> to memref<3x20xf32>
|
||||
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
|
||||
// CHECK: return [[valloc_0]] : memref<3x20xf32>
|
||||
// CHECK: return [[v5]] : memref<3x20xf32>
|
||||
return %0 : memref<3x20xf32>
|
||||
}
|
||||
}
|
||||
@ -377,9 +389,9 @@ shard.grid @grid0(shape = 2x2x4)
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
|
||||
func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) {
|
||||
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding
|
||||
// CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK-DAG: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
@ -397,10 +409,10 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.s
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
|
||||
func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) {
|
||||
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding
|
||||
// CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
|
||||
// CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK-DAG: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
@ -417,12 +429,12 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !s
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
|
||||
func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) {
|
||||
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding
|
||||
// CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
|
||||
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
|
||||
// CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
|
||||
// CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK-DAG: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
|
||||
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
|
||||
// CHECK-DAG: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
|
||||
// CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
|
||||
// CHECK-DAG: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
|
||||
// CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
|
||||
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
|
||||
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
|
||||
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
|
||||
@ -444,39 +456,38 @@ shard.grid @grid_1d_4(shape = 4)
|
||||
// CHECK-LABEL: func.func @mlp_1dgrid(
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
|
||||
func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %arg2: tensor<256x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
|
||||
// CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: [[vc0:%.*]] = arith.constant 0 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
|
||||
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
|
||||
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: [[vsize:%.*]] = mpi.comm_size
|
||||
// CHECK: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
|
||||
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
|
||||
// CHECK: [[v3:%.*]] = arith.divsi
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x512x?xf32>
|
||||
// CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<?x512x?xf32>
|
||||
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
|
||||
// CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
|
||||
// CHECK: affine.for [[varg4:%.*]] = 0 to 512 {
|
||||
// CHECK: affine.for [[varg5:%.*]] = 0 to [[v3]] {
|
||||
// CHECK: [[v19:%.*]] = memref.load [[valloc]][[[varg3]], [[varg4]], [[varg5]]] : memref<?x512x?xf32>
|
||||
// CHECK: affine.store [[v19]], [[valloc_0]][[[varg4]], [[varg3]] * symbol([[v3]]) + [[varg5]]] : memref<512x2048xf32>
|
||||
// CHECK: memref.dealloc [[valloc]] : memref<?x512x?xf32>
|
||||
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
|
||||
// CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
|
||||
// CHECK: cf.assert [[v3]]
|
||||
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x512x512xf32>
|
||||
// CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<4x512x512xf32>
|
||||
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<4x512x512xf32> to tensor<4x512x512xf32>
|
||||
// CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x4x512xf32>
|
||||
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<4x512x512xf32>) outs([[v5]] : tensor<512x4x512xf32>) permutation = [1, 0, 2]
|
||||
// CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<512x4x512xf32> into tensor<512x2048xf32>
|
||||
%all_gather = shard.all_gather %arg0 on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
|
||||
// CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x256xf32>
|
||||
// CHECK: [[v6:%.*]] = tensor.empty() : tensor<512x256xf32>
|
||||
%0 = tensor.empty() : tensor<512x256xf32>
|
||||
// CHECK: [[v6:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v5]] : tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
// CHECK: [[v7:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
// CHECK: [[v7:%.*]] = linalg.matmul ins([[v4]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
// CHECK: [[v8:%.*]] = linalg.matmul ins([[vcollapsed]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v7]] : tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
%2 = linalg.matmul ins(%all_gather, %arg1 : tensor<512x2048xf32>, tensor<2048x256xf32>) outs(%1 : tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
// CHECK: [[v8:%.*]] = tosa.sigmoid [[v7]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
// CHECK: [[v9:%.*]] = tosa.sigmoid [[v8]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
%3 = tosa.sigmoid %2 : (tensor<512x256xf32>) -> tensor<512x256xf32>
|
||||
%4 = tensor.empty() : tensor<512x2048xf32>
|
||||
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
|
||||
%proc_linear_idx = shard.process_multi_index on @grid_1d_4 axes = [0] : index
|
||||
%grid_shape = shard.grid_shape @grid_1d_4 axes = [0] : index
|
||||
%6 = arith.cmpi eq, %proc_linear_idx, %c0 : index
|
||||
// CHECK: [[v14:%.*]] = scf.if
|
||||
// CHECK: [[v15:%.*]] = scf.if
|
||||
%7 = scf.if %6 -> (tensor<512x2048xf32>) {
|
||||
scf.yield %5 : tensor<512x2048xf32>
|
||||
} else {
|
||||
@ -484,15 +495,15 @@ func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %
|
||||
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
|
||||
scf.yield %10 : tensor<512x2048xf32>
|
||||
}
|
||||
// CHECK: [[v15:%.*]] = linalg.matmul ins([[v8]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v14]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
|
||||
// CHECK: [[v16:%.*]] = linalg.matmul ins([[v9]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v15]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
|
||||
%8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
|
||||
// CHECK: [[v16:%.*]] = bufferization.to_buffer
|
||||
// CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<512x2048xf32>
|
||||
// CHECK: linalg.copy ins([[v16]] : memref<512x2048xf32>) outs([[valloc_1]] : memref<512x2048xf32>)
|
||||
// CHECK: [[v17:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: mpi.allreduce([[valloc_1]], [[valloc_1]], MPI_SUM, [[v17]]) : memref<512x2048xf32>, memref<512x2048xf32>
|
||||
// CHECK: [[v18:%.*]] = bufferization.to_tensor [[valloc_1]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
|
||||
// CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] : tensor<512x2048xf32> to memref<512x2048xf32>
|
||||
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
|
||||
// CHECK: linalg.copy ins([[v17]] : memref<512x2048xf32>) outs([[valloc_0]] : memref<512x2048xf32>)
|
||||
// CHECK: [[v18:%.*]] = mpi.comm_world : !mpi.comm
|
||||
// CHECK: mpi.allreduce([[valloc_0]], [[valloc_0]], MPI_SUM, [[v18]]) : memref<512x2048xf32>, memref<512x2048xf32>
|
||||
// CHECK: [[v19:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
|
||||
%all_reduce = shard.all_reduce %8 on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
|
||||
// CHECK: return [[v18]] : tensor<512x2048xf32>
|
||||
// CHECK: return [[v19]] : tensor<512x2048xf32>
|
||||
return %all_reduce : tensor<512x2048xf32>
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
shard.grid @grid_1d(shape = 2)
|
||||
shard.grid @grid_1d_4(shape = 4)
|
||||
shard.grid @grid_2d_16(shape = 4x4)
|
||||
|
||||
// CHECK-LABEL: func @return_sharding
|
||||
func.func @return_sharding(
|
||||
@ -318,9 +319,9 @@ func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
|
||||
return %sharded_ret : tensor<6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mlp_1dgrid
|
||||
// CHECK-LABEL: func.func @mlp_1d_weight_stationary
|
||||
// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
|
||||
func.func @mlp_1dgrid(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
|
||||
func.func @mlp_1d_weight_stationary(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
|
||||
// CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
|
||||
%sharding = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
|
||||
%sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user