955 lines
40 KiB
C++
955 lines
40 KiB
C++
//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a translation of Shard communication ops to MPI ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/MPI/IR/MPI.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Shard/IR/ShardDialect.h"
|
|
#include "mlir/Dialect/Shard/IR/ShardOps.h"
|
|
#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
|
|
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#define DEBUG_TYPE "shard-to-mpi"
|
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace shard;
|
|
|
|
namespace {
|
|
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
|
|
/// provided type.
|
|
static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
|
|
llvm::ArrayRef<int64_t> statics,
|
|
ValueRange dynamics,
|
|
Type type = Type()) {
|
|
SmallVector<Value> values;
|
|
auto dyn = dynamics.begin();
|
|
Type i64 = b.getI64Type();
|
|
if (!type)
|
|
type = i64;
|
|
assert((i64 == type || b.getIndexType() == type) &&
|
|
"expected an i64 or an intex type");
|
|
for (auto s : statics) {
|
|
if (s == ShapedType::kDynamic) {
|
|
values.emplace_back(*(dyn++));
|
|
} else {
|
|
TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
|
|
values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
|
|
}
|
|
}
|
|
return values;
|
|
}
|
|
|
|
/// Create operations converting a linear index to a multi-dimensional index.
|
|
static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
|
|
Value linearIndex,
|
|
ValueRange dimensions) {
|
|
int n = dimensions.size();
|
|
SmallVector<Value> multiIndex(n);
|
|
|
|
for (int i = n - 1; i >= 0; --i) {
|
|
multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
|
|
if (i > 0)
|
|
linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
|
|
}
|
|
|
|
return multiIndex;
|
|
}
|
|
|
|
/// Create operations converting a multi-dimensional index to a linear index.
|
|
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
|
|
ValueRange dimensions) {
|
|
|
|
Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
|
|
Value stride = arith::ConstantIndexOp::create(b, loc, 1);
|
|
|
|
for (int i = multiIndex.size() - 1; i >= 0; --i) {
|
|
Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
|
|
linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
|
|
stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
|
|
}
|
|
|
|
return linearIndex;
|
|
}
|
|
|
|
/// Replace GetShardingOp with related/dependent ShardingOp.
|
|
struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
|
|
if (!shardOp)
|
|
return failure();
|
|
auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
|
|
if (!shardingOp)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(op, shardingOp.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Convert a sharding op to a tuple of tensors of its components
|
|
/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
|
|
/// as defined by type converter.
|
|
struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto splitAxes = op.getSplitAxes().getAxes();
|
|
int64_t maxNAxes = 0;
|
|
for (auto axes : splitAxes)
|
|
maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
|
|
|
|
// To hold the split axes, create empty 2d tensor with shape
|
|
// {splitAxes.size(), max-size-of-split-groups}.
|
|
// Set trailing elements for smaller split-groups to -1.
|
|
Location loc = op.getLoc();
|
|
auto i16 = rewriter.getI16Type();
|
|
auto i64 = rewriter.getI64Type();
|
|
std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
|
|
maxNAxes};
|
|
Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
|
|
auto attr = IntegerAttr::get(i16, -1);
|
|
Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
|
|
resSplitAxes =
|
|
linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
|
|
.getResult(0);
|
|
|
|
// explicitly write values into tensor row by row
|
|
std::array<int64_t, 2> strides = {1, 1};
|
|
int64_t nSplits = 0;
|
|
ValueRange empty = {};
|
|
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
|
|
int64_t size = axes.size();
|
|
if (size > 0)
|
|
++nSplits;
|
|
std::array<int64_t, 2> offs = {(int64_t)i, 0};
|
|
std::array<int64_t, 2> sizes = {1, size};
|
|
auto tensorType = RankedTensorType::get({size}, i16);
|
|
auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
|
|
auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
|
|
resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
|
|
resSplitAxes, empty, empty,
|
|
empty, offs, sizes, strides);
|
|
}
|
|
|
|
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
|
|
// Store the halo sizes in the tensor.
|
|
SmallVector<Value> haloSizes =
|
|
getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
|
|
adaptor.getDynamicHaloSizes());
|
|
auto type = RankedTensorType::get({nSplits, 2}, i64);
|
|
Value resHaloSizes =
|
|
haloSizes.empty()
|
|
? tensor::EmptyOp::create(rewriter, loc,
|
|
std::array<int64_t, 2>{0, 0}, i64)
|
|
.getResult()
|
|
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
|
|
.getResult();
|
|
|
|
// To hold sharded dims offsets, create Tensor with shape {nSplits,
|
|
// maxSplitSize+1}. Store the offsets in the tensor but set trailing
|
|
// elements for smaller split-groups to -1. Computing the max size of the
|
|
// split groups needs using collectiveProcessGroupSize (which needs the
|
|
// GridOp)
|
|
Value resOffsets;
|
|
if (adaptor.getStaticShardedDimsOffsets().empty()) {
|
|
resOffsets = tensor::EmptyOp::create(rewriter, loc,
|
|
std::array<int64_t, 2>{0, 0}, i64);
|
|
} else {
|
|
SymbolTableCollection symbolTableCollection;
|
|
auto gridOp = getGrid(op, symbolTableCollection);
|
|
int64_t maxSplitSize = 0;
|
|
for (auto axes : splitAxes) {
|
|
int64_t splitSize =
|
|
collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
|
|
assert(splitSize != ShapedType::kDynamic);
|
|
maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
|
|
}
|
|
assert(maxSplitSize);
|
|
++maxSplitSize; // add one for the total size
|
|
|
|
resOffsets = tensor::EmptyOp::create(
|
|
rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
|
|
Value zero = arith::ConstantOp::create(
|
|
rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
|
|
resOffsets =
|
|
linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
|
|
SmallVector<Value> offsets =
|
|
getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
|
|
adaptor.getDynamicShardedDimsOffsets());
|
|
int64_t curr = 0;
|
|
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
|
|
int64_t splitSize =
|
|
collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
|
|
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
|
|
++splitSize; // add one for the total size
|
|
ArrayRef<Value> values(&offsets[curr], splitSize);
|
|
Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
|
|
std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
|
|
std::array<int64_t, 2> sizes = {1, splitSize};
|
|
resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
|
|
resOffsets, empty, empty,
|
|
empty, offs, sizes, strides);
|
|
curr += splitSize;
|
|
}
|
|
}
|
|
|
|
// return a tuple of tensors as defined by type converter
|
|
SmallVector<Type> resTypes;
|
|
if (failed(getTypeConverter()->convertType(op.getResult().getType(),
|
|
resTypes)))
|
|
return failure();
|
|
|
|
resSplitAxes =
|
|
tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
|
|
resHaloSizes =
|
|
tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
|
|
resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
|
|
|
|
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
|
|
op, TupleType::get(op.getContext(), resTypes),
|
|
ValueRange{resSplitAxes, resHaloSizes, resOffsets});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertProcessMultiIndexOp
|
|
: public OpConversionPattern<ProcessMultiIndexOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// Currently converts its linear index to a multi-dimensional index.
|
|
|
|
SymbolTableCollection symbolTableCollection;
|
|
Location loc = op.getLoc();
|
|
auto gridOp = getGrid(op, symbolTableCollection);
|
|
// For now we only support static grid shapes
|
|
if (ShapedType::isDynamicShape(gridOp.getShape()))
|
|
return failure();
|
|
|
|
SmallVector<Value> dims;
|
|
llvm::transform(
|
|
gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
|
|
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
|
|
});
|
|
Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
|
|
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
|
|
|
|
// optionally extract subset of grid axes
|
|
auto axes = adaptor.getAxes();
|
|
if (!axes.empty()) {
|
|
SmallVector<Value> subIndex;
|
|
for (auto axis : axes) {
|
|
subIndex.emplace_back(mIdx[axis]);
|
|
}
|
|
mIdx = std::move(subIndex);
|
|
}
|
|
|
|
rewriter.replaceOp(op, mIdx);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ConvertProcessLinearIndexOp
|
|
: public OpConversionPattern<ProcessLinearIndexOp> {
|
|
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Create mpi::CommRankOp
|
|
Location loc = op.getLoc();
|
|
auto ctx = op.getContext();
|
|
Value commWorld =
|
|
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
|
|
auto rank = mpi::CommRankOp::create(
|
|
rewriter, loc,
|
|
TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
|
|
commWorld)
|
|
.getRank();
|
|
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
|
|
rank);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertNeighborsLinearIndicesOp
|
|
: public OpConversionPattern<NeighborsLinearIndicesOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// Computes the neighbors indices along a split axis by simply
|
|
// adding/subtracting 1 to the current index in that dimension.
|
|
// Assigns -1 if neighbor is out of bounds.
|
|
|
|
auto axes = adaptor.getSplitAxes();
|
|
// For now only single axis sharding is supported
|
|
if (axes.size() != 1)
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
SymbolTableCollection symbolTableCollection;
|
|
auto gridOp = getGrid(op, symbolTableCollection);
|
|
auto mIdx = adaptor.getDevice();
|
|
auto orgIdx = mIdx[axes[0]];
|
|
SmallVector<Value> dims;
|
|
llvm::transform(
|
|
gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
|
|
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
|
|
});
|
|
Value dimSz = dims[axes[0]];
|
|
Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
|
Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
|
|
Value atBorder =
|
|
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
|
|
arith::ConstantIndexOp::create(rewriter, loc, 0));
|
|
auto down = scf::IfOp::create(
|
|
rewriter, loc, atBorder,
|
|
[&](OpBuilder &builder, Location loc) {
|
|
scf::YieldOp::create(builder, loc, minus1);
|
|
},
|
|
[&](OpBuilder &builder, Location loc) {
|
|
SmallVector<Value> tmp = mIdx;
|
|
tmp[axes[0]] =
|
|
arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
|
|
.getResult();
|
|
scf::YieldOp::create(builder, loc,
|
|
multiToLinearIndex(loc, rewriter, tmp, dims));
|
|
});
|
|
atBorder = arith::CmpIOp::create(
|
|
rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
|
|
arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
|
|
auto up = scf::IfOp::create(
|
|
rewriter, loc, atBorder,
|
|
[&](OpBuilder &builder, Location loc) {
|
|
scf::YieldOp::create(builder, loc, minus1);
|
|
},
|
|
[&](OpBuilder &builder, Location loc) {
|
|
SmallVector<Value> tmp = mIdx;
|
|
tmp[axes[0]] =
|
|
arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
|
|
scf::YieldOp::create(builder, loc,
|
|
multiToLinearIndex(loc, rewriter, tmp, dims));
|
|
});
|
|
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
|
|
if (!sharding) {
|
|
return op->emitError()
|
|
<< "Expected ShardingOp as defining op for sharding"
|
|
<< " but found " << adaptor.getSharding()[0].getDefiningOp();
|
|
}
|
|
|
|
// Compute the sharded shape by applying the sharding to the input shape.
|
|
// If shardedDimsOffsets is not defined in the sharding, the shard shape is
|
|
// computed by dividing the dimension size by the number of shards in that
|
|
// dimension (which is given by the size of the grid axes provided in
|
|
// split-axes). Odd elements get distributed to trailing shards. If a
|
|
// shardedDimsOffsets is provided, the shard shape is computed by
|
|
// subtracting the offset of the current shard from the offset of the next
|
|
// shard.
|
|
|
|
Location loc = op.getLoc();
|
|
Type index = rewriter.getIndexType();
|
|
|
|
// This is a 1:N conversion because the sharding op is a 1:3 conversion.
|
|
// The operands in the adaptor are a vector<ValeRange>. For dims and device
|
|
// we have a 1:1 conversion.
|
|
// For simpler access fill a vector with the dynamic dims.
|
|
SmallVector<Value> dynDims, dynDevice;
|
|
for (auto dim : adaptor.getDimsDynamic()) {
|
|
// type conversion should be 1:1 for ints
|
|
dynDims.emplace_back(llvm::getSingleElement(dim));
|
|
}
|
|
// same for device
|
|
for (auto device : adaptor.getDeviceDynamic()) {
|
|
dynDevice.emplace_back(llvm::getSingleElement(device));
|
|
}
|
|
|
|
// To keep the code simple, convert dims/device to values when they are
|
|
// attributes. Count on canonicalization to fold static values.
|
|
SmallVector<Value> shape =
|
|
getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
|
|
SmallVector<Value> multiIdx =
|
|
getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
|
|
|
|
// Get the GridOp, the grid shape is needed to compute the sharded shape.
|
|
SymbolTableCollection symbolTableCollection;
|
|
auto gridOp = getGrid(sharding, symbolTableCollection);
|
|
// For now we only support static grid shapes
|
|
if (ShapedType::isDynamicShape(gridOp.getShape()))
|
|
return failure();
|
|
|
|
auto splitAxes = sharding.getSplitAxes().getAxes();
|
|
// shardedDimsOffsets are optional and might be Values (not attributes).
|
|
// Also, the shardId might be dynamic which means the position in the
|
|
// shardedDimsOffsets is not statically known. Create a tensor of the
|
|
// shardedDimsOffsets and later extract the offsets for computing the
|
|
// local shard-size.
|
|
Value shardedDimsOffs;
|
|
{
|
|
SmallVector<Value> tmp = getMixedAsValues(
|
|
rewriter, loc, sharding.getStaticShardedDimsOffsets(),
|
|
sharding.getDynamicShardedDimsOffsets(), index);
|
|
if (!tmp.empty())
|
|
shardedDimsOffs = tensor::FromElementsOp::create(
|
|
rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
|
|
tmp);
|
|
}
|
|
|
|
// With static grid shape the sizes of the split axes are known.
|
|
// Hence the start/pos for each split axes in shardDimsOffsets can be
|
|
// computed statically.
|
|
int64_t pos = 0;
|
|
SmallVector<Value> shardShape;
|
|
Value zero =
|
|
arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
|
|
Value one =
|
|
arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
|
|
|
|
// Iterate over the dimensions of the tensor shape, get their split Axes,
|
|
// and compute the sharded shape.
|
|
for (auto [i, dim] : llvm::enumerate(shape)) {
|
|
// Trailing dimensions might not be annotated.
|
|
if (i < splitAxes.size() && !splitAxes[i].empty()) {
|
|
auto axes = splitAxes[i];
|
|
// The current dimension might not be sharded.
|
|
// Create a value from the static position in shardDimsOffsets.
|
|
Value posVal = arith::ConstantOp::create(rewriter, loc,
|
|
rewriter.getIndexAttr(pos));
|
|
// Get the index of the local shard in the grid axis.
|
|
Value idx = multiIdx[axes[0]];
|
|
auto numShards =
|
|
collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
|
|
if (shardedDimsOffs) {
|
|
// If sharded dims offsets are provided, use them to compute the
|
|
// sharded shape.
|
|
if (axes.size() > 1) {
|
|
return op->emitError() << "Only single axis sharding is "
|
|
<< "supported for each dimension.";
|
|
}
|
|
idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
|
|
// Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
|
|
Value off =
|
|
tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
|
|
idx = arith::AddIOp::create(rewriter, loc, idx, one);
|
|
Value nextOff =
|
|
tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
|
|
Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
|
|
shardShape.emplace_back(sz);
|
|
} else {
|
|
Value numShardsVal = arith::ConstantOp::create(
|
|
rewriter, loc, rewriter.getIndexAttr(numShards));
|
|
// Compute shard dim size by distributing odd elements to trailing
|
|
// shards:
|
|
// sz = dim / numShards
|
|
// + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
|
|
Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
|
|
Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
|
|
sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
|
|
auto cond = arith::CmpIOp::create(
|
|
rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
|
|
Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
|
|
sz = arith::AddIOp::create(rewriter, loc, sz, odd);
|
|
shardShape.emplace_back(sz);
|
|
}
|
|
pos += numShards + 1; // add one for the total size.
|
|
} // else no sharding if split axis is empty or no split axis
|
|
// If no size was added -> no sharding in this dimension.
|
|
if (shardShape.size() <= i)
|
|
shardShape.emplace_back(dim);
|
|
}
|
|
assert(shardShape.size() == shape.size());
|
|
rewriter.replaceOp(op, shardShape);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
|
|
auto ctx = kind.getContext();
|
|
auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
|
|
return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
|
|
};
|
|
|
|
switch (kind.getValue()) {
|
|
case ReductionKind::Sum:
|
|
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
|
|
case ReductionKind::Product:
|
|
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
|
|
case ReductionKind::Min:
|
|
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
|
|
case ReductionKind::Max:
|
|
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
|
|
case ReductionKind::BitwiseAnd:
|
|
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
|
|
case ReductionKind::BitwiseOr:
|
|
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
|
|
case ReductionKind::BitwiseXor:
|
|
return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
|
|
default:
|
|
llvm_unreachable("Unknown/unsupported reduction kind");
|
|
}
|
|
}
|
|
|
|
struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SymbolTableCollection symbolTableCollection;
|
|
auto grid = adaptor.getGrid();
|
|
mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
|
|
if (!gridOp)
|
|
return op->emitError() << "No grid found for AllReduceOp";
|
|
if (ShapedType::isDynamicShape(gridOp.getShape()))
|
|
return op->emitError()
|
|
<< "Dynamic grid shape not supported in AllReduceOp";
|
|
|
|
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
|
|
Value input = adaptor.getInput();
|
|
auto inputShape = cast<ShapedType>(input.getType()).getShape();
|
|
|
|
// If the source is a memref, cast it to a tensor.
|
|
if (isa<RankedTensorType>(input.getType())) {
|
|
auto memrefType = MemRefType::get(
|
|
inputShape, cast<ShapedType>(input.getType()).getElementType());
|
|
input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
|
|
}
|
|
MemRefType inType = cast<MemRefType>(input.getType());
|
|
|
|
// Get the actual shape to allocate the buffer.
|
|
SmallVector<OpFoldResult> shape(inType.getRank());
|
|
for (auto i = 0; i < inType.getRank(); ++i) {
|
|
auto s = inputShape[i];
|
|
if (ShapedType::isDynamic(s))
|
|
shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
|
|
else
|
|
shape[i] = iBuilder.getIndexAttr(s);
|
|
}
|
|
|
|
// Allocate buffer and copy input to buffer.
|
|
Value buffer = memref::AllocOp::create(
|
|
iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
|
|
linalg::CopyOp::create(iBuilder, input, buffer);
|
|
|
|
// Get an MPI_Comm_split for the AllReduce operation.
|
|
// The color is the linear index of the process in the grid along the
|
|
// non-reduced axes. The key is the linear index of the process in the grid
|
|
// along the reduced axes.
|
|
SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
|
|
iBuilder.getIndexType());
|
|
SmallVector<Value> myMultiIndex =
|
|
ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
|
|
.getResult();
|
|
Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
|
|
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
|
|
|
|
auto redAxes = adaptor.getGridAxes();
|
|
for (auto axis : redAxes) {
|
|
multiKey[axis] = myMultiIndex[axis];
|
|
myMultiIndex[axis] = zero;
|
|
}
|
|
|
|
Value color =
|
|
createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
|
|
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
|
|
Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
|
|
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
|
|
|
|
// Finally split the communicator
|
|
auto commType = mpi::CommType::get(op->getContext());
|
|
Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
|
|
auto comm =
|
|
mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
|
|
.getNewcomm();
|
|
|
|
Value buffer1d = buffer;
|
|
// Collapse shape to 1d if needed
|
|
if (inType.getRank() > 1) {
|
|
ReassociationIndices reassociation(inType.getRank());
|
|
std::iota(reassociation.begin(), reassociation.end(), 0);
|
|
buffer1d = memref::CollapseShapeOp::create(
|
|
iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
|
|
}
|
|
|
|
// Create the MPI AllReduce operation.
|
|
mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
|
|
getMPIReductionOp(adaptor.getReductionAttr()),
|
|
comm);
|
|
|
|
// If the destination is a memref, cast it to a tensor
|
|
if (isa<RankedTensorType>(op.getType()))
|
|
buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
|
|
true);
|
|
|
|
rewriter.replaceOp(op, buffer);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// The input/output memref is assumed to be in C memory order.
|
|
// Halos are exchanged as 2 blocks per dimension (one for each side: down
|
|
// and up). For each haloed dimension `d`, the exchanged blocks are
|
|
// expressed as multi-dimensional subviews. The subviews include potential
|
|
// halos of higher dimensions `dh > d`, no halos for the lower dimensions
|
|
// `dl < d` and for dimension `d` the currently exchanged halo only.
|
|
// By iterating form higher to lower dimensions this also updates the halos
|
|
// in the 'corners'.
|
|
// memref.subview is used to read and write the halo data from and to the
|
|
// local data. Because subviews and halos can have mixed dynamic and static
|
|
// shapes, OpFoldResults are used whenever possible.
|
|
|
|
auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
|
|
adaptor.getHaloSizes(), rewriter);
|
|
if (haloSizes.empty()) {
|
|
// no halos -> nothing to do
|
|
rewriter.replaceOp(op, adaptor.getDestination());
|
|
return success();
|
|
}
|
|
|
|
SymbolTableCollection symbolTableCollection;
|
|
Location loc = op.getLoc();
|
|
|
|
// convert a OpFoldResult into a Value
|
|
auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
|
|
if (auto value = dyn_cast<Value>(v))
|
|
return value;
|
|
return arith::ConstantOp::create(
|
|
rewriter, loc,
|
|
rewriter.getIndexAttr(
|
|
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
|
|
};
|
|
|
|
auto dest = adaptor.getDestination();
|
|
auto dstShape = cast<ShapedType>(dest.getType()).getShape();
|
|
Value array = dest;
|
|
if (isa<RankedTensorType>(array.getType())) {
|
|
// If the destination is a memref, we need to cast it to a tensor
|
|
auto mmemrefType = MemRefType::get(
|
|
dstShape, cast<ShapedType>(array.getType()).getElementType());
|
|
array =
|
|
bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
|
|
}
|
|
auto rank = cast<ShapedType>(array.getType()).getRank();
|
|
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
|
|
auto grid = adaptor.getGrid();
|
|
auto gridOp = getGrid(op, symbolTableCollection);
|
|
// subviews need Index values
|
|
for (auto &sz : haloSizes) {
|
|
if (auto value = dyn_cast<Value>(sz))
|
|
sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
|
|
value)
|
|
.getResult();
|
|
}
|
|
|
|
// most of the offset/size/stride data is the same for all dims
|
|
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
|
SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
|
|
auto currHaloDim = -1; // halo sizes are provided for split dimensions only
|
|
// we need the actual shape to compute offsets and sizes
|
|
for (auto i = 0; i < rank; ++i) {
|
|
auto s = dstShape[i];
|
|
if (ShapedType::isDynamic(s))
|
|
shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
|
|
else
|
|
shape[i] = rewriter.getIndexAttr(s);
|
|
|
|
if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
|
|
++currHaloDim;
|
|
// the offsets for lower dim sstarts after their down halo
|
|
offsets[i] = haloSizes[currHaloDim * 2];
|
|
|
|
// prepare shape and offsets of highest dim's halo exchange
|
|
Value _haloSz = arith::AddIOp::create(
|
|
rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
|
|
toValue(haloSizes[currHaloDim * 2 + 1]));
|
|
// the halo shape of lower dims exlude the halos
|
|
dimSizes[i] =
|
|
arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
|
|
.getResult();
|
|
} else {
|
|
dimSizes[i] = shape[i];
|
|
}
|
|
}
|
|
|
|
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
|
|
auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
|
|
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
|
|
auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
|
|
|
|
SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
|
|
rewriter.getIndexType());
|
|
auto myMultiIndex =
|
|
ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
|
|
.getResult();
|
|
// traverse all split axes from high to low dim
|
|
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
|
|
auto splitAxes = opSplitAxes[dim];
|
|
if (splitAxes.empty())
|
|
continue;
|
|
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
|
|
// Get the linearized ids of the neighbors (down and up) for the
|
|
// given split
|
|
auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
|
|
myMultiIndex, splitAxes)
|
|
.getResults();
|
|
// MPI operates on i32...
|
|
Value neighbourIDs[2] = {
|
|
arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
|
|
tmp[0]),
|
|
arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
|
|
tmp[1])};
|
|
|
|
auto lowerRecvOffset = rewriter.getIndexAttr(0);
|
|
auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
|
|
auto upperRecvOffset =
|
|
arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
|
|
toValue(haloSizes[currHaloDim * 2 + 1]));
|
|
auto upperSendOffset = arith::SubIOp::create(
|
|
rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
|
|
|
|
Value commWorld = mpi::CommWorldOp::create(
|
|
rewriter, loc, mpi::CommType::get(op->getContext()));
|
|
|
|
// Make sure we send/recv in a way that does not lead to a dead-lock.
|
|
// The current approach is by far not optimal, this should be at least
|
|
// be a red-black pattern or using MPI_sendrecv.
|
|
// Also, buffers should be re-used.
|
|
// Still using temporary contiguous buffers for MPI communication...
|
|
// Still yielding a "serialized" communication pattern...
|
|
auto genSendRecv = [&](bool upperHalo) {
|
|
auto orgOffset = offsets[dim];
|
|
dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
|
|
: haloSizes[currHaloDim * 2];
|
|
// Check if we need to send and/or receive
|
|
// Processes on the grid borders have only one neighbor
|
|
auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
|
|
auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
|
|
auto hasFrom = arith::CmpIOp::create(
|
|
rewriter, loc, arith::CmpIPredicate::sge, from, zero);
|
|
auto hasTo = arith::CmpIOp::create(rewriter, loc,
|
|
arith::CmpIPredicate::sge, to, zero);
|
|
auto buffer = memref::AllocOp::create(
|
|
rewriter, loc, dimSizes,
|
|
cast<ShapedType>(array.getType()).getElementType());
|
|
// if has neighbor: copy halo data from array to buffer and send
|
|
scf::IfOp::create(
|
|
rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
|
|
offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
|
|
: OpFoldResult(upperSendOffset);
|
|
auto subview = memref::SubViewOp::create(
|
|
builder, loc, array, offsets, dimSizes, strides);
|
|
memref::CopyOp::create(builder, loc, subview, buffer);
|
|
mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
|
|
commWorld);
|
|
scf::YieldOp::create(builder, loc);
|
|
});
|
|
// if has neighbor: receive halo data into buffer and copy to array
|
|
scf::IfOp::create(
|
|
rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
|
|
offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
|
|
: OpFoldResult(lowerRecvOffset);
|
|
mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
|
|
commWorld);
|
|
auto subview = memref::SubViewOp::create(
|
|
builder, loc, array, offsets, dimSizes, strides);
|
|
memref::CopyOp::create(builder, loc, buffer, subview);
|
|
scf::YieldOp::create(builder, loc);
|
|
});
|
|
memref::DeallocOp::create(rewriter, loc, buffer);
|
|
offsets[dim] = orgOffset;
|
|
};
|
|
|
|
auto doSendRecv = [&](int upOrDown) {
|
|
OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
|
|
Value haloSz = dyn_cast<Value>(v);
|
|
if (!haloSz)
|
|
haloSz = arith::ConstantOp::create(
|
|
rewriter, loc,
|
|
rewriter.getI32IntegerAttr(
|
|
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
|
|
auto hasSize = arith::CmpIOp::create(
|
|
rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
|
|
scf::IfOp::create(rewriter, loc, hasSize,
|
|
[&](OpBuilder &builder, Location loc) {
|
|
genSendRecv(upOrDown > 0);
|
|
scf::YieldOp::create(builder, loc);
|
|
});
|
|
};
|
|
|
|
doSendRecv(0);
|
|
doSendRecv(1);
|
|
|
|
// the shape for lower dims include higher dims' halos
|
|
dimSizes[dim] = shape[dim];
|
|
// -> the offset for higher dims is always 0
|
|
offsets[dim] = rewriter.getIndexAttr(0);
|
|
// on to next halo
|
|
--currHaloDim;
|
|
}
|
|
|
|
if (isa<MemRefType>(op.getResult().getType())) {
|
|
rewriter.replaceOp(op, array);
|
|
} else {
|
|
assert(isa<RankedTensorType>(op.getResult().getType()));
|
|
rewriter.replaceOp(op, bufferization::ToTensorOp::create(
|
|
rewriter, loc, op.getResult().getType(), array,
|
|
/*restrict=*/true, /*writable=*/true));
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertShardToMPIPass
|
|
: public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
|
|
using Base::Base;
|
|
|
|
/// Run the dialect converter on the module.
|
|
void runOnOperation() override {
|
|
auto *ctxt = &getContext();
|
|
RewritePatternSet patterns(ctxt);
|
|
ConversionTarget target(getContext());
|
|
|
|
// Define a type converter to convert shard::ShardingType,
|
|
// mostly for use in return operations.
|
|
TypeConverter typeConverter;
|
|
typeConverter.addConversion([](Type type) { return type; });
|
|
|
|
// convert shard::ShardingType to a tuple of RankedTensorTypes
|
|
typeConverter.addConversion(
|
|
[](ShardingType type,
|
|
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
|
|
auto i16 = IntegerType::get(type.getContext(), 16);
|
|
auto i64 = IntegerType::get(type.getContext(), 64);
|
|
std::array<int64_t, 2> shp = {ShapedType::kDynamic,
|
|
ShapedType::kDynamic};
|
|
results.emplace_back(RankedTensorType::get(shp, i16));
|
|
results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
|
|
results.emplace_back(RankedTensorType::get(shp, i64));
|
|
return success();
|
|
});
|
|
|
|
// To 'extract' components, a UnrealizedConversionCastOp is expected
|
|
// to define the input
|
|
typeConverter.addTargetMaterialization(
|
|
[&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
|
|
Location loc) {
|
|
// Expecting a single input.
|
|
if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
|
|
return SmallVector<Value>();
|
|
auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
|
|
// Expecting an UnrealizedConversionCastOp.
|
|
if (!castOp)
|
|
return SmallVector<Value>();
|
|
// Fill a vector with elements of the tuple/castOp.
|
|
SmallVector<Value> results;
|
|
for (auto oprnd : castOp.getInputs()) {
|
|
if (!isa<RankedTensorType>(oprnd.getType()))
|
|
return SmallVector<Value>();
|
|
results.emplace_back(oprnd);
|
|
}
|
|
return results;
|
|
});
|
|
|
|
// No shard dialect should left after conversion...
|
|
target.addIllegalDialect<shard::ShardDialect>();
|
|
// ...except the global GridOp. GridShapeOp which will get folded later.
|
|
target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
|
|
// Allow all the stuff that our patterns will convert to
|
|
target.addLegalDialect<
|
|
BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
|
|
tensor::TensorDialect, bufferization::BufferizationDialect,
|
|
linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
|
|
// Make sure the function signature, calls etc. are legal
|
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
|
return typeConverter.isSignatureLegal(op.getFunctionType());
|
|
});
|
|
target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
|
|
[&](Operation *op) { return typeConverter.isLegal(op); });
|
|
|
|
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
|
|
ConvertProcessMultiIndexOp, ConvertGetShardingOp,
|
|
ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
|
|
ConvertProcessLinearIndexOp>(typeConverter, ctxt);
|
|
|
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
|
|
patterns, typeConverter);
|
|
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
|
populateReturnOpTypeConversionPattern(patterns, typeConverter);
|
|
|
|
(void)applyPartialConversion(getOperation(), target, std::move(patterns));
|
|
|
|
// Folding patterns cannot be mixed with conversion patterns -> extra pass.
|
|
patterns.clear();
|
|
SymbolTableCollection symbolTableCollection;
|
|
mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
} // namespace
|