llvm-project/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
Longsheng Mou f047b735e9
[mlir][NFC] Use getDefiningOp<OpTy>() instead of dyn_cast<OpTy>(getDefiningOp()) (#150428)
This PR uses `val.getDefiningOp<OpTy>()` to replace `dyn_cast<OpTy>(val.getDefiningOp())` , `dyn_cast_or_null<OpTy>(val.getDefiningOp())` and `dyn_cast_if_present<OpTy>(val.getDefiningOp())`.
2025-07-25 10:35:51 +08:00

412 lines
15 KiB
C++

//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <vector>
namespace mlir {
namespace mesh {
#define GEN_PASS_DEF_SHARDINGPROPAGATION
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
} // namespace mesh
} // namespace mlir
#define DEBUG_TYPE "sharding-propagation"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::mesh;
enum class ReshardingRquirementKind {
NO_RESHARDING = 0,
NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
RESHARDING_FOR_EXPLICIT_ANNOTATIONS
};
#ifdef LLVM_DEBUG
template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const SmallVector<T> &vec);
template <typename... Ts>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const std::tuple<Ts...> &t);
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
ReshardingRquirementKind v);
template <typename Stream, typename Range>
static Stream &printRange(Stream &stream, Range &&range) {
stream << "[";
for (auto &v : range) {
stream << v;
stream << ", ";
}
return stream << "]";
}
template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const SmallVector<T> &vec) {
return printRange(stream, vec);
}
[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const ShardingOption &v) {
return stream << "{empty = " << v.empty << ", mesh" << v.mesh
<< ", shardingArray = " << v.shardingArray << "}";
}
template <typename Stream, typename... Ts, size_t... Is>
static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
std::index_sequence<Is...>) {
static_assert(sizeof...(Is) == sizeof...(Ts),
"Indices must have same number of elements as tuple types!");
static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
stream << "{";
((stream << std::get<Is>(tuple) << ", "), ...);
return stream << "}";
}
template <typename... Ts>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const std::tuple<Ts...> &t) {
return printTuple(stream, t, std::index_sequence_for<Ts...>{});
}
[[maybe_unused]] static llvm::raw_ostream &
operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
return stream << static_cast<int>(v);
}
#endif // LLVM_DEBUG
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
// This method retrieves all potential sharding attributes, prioritizing
// specific shardings. For example, mustShardings = [shard0, None] and
// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
// [shard0, None]]
static SmallVector<std::vector<MeshSharding>>
getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
ArrayRef<MeshSharding> optionalShardings) {
SmallVector<std::vector<MeshSharding>> allShardingAttrs;
std::vector<MeshSharding> curShardingAttrs;
std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
if (i == mustShardings.size()) {
allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
return;
}
if (mustShardings[i]) {
curShardingAttrs.push_back(mustShardings[i]);
dfsCreateShardingAttrs(i + 1);
curShardingAttrs.pop_back();
return;
}
if (optionalShardings[i]) {
curShardingAttrs.push_back(optionalShardings[i]);
dfsCreateShardingAttrs(i + 1);
curShardingAttrs.pop_back();
curShardingAttrs.push_back({});
dfsCreateShardingAttrs(i + 1);
curShardingAttrs.pop_back();
return;
}
curShardingAttrs.push_back({});
dfsCreateShardingAttrs(i + 1);
curShardingAttrs.pop_back();
};
dfsCreateShardingAttrs(0);
return allShardingAttrs;
}
// The order of preference is form highest to lowest:
// 1. No resharding is required (all existing annotations are compatible).
// 2. No resharding for operands/results that have annotation specifically
// targeting this operation. This means
// * operands that are the result of `mesh.shard` ops marked with
// `annotate_for_users`.
// * results that are annotated with `mesh.shard` ops without
// `annotate_for_users`.
// 3. All other cases. Resharding is required for operands/results with
// annotation targeting explicitly this operation.
ReshardingRquirementKind getReshardingRquirementKind(
Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
size_t operandsCount = op->getOperands().size();
auto operandShardings =
llvm::make_range(operandAndResultShardings.begin(),
operandAndResultShardings.begin() + operandsCount);
auto resultShardings =
llvm::make_range(operandAndResultShardings.begin() + operandsCount,
operandAndResultShardings.end());
for (auto [operand, sharding] :
llvm::zip_equal(op->getOperands(), operandShardings)) {
ShardOp shardOp = operand.getDefiningOp<ShardOp>();
if (!shardOp) {
continue;
}
bool needsResharding = sharding != shardOp.getSharding();
bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
if (needsResharding) {
if (isExplicitAnnotationForThisOp) {
// This is the worst case. No need to continue.
return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
}
res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
}
}
for (auto [result, sharding] :
llvm::zip_equal(op->getResults(), resultShardings)) {
for (auto user : result.getUsers()) {
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
if (!shardOp) {
continue;
}
bool needsResharding = sharding != shardOp.getSharding();
bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
if (needsResharding) {
if (isExplicitAnnotationForThisOp) {
// This is the worst case. No need to continue.
return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
}
res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
}
}
}
return res;
}
// From all the operand and result sharding combinations,
// return the one that is most desirable.
// The order of preference is:
// 1. No resharding with respect to existing sharding annotations.
// 2. Resharding for values that have already annotations that do not target
// this op.
// 3. Resharding of existing explicit sharding annotations for this op.
static FailureOr<ShardingOption> selectShardingOption(
ShardingInterface shardingOp,
ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
shardingOptionsAndReshardingRequirements;
for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
for (ArrayRef<MeshSharding> operandShardings :
possibleOperandShardingAttrs) {
FailureOr<ShardingOption> shardingOption =
shardingOp.getShardingOption(operandShardings, resultShardings);
if (failed(shardingOption) || shardingOption->empty) {
continue;
}
// These shardings may not be the same as those in operandShardings and
// resultShardings.
// They may be missing some annotations.
// Whatever is returned by getShardingAnnotations is exactly what the op
// needs.
FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
shardingOp.getShardingAnnotations(*shardingOption);
if (failed(operandAndResultShardings)) {
return failure();
}
// LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
// << *operandAndResultShardings << "\n";);
ReshardingRquirementKind reshardingRquirement =
getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
// This is the best case. No need to go on.
return *shardingOption;
}
shardingOptionsAndReshardingRequirements.emplace_back(
std::move(*shardingOption), reshardingRquirement);
}
}
if (shardingOptionsAndReshardingRequirements.empty()) {
return ShardingOption::makeEmpty();
}
std::partial_sort(
shardingOptionsAndReshardingRequirements.begin(),
shardingOptionsAndReshardingRequirements.begin() + 1,
shardingOptionsAndReshardingRequirements.end(),
[](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
return std::get<ReshardingRquirementKind>(a) <
std::get<ReshardingRquirementKind>(b);
});
LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
<< shardingOptionsAndReshardingRequirements << "\n";);
return std::get<ShardingOption>(
shardingOptionsAndReshardingRequirements.front());
}
// For each operation that implements the ShardingInterface, infer the sharding
// option of the operation from its operands and/or results using the
// `getShardingOption` method. If the inferred sharding option is not empty, add
// a `mesh.shard` operation for all remaining operands and results that do not
// have sharding annotations.
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (op->hasTrait<OpTrait::IsTerminator>() ||
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
return success();
if (!shardingOp) {
op->emitOpError() << "sharding interface is not implemented.";
return failure();
}
// collect MeshSharding from results
std::vector<MeshSharding> allowConflictsResultShardings;
allowConflictsResultShardings.resize(op->getNumResults());
std::vector<MeshSharding> resultMustShardings;
resultMustShardings.resize(op->getNumResults());
for (OpResult result : op->getResults()) {
FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
getMeshSharding(result);
if (failed(maybeShardAttr))
continue;
if (!maybeShardAttr->first)
resultMustShardings[result.getResultNumber()] = maybeShardAttr->second;
else
allowConflictsResultShardings[result.getResultNumber()] =
maybeShardAttr->second;
}
// collect MeshSharding from operands
std::vector<MeshSharding> allowConflictsOperandShardings;
allowConflictsOperandShardings.resize(op->getNumOperands());
std::vector<MeshSharding> operandMustShardings;
operandMustShardings.resize(op->getNumOperands());
for (OpOperand &opOperand : op->getOpOperands()) {
FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
getMeshSharding(opOperand);
if (failed(maybeShardAttr))
continue;
if (maybeShardAttr->first)
operandMustShardings[opOperand.getOperandNumber()] =
maybeShardAttr->second;
else
allowConflictsOperandShardings[opOperand.getOperandNumber()] =
maybeShardAttr->second;
}
// try to get the sharding option
SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
getOrderedPossibleShardingAttrs(operandMustShardings,
allowConflictsOperandShardings);
SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
getOrderedPossibleShardingAttrs(resultMustShardings,
allowConflictsResultShardings);
FailureOr<ShardingOption> shardingOption = selectShardingOption(
shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
if (failed(shardingOption)) {
op->emitOpError() << "fail to get sharding option.";
return failure();
}
LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
// sharding info is empty, return immediately
if (shardingOption->empty)
return success();
if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
op->emitOpError() << "fail to set sharding annotations.";
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// ShardingPropagation
//===----------------------------------------------------------------------===//
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
Region &region = funcOp.getFunctionBody();
OpBuilder builder(ctx);
if (!region.hasOneBlock()) {
funcOp.emitOpError() << "only one block is supported!";
return signalPassFailure();
}
Block &block = region.front();
LLVM_DEBUG(
DBGS() << "print all the ops' iterator types and indexing maps in the "
"block.\n";
for (Operation &op : block.getOperations()) {
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
auto traverse = [&](auto &&range, OpBuilder &builder,
const char *order) -> bool {
for (Operation &op : range) {
if (failed(visitOp(&op, builder))) {
signalPassFailure();
return true;
}
}
LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
<< funcOp << "\n");
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
return false;
};
// 1. Propagate in reversed order.
if (traversal == TraversalOrder::Backward ||
traversal == TraversalOrder::BackwardForward)
traverse(llvm::reverse(block), builder, "backward");
// 2. Propagate in original order.
if (traversal != TraversalOrder::Backward)
traverse(block, builder, "forward");
// 3. Propagate in backward order if needed.
if (traversal == TraversalOrder::ForwardBackward)
traverse(llvm::reverse(block), builder, "backward");
}
};