
This PR uses `val.getDefiningOp<OpTy>()` to replace `dyn_cast<OpTy>(val.getDefiningOp())` , `dyn_cast_or_null<OpTy>(val.getDefiningOp())` and `dyn_cast_if_present<OpTy>(val.getDefiningOp())`.
412 lines
15 KiB
C++
412 lines
15 KiB
C++
//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
|
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
|
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/iterator_range.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <algorithm>
|
|
#include <vector>
|
|
|
|
namespace mlir {
|
|
namespace mesh {
|
|
#define GEN_PASS_DEF_SHARDINGPROPAGATION
|
|
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
|
|
} // namespace mesh
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "sharding-propagation"
|
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::mesh;
|
|
|
|
enum class ReshardingRquirementKind {
|
|
NO_RESHARDING = 0,
|
|
NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
|
|
RESHARDING_FOR_EXPLICIT_ANNOTATIONS
|
|
};
|
|
|
|
#ifdef LLVM_DEBUG
|
|
|
|
template <typename T>
|
|
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
|
|
const SmallVector<T> &vec);
|
|
template <typename... Ts>
|
|
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
|
|
const std::tuple<Ts...> &t);
|
|
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
|
|
ReshardingRquirementKind v);
|
|
|
|
template <typename Stream, typename Range>
|
|
static Stream &printRange(Stream &stream, Range &&range) {
|
|
stream << "[";
|
|
for (auto &v : range) {
|
|
stream << v;
|
|
stream << ", ";
|
|
}
|
|
return stream << "]";
|
|
}
|
|
|
|
template <typename T>
|
|
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
|
|
const SmallVector<T> &vec) {
|
|
return printRange(stream, vec);
|
|
}
|
|
|
|
[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
|
|
const ShardingOption &v) {
|
|
return stream << "{empty = " << v.empty << ", mesh" << v.mesh
|
|
<< ", shardingArray = " << v.shardingArray << "}";
|
|
}
|
|
|
|
template <typename Stream, typename... Ts, size_t... Is>
|
|
static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
|
|
std::index_sequence<Is...>) {
|
|
static_assert(sizeof...(Is) == sizeof...(Ts),
|
|
"Indices must have same number of elements as tuple types!");
|
|
static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
|
|
|
|
stream << "{";
|
|
((stream << std::get<Is>(tuple) << ", "), ...);
|
|
return stream << "}";
|
|
}
|
|
|
|
template <typename... Ts>
|
|
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
|
|
const std::tuple<Ts...> &t) {
|
|
return printTuple(stream, t, std::index_sequence_for<Ts...>{});
|
|
}
|
|
|
|
[[maybe_unused]] static llvm::raw_ostream &
|
|
operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
|
|
return stream << static_cast<int>(v);
|
|
}
|
|
|
|
#endif // LLVM_DEBUG
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// This method retrieves all potential sharding attributes, prioritizing
|
|
// specific shardings. For example, mustShardings = [shard0, None] and
|
|
// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
|
|
// [shard0, None]]
|
|
static SmallVector<std::vector<MeshSharding>>
|
|
getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
|
|
ArrayRef<MeshSharding> optionalShardings) {
|
|
SmallVector<std::vector<MeshSharding>> allShardingAttrs;
|
|
std::vector<MeshSharding> curShardingAttrs;
|
|
|
|
std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
|
|
if (i == mustShardings.size()) {
|
|
allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
|
|
return;
|
|
}
|
|
|
|
if (mustShardings[i]) {
|
|
curShardingAttrs.push_back(mustShardings[i]);
|
|
dfsCreateShardingAttrs(i + 1);
|
|
curShardingAttrs.pop_back();
|
|
return;
|
|
}
|
|
|
|
if (optionalShardings[i]) {
|
|
curShardingAttrs.push_back(optionalShardings[i]);
|
|
dfsCreateShardingAttrs(i + 1);
|
|
curShardingAttrs.pop_back();
|
|
curShardingAttrs.push_back({});
|
|
dfsCreateShardingAttrs(i + 1);
|
|
curShardingAttrs.pop_back();
|
|
return;
|
|
}
|
|
|
|
curShardingAttrs.push_back({});
|
|
dfsCreateShardingAttrs(i + 1);
|
|
curShardingAttrs.pop_back();
|
|
};
|
|
|
|
dfsCreateShardingAttrs(0);
|
|
return allShardingAttrs;
|
|
}
|
|
|
|
// The order of preference is form highest to lowest:
|
|
// 1. No resharding is required (all existing annotations are compatible).
|
|
// 2. No resharding for operands/results that have annotation specifically
|
|
// targeting this operation. This means
|
|
// * operands that are the result of `mesh.shard` ops marked with
|
|
// `annotate_for_users`.
|
|
// * results that are annotated with `mesh.shard` ops without
|
|
// `annotate_for_users`.
|
|
// 3. All other cases. Resharding is required for operands/results with
|
|
// annotation targeting explicitly this operation.
|
|
ReshardingRquirementKind getReshardingRquirementKind(
|
|
Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
|
|
ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
|
|
|
|
size_t operandsCount = op->getOperands().size();
|
|
auto operandShardings =
|
|
llvm::make_range(operandAndResultShardings.begin(),
|
|
operandAndResultShardings.begin() + operandsCount);
|
|
auto resultShardings =
|
|
llvm::make_range(operandAndResultShardings.begin() + operandsCount,
|
|
operandAndResultShardings.end());
|
|
|
|
for (auto [operand, sharding] :
|
|
llvm::zip_equal(op->getOperands(), operandShardings)) {
|
|
ShardOp shardOp = operand.getDefiningOp<ShardOp>();
|
|
if (!shardOp) {
|
|
continue;
|
|
}
|
|
bool needsResharding = sharding != shardOp.getSharding();
|
|
bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
|
|
if (needsResharding) {
|
|
if (isExplicitAnnotationForThisOp) {
|
|
// This is the worst case. No need to continue.
|
|
return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
|
|
}
|
|
res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
|
|
}
|
|
}
|
|
|
|
for (auto [result, sharding] :
|
|
llvm::zip_equal(op->getResults(), resultShardings)) {
|
|
for (auto user : result.getUsers()) {
|
|
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
|
|
if (!shardOp) {
|
|
continue;
|
|
}
|
|
bool needsResharding = sharding != shardOp.getSharding();
|
|
bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
|
|
if (needsResharding) {
|
|
if (isExplicitAnnotationForThisOp) {
|
|
// This is the worst case. No need to continue.
|
|
return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
|
|
}
|
|
res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
|
|
}
|
|
}
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
// From all the operand and result sharding combinations,
|
|
// return the one that is most desirable.
|
|
// The order of preference is:
|
|
// 1. No resharding with respect to existing sharding annotations.
|
|
// 2. Resharding for values that have already annotations that do not target
|
|
// this op.
|
|
// 3. Resharding of existing explicit sharding annotations for this op.
|
|
static FailureOr<ShardingOption> selectShardingOption(
|
|
ShardingInterface shardingOp,
|
|
ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
|
|
ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
|
|
SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
|
|
shardingOptionsAndReshardingRequirements;
|
|
|
|
for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
|
|
for (ArrayRef<MeshSharding> operandShardings :
|
|
possibleOperandShardingAttrs) {
|
|
FailureOr<ShardingOption> shardingOption =
|
|
shardingOp.getShardingOption(operandShardings, resultShardings);
|
|
if (failed(shardingOption) || shardingOption->empty) {
|
|
continue;
|
|
}
|
|
// These shardings may not be the same as those in operandShardings and
|
|
// resultShardings.
|
|
// They may be missing some annotations.
|
|
// Whatever is returned by getShardingAnnotations is exactly what the op
|
|
// needs.
|
|
FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
|
|
shardingOp.getShardingAnnotations(*shardingOption);
|
|
if (failed(operandAndResultShardings)) {
|
|
return failure();
|
|
}
|
|
|
|
// LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
|
|
// << *operandAndResultShardings << "\n";);
|
|
|
|
ReshardingRquirementKind reshardingRquirement =
|
|
getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
|
|
if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
|
|
// This is the best case. No need to go on.
|
|
return *shardingOption;
|
|
}
|
|
|
|
shardingOptionsAndReshardingRequirements.emplace_back(
|
|
std::move(*shardingOption), reshardingRquirement);
|
|
}
|
|
}
|
|
|
|
if (shardingOptionsAndReshardingRequirements.empty()) {
|
|
return ShardingOption::makeEmpty();
|
|
}
|
|
|
|
std::partial_sort(
|
|
shardingOptionsAndReshardingRequirements.begin(),
|
|
shardingOptionsAndReshardingRequirements.begin() + 1,
|
|
shardingOptionsAndReshardingRequirements.end(),
|
|
[](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
|
|
const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
|
|
return std::get<ReshardingRquirementKind>(a) <
|
|
std::get<ReshardingRquirementKind>(b);
|
|
});
|
|
|
|
LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
|
|
<< shardingOptionsAndReshardingRequirements << "\n";);
|
|
|
|
return std::get<ShardingOption>(
|
|
shardingOptionsAndReshardingRequirements.front());
|
|
}
|
|
|
|
// For each operation that implements the ShardingInterface, infer the sharding
|
|
// option of the operation from its operands and/or results using the
|
|
// `getShardingOption` method. If the inferred sharding option is not empty, add
|
|
// a `mesh.shard` operation for all remaining operands and results that do not
|
|
// have sharding annotations.
|
|
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
|
|
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
|
|
if (op->hasTrait<OpTrait::IsTerminator>() ||
|
|
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
|
|
llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
|
|
return success();
|
|
|
|
if (!shardingOp) {
|
|
op->emitOpError() << "sharding interface is not implemented.";
|
|
return failure();
|
|
}
|
|
|
|
// collect MeshSharding from results
|
|
std::vector<MeshSharding> allowConflictsResultShardings;
|
|
allowConflictsResultShardings.resize(op->getNumResults());
|
|
std::vector<MeshSharding> resultMustShardings;
|
|
resultMustShardings.resize(op->getNumResults());
|
|
for (OpResult result : op->getResults()) {
|
|
FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
|
|
getMeshSharding(result);
|
|
if (failed(maybeShardAttr))
|
|
continue;
|
|
if (!maybeShardAttr->first)
|
|
resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
|
|
else
|
|
allowConflictsResultShardings[result.getResultNumber()] =
|
|
maybeShardAttr->second;
|
|
}
|
|
|
|
// collect MeshSharding from operands
|
|
std::vector<MeshSharding> allowConflictsOperandShardings;
|
|
allowConflictsOperandShardings.resize(op->getNumOperands());
|
|
std::vector<MeshSharding> operandMustShardings;
|
|
operandMustShardings.resize(op->getNumOperands());
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
|
|
getMeshSharding(opOperand);
|
|
if (failed(maybeShardAttr))
|
|
continue;
|
|
|
|
if (maybeShardAttr->first)
|
|
operandMustShardings[opOperand.getOperandNumber()] =
|
|
maybeShardAttr->second;
|
|
else
|
|
allowConflictsOperandShardings[opOperand.getOperandNumber()] =
|
|
maybeShardAttr->second;
|
|
}
|
|
|
|
// try to get the sharding option
|
|
SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
|
|
getOrderedPossibleShardingAttrs(operandMustShardings,
|
|
allowConflictsOperandShardings);
|
|
SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
|
|
getOrderedPossibleShardingAttrs(resultMustShardings,
|
|
allowConflictsResultShardings);
|
|
FailureOr<ShardingOption> shardingOption = selectShardingOption(
|
|
shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
|
|
|
|
if (failed(shardingOption)) {
|
|
op->emitOpError() << "fail to get sharding option.";
|
|
return failure();
|
|
}
|
|
|
|
LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
|
|
|
|
// sharding info is empty, return immediately
|
|
if (shardingOption->empty)
|
|
return success();
|
|
|
|
if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
|
|
op->emitOpError() << "fail to set sharding annotations.";
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ShardingPropagation
|
|
//===----------------------------------------------------------------------===//
|
|
struct ShardingPropagation
|
|
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
|
|
|
|
using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
|
|
|
|
void runOnOperation() override {
|
|
FunctionOpInterface funcOp = getOperation();
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
Region ®ion = funcOp.getFunctionBody();
|
|
OpBuilder builder(ctx);
|
|
if (!region.hasOneBlock()) {
|
|
funcOp.emitOpError() << "only one block is supported!";
|
|
return signalPassFailure();
|
|
}
|
|
Block &block = region.front();
|
|
|
|
LLVM_DEBUG(
|
|
DBGS() << "print all the ops' iterator types and indexing maps in the "
|
|
"block.\n";
|
|
for (Operation &op : block.getOperations()) {
|
|
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
|
|
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
|
|
});
|
|
|
|
auto traverse = [&](auto &&range, OpBuilder &builder,
|
|
const char *order) -> bool {
|
|
for (Operation &op : range) {
|
|
if (failed(visitOp(&op, builder))) {
|
|
signalPassFailure();
|
|
return true;
|
|
}
|
|
}
|
|
LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
|
|
<< funcOp << "\n");
|
|
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
|
|
return false;
|
|
};
|
|
|
|
// 1. Propagate in reversed order.
|
|
if (traversal == TraversalOrder::Backward ||
|
|
traversal == TraversalOrder::BackwardForward)
|
|
traverse(llvm::reverse(block), builder, "backward");
|
|
|
|
// 2. Propagate in original order.
|
|
if (traversal != TraversalOrder::Backward)
|
|
traverse(block, builder, "forward");
|
|
|
|
// 3. Propagate in backward order if needed.
|
|
if (traversal == TraversalOrder::ForwardBackward)
|
|
traverse(llvm::reverse(block), builder, "backward");
|
|
}
|
|
};
|