llvm-project/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
Frank Schlimbach b2d4963ee9
[NFC][mlir][mesh,shard] Fixing misnomers in mesh dialect, renaming 'mesh' dialect to 'shard' (#150177)
Dialect to 'shard' (discourse 87053)
  - dialect name mesh -> shard
  - (device) mesh -> (device) grid
  - spmdize -> partition

A lot of diffs, but simple renames only.

@tkarna @yaochengji
2025-07-25 16:53:08 +02:00

350 lines
14 KiB
C++

//===- ShardingInterfaceImpl.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
#include <optional>
namespace mlir::linalg {
using GridAxis = shard::GridAxis;
using ReductionKind = shard::ReductionKind;
using Sharding = shard::Sharding;
using ShardingArray = shard::ShardingArray;
using GridOp = shard::GridOp;
// Returns the corresponding grid reduction kind for the given arith op.
static ReductionKind getReductionKind(Operation *op) {
return llvm::TypeSwitch<Operation *, ReductionKind>(op)
// Floating-point operations.
.Case([](arith::AddFOp op) { return ReductionKind::Sum; })
.Case([](arith::MulFOp op) { return ReductionKind::Product; })
// TODO: handle maxnumf and minnumf.
.Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
.Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
// Integer operations.
.Case([](arith::AddIOp op) { return ReductionKind::Sum; })
.Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
.Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
.Case([](arith::AndIOp op) { return ReductionKind::Sum; })
// TODO: handle signless, signed and unsigned types properly.
// It is assumed that the element type of the collective operands and
// result drive the meaning of the reduction kind, whether it is signed
// or unsigned.
// The reduction op inside the linalg op may have different result type
// from the element type of the linalg op's result.
// Also signed and unsigned Arith dialect ops may accept signed, unsigned
// or signless operands.
// Maybe expand the reduction kinds.
.Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
.Case([](arith::MinUIOp op) { return ReductionKind::Min; })
.Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
.Case([](arith::MinSIOp op) { return ReductionKind::Min; })
.Case([](arith::MulIOp op) { return ReductionKind::Product; })
.Default([](Operation *op) { return ReductionKind::Generic; });
}
static std::optional<Operation *> getCombinerOp(LinalgOp op) {
SmallVector<Operation *> combinerOps;
Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
if (!reducedValue || combinerOps.size() != 1) {
return std::nullopt;
}
return combinerOps[0];
}
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
std::optional<Operation *> reductionOp = getCombinerOp(op);
if (!reductionOp) {
return ReductionKind::Generic;
}
[[maybe_unused]] Type resultElementType =
llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
// TODO: handle case when result type of the reduction op does not match the
// element type of the result tensor.
// Would it makes sense at all?
assert(resultElementType == reductionOp.value()->getResult(0).getType());
return getReductionKind(reductionOp.value());
}
static GridOp getGrid(Operation *op, ArrayRef<Sharding> operandShardings,
ArrayRef<Sharding> resultShardings,
SymbolTableCollection &symbolTable) {
for (const Sharding &sharding : operandShardings) {
if (sharding) {
return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
}
}
for (const Sharding &sharding : resultShardings) {
if (sharding) {
return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
}
}
assert(false);
return nullptr;
}
// Choose the operand based on the current process index along the reduction
// grid axes.
// We need to use the initial value only once to avoid including it in the
// reduction multiple times.
// In each process group only the leading process with linear index 0 would use
// the original operand.
// The other processes would use the reduction operation neutral tensor.
static Value createDestinationPassingStyleInitOperand(
LinalgOp op, int operandNumber, Value partitionedOperand,
ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
ImplicitLocOpBuilder &builder) {
Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
gridOp.getSymName(), reductionGridAxes, builder);
Value zero = arith::ConstantIndexOp::create(builder, 0);
Value isLeadProcess = arith::CmpIOp::create(
builder, builder.getI1Type(), arith::CmpIPredicate::eq,
processLinearIndexInReductionGroup, zero);
scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.getType(),
isLeadProcess, true, true);
// Then block.
{
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
scf::YieldOp::create(builder, partitionedOperand);
}
// Else block.
{
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
SmallVector<OpFoldResult> shape =
tensor::getMixedSizes(builder, builder.getLoc(), partitionedOperand);
SmallVector<Operation *> combinerOps;
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
assert(combinerOps.size() == 1);
std::optional<TypedAttr> neutralEl =
arith::getNeutralElement(combinerOps[0]);
Value init = tensor::EmptyOp::create(builder, op.getLoc(), shape,
neutralEl.value().getType());
Value constant =
arith::ConstantOp::create(builder, op.getLoc(), neutralEl.value());
Value fill = linalg::FillOp::create(builder, op.getLoc(), constant, init)
.getResult(0);
scf::YieldOp::create(builder, fill);
}
return ifOp.getResult(0);
}
// Create the DPS init operands for the partitioned Linalg op.
// Return all the new partitioned operands.
static SmallVector<Value> createDestinationPassingStyleInitOperands(
LinalgOp op, GridOp gridOp, ArrayRef<Value> partitionedOperands,
ArrayRef<GridAxis> reductionGridAxes, IRMapping &partitionMap,
ImplicitLocOpBuilder &builder) {
// TODO: add support for multiple destination passing style initial value
// operands.
assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
SmallVector<Value> newOperands = llvm::to_vector(partitionedOperands);
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
Value partitionedInitOperand =
partitionMap.lookup(op->getOperands()[operandIdx]);
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder);
return newOperands;
}
static void createAllReduceForResultsWithoutPartialShardings(
LinalgOp unshardedOp, ArrayRef<GridAxis> opReductionGridAxes,
ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
ImplicitLocOpBuilder &builder) {
ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
for (auto [unshardedLinalgOpResult, resultSharding] :
llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
Value partitionedLinalgOpResult =
partitionMap.lookup(unshardedLinalgOpResult);
Value reducedValue = shard::AllReduceOp::create(
builder, partitionedLinalgOpResult, resultSharding.getGrid(),
opReductionGridAxes, reductionKind);
partitionMap.map(unshardedLinalgOpResult, reducedValue);
}
}
static void partitionLinalgOpWithShardedReduction(
LinalgOp op, ArrayRef<Value> partitionedOperands,
ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators,
IRMapping &partitionMap, SymbolTableCollection &symbolTable,
ImplicitLocOpBuilder &builder) {
GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable);
SmallVector<GridAxis> reductionGridAxes = shard::getReductionGridAxes(
loopIteratorTypes, gridAxisAssignmentForLoopIterators);
SmallVector<Value> partitionedLinalgOpOperands =
createDestinationPassingStyleInitOperands(op, grid, partitionedOperands,
reductionGridAxes, partitionMap,
builder);
// We must not change the operand mappings of the original partitionMap as
// they are the mappings for the whole partition blob and may be used by
// others.
IRMapping internalPartitionMap;
for (auto [unshardedOperand, partitionedOperand] :
llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) {
internalPartitionMap.map(unshardedOperand, partitionedOperand);
}
partitionTriviallyShardableOperation(
*op, partitionedLinalgOpOperands, operandShardings, resultShardings,
internalPartitionMap, symbolTable, builder);
for (Value result : op->getResults()) {
partitionMap.map(result, internalPartitionMap.lookup(result));
}
// Handle partial shardings.
createAllReduceForResultsWithoutPartialShardings(
op, reductionGridAxes, resultShardings, partitionMap, builder);
}
namespace {
// ShardingInterface for ops that implement LinalgStructuredInterface.
// The supported ops are only those where the indexing maps are projected
// permutations.
template <typename Op>
struct StructuredOpShardingInterface
: public shard::ShardingInterface::ExternalModel<
StructuredOpShardingInterface<Op>, Op> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
}
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
// Results must have the same indexing as destination passing style initial
// operands.
for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
}
return res;
}
SmallVector<ReductionKind>
getReductionLoopIteratorKinds(Operation *op) const {
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
unsigned reductionItersCount = std::accumulate(
iteratorTypes.begin(), iteratorTypes.end(), 0,
[](unsigned count, utils::IteratorType iter) {
return count + (iter == utils::IteratorType::reduction);
});
shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
}
LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
ArrayRef<Sharding> operandShardings,
ArrayRef<Sharding> resultShardings,
IRMapping &partitionMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
bool allIndexingMapsAreProjectedPermutation =
llvm::all_of(indexingMaps, [](AffineMap map) {
return map.isProjectedPermutation();
});
if (!allIndexingMapsAreProjectedPermutation) {
// TODO: handle non-projected permutations.
return op->emitOpError()
<< "supports indexing maps that are only projected permutation.";
}
SmallVector<utils::IteratorType> loopIteratorTypes =
linalgOp.getIteratorTypesArray();
ShardingArray gridAxisAssignmentForLoopIterators =
getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings,
loopIteratorTypes, indexingMaps);
if (shard::isAtLeastOneReductionIteratorSharded(
loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
partitionLinalgOpWithShardedReduction(
linalgOp, partitionedOperands, operandShardings, resultShardings,
loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
symbolTable, implicitLocBuilder);
} else {
partitionTriviallyShardableOperation(*op, partitionedOperands,
operandShardings, resultShardings,
partitionMap, symbolTable, builder);
}
return success();
}
};
} // namespace
template <typename OpType>
static void registerOne(MLIRContext *ctx) {
OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
}
/// Variadic helper function.
template <typename... OpTypes>
static void registerAll(MLIRContext *ctx) {
(registerOne<OpTypes>(ctx), ...);
}
void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
DialectRegistry registry;
registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
tensor::TensorDialect>();
ctx->appendDialectRegistry(registry);
for (StringRef name : registry.getDialectNames())
ctx->getOrLoadDialect(name);
registerOne<linalg::GenericOp>(ctx);
registerAll<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(ctx);
});
}
} // namespace mlir::linalg