//===- ShardingInterface.cpp -------------------------------------*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "sharding-interface" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::shard; #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc" //===----------------------------------------------------------------------===// // common util functions //===----------------------------------------------------------------------===// static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl &seenIds) { switch (expr.getKind()) { case AffineExprKind::Add: { auto binOpExpr = cast(expr); AffineExpr lhs = binOpExpr.getLHS(); AffineExpr rhs = binOpExpr.getRHS(); if (failed(checkOperandAffineExprRecursively(lhs, seenIds))) return failure(); if (failed(checkOperandAffineExprRecursively(rhs, seenIds))) return failure(); return success(); } case AffineExprKind::Mul: { auto binOpExpr = cast(expr); AffineExpr lhs = binOpExpr.getLHS(); AffineExpr rhs = binOpExpr.getRHS(); AffineExpr dimExpr; if (lhs.getKind() == AffineExprKind::DimId && rhs.getKind() == AffineExprKind::Constant) { dimExpr = lhs; } else if (rhs.getKind() == AffineExprKind::DimId && lhs.getKind() == AffineExprKind::Constant) { dimExpr = rhs; } else { return failure(); } unsigned position = cast(dimExpr).getPosition(); if ((size_t)position >= seenIds.size() || seenIds[position]) return failure(); seenIds[position] = true; return success(); } case AffineExprKind::DimId: { unsigned position = cast(expr).getPosition(); if ((size_t)position >= seenIds.size() || seenIds[position]) return failure(); seenIds[position] = true; return success(); } default: return failure(); } } static FailureOr> checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { SmallVector seenIds(numDims, false); if (failed(checkOperandAffineExprRecursively(expr, seenIds))) return failure(); llvm::SmallSet positions; for (auto it : llvm::enumerate(seenIds)) { if (it.value()) positions.insert((unsigned)it.index()); } return positions; } template SmallVector fromArrayOfVector(MLIRContext *ctxt, const SmallVector> &vec) { SmallVector res; for (const auto &v : vec) { res.emplace_back(GridAxesAttr::get(ctxt, v)); } return res; } //===----------------------------------------------------------------------===// // shard::getSharding //===----------------------------------------------------------------------===// FailureOr> shard::getSharding(OpResult result) { Value val = cast(result); bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) { auto shardOp = llvm::dyn_cast(user); if (!shardOp) return false; return !shardOp.getAnnotateForUsers(); }); if (anyShardedForDef) { // expected to have exact one use if it has a use of `shard.shard` without // unit attr annotate_for_users if (!val.hasOneUse()) return failure(); auto shardOp = llvm::cast(*val.getUsers().begin()); return std::make_pair(false, Sharding(shardOp.getSharding())); } bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) { auto shardOp = llvm::dyn_cast(user); if (!shardOp) return false; return shardOp.getAnnotateForUsers(); }); if (anyShardedForUsers) { SmallVector shardOps; for (Operation *user : val.getUsers()) { ShardOp shardOp = llvm::dyn_cast(user); if (shardOp) shardOps.push_back(shardOp); } Sharding shardForDef = shardOps[0].getSharding(); for (size_t i = 1; i < shardOps.size(); ++i) { // TODO: Deduce a reasonable grid sharding attr for def when they are // different assert(shardForDef == shardOps[i].getSharding() && "only support all shard ops have the same grid sharding attr"); } return std::make_pair(true, shardForDef); } return failure(); } FailureOr> shard::getSharding(OpOperand &opOperand) { Value val = opOperand.get(); if (ShardOp shardOp = val.getDefiningOp()) return std::make_pair(shardOp.getAnnotateForUsers(), Sharding(shardOp.getSharding())); return failure(); } //===----------------------------------------------------------------------===// // ShardingInterface::verifyShardingInterfaceImpl //===----------------------------------------------------------------------===// LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() { Operation *op = getOperation(); // check operands and results type for (Type type : op->getOperandTypes()) if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); for (Type type : op->getResultTypes()) if (!llvm::isa(type) && !type.isIntOrIndexOrFloat()) return failure(); // check maps SmallVector maps = getIndexingMaps(); if (maps.empty()) return failure(); unsigned numOperands = op->getNumOperands(); unsigned numResults = op->getNumResults(); if (numOperands + numResults != maps.size()) return failure(); for (OpResult result : op->getResults()) { auto resultType = dyn_cast(result.getType()); if (!resultType) return failure(); AffineMap map = maps[numOperands + result.getResultNumber()]; if (!map.isProjectedPermutation()) { return failure(); } } return success(); } //===----------------------------------------------------------------------===// // ShardingInterface::printLoopTypesAndIndexingMaps //===----------------------------------------------------------------------===// void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { os << "print loop types and indexing maps for: \n"; getOperation()->print(os); os << "\n"; os << "loop types: ["; for (utils::IteratorType type : getLoopIteratorTypes()) { os << stringifyEnum(type) << " "; } os << "]\n"; os << "indexing maps: \n"; for (AffineMap map : getIndexingMaps()) os << map << "\n"; os << "\n"; } //===----------------------------------------------------------------------===// // detail::defaultGetShardingOption //===----------------------------------------------------------------------===// namespace { // Update the given `shardingOption` according to `gridAxes` and `loopIdx` static LogicalResult fillShardingOption(Operation *op, ShardingOption &shardingOption, FlatSymbolRefAttr grid, ArrayRef gridAxes, unsigned loopIdx) { if ((shardingOption.grid && grid && shardingOption.grid != grid) || (!shardingOption.shardingArray[loopIdx].empty() && shardingOption.shardingArray[loopIdx] != gridAxes)) { LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " << loopIdx << "\n"); return failure(); } for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) { if (i == loopIdx) continue; for (GridAxis axis : gridAxes) { if (llvm::is_contained(shardingOption.shardingArray[i], axis)) { LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes " << axis << " duplicate"); return failure(); } } } if (grid) shardingOption.grid = grid; if (shardingOption.shardingArray[loopIdx].empty()) shardingOption.shardingArray[loopIdx].append(gridAxes.begin(), gridAxes.end()); return success(); } } // namespace FailureOr shard::detail::defaultGetShardingOption(Operation *op, ArrayRef operandShardings, ArrayRef resultShardings) { ShardingInterface shardingOp = llvm::cast(op); ShardingOption shardingOption; if (failed(shardingOp.verifyShardingInterfaceImpl())) return op->emitOpError() << "invalid sharding interface implementation"; SmallVector loopTypes = shardingOp.getLoopIteratorTypes(); SmallVector maps = shardingOp.getIndexingMaps(); unsigned numOperands = op->getNumOperands(); shardingOption.shardingArray.resize(loopTypes.size()); llvm::SmallSet visitedLoopIndices; bool anyShardingInResultsOrOperands = false; // 1. Fill sharding option based on op results for (auto shardingIt : llvm::enumerate(resultShardings)) { Sharding shardAttr = shardingIt.value(); if (!shardAttr) continue; AffineMap map = maps[numOperands + shardingIt.index()]; anyShardingInResultsOrOperands = true; if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) { shardingOption.grid = shardAttr.getGridAttr(); } else { // Handle the split axes: calculate the corresponding loop index for each // split axes sub-array, and then store the sub-array to // shardingOption[index] for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { AffineExpr expr = std::get<0>(it); ArrayRef axes = std::get<1>(it).asArrayRef(); auto dim = cast(expr); unsigned index = dim.getPosition(); visitedLoopIndices.insert(index); if (failed(fillShardingOption(op, shardingOption, shardAttr.getGridAttr(), axes, index))) return failure(); } } } // 2. Fill sharding option based on operands for (auto shardingIt : llvm::enumerate(operandShardings)) { Sharding shardAttr = shardingIt.value(); if (!shardAttr) continue; anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty(); AffineMap map = maps[shardingIt.index()]; unsigned numDims = map.getNumDims(); // Handle the split axes. // // TODO: Change to process the operands with single loop index first and // then the operands with multiple loop indices. for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { AffineExpr expr = std::get<0>(it); ArrayRef axes = std::get<1>(it).asArrayRef(); FailureOr> loopIndices = checkOperandAffineExpr(expr, numDims); if (failed(loopIndices)) return op->emitOpError() << "operand's affine expression is restricted to const_i * " "dim_i + const_j + dim_j + ..."; if (loopIndices->empty()) continue; if (loopIndices->size() == 1) { unsigned loopIdx = *loopIndices->begin(); visitedLoopIndices.insert(loopIdx); if (failed(fillShardingOption(op, shardingOption, shardAttr.getGridAttr(), axes, loopIdx))) return failure(); } // If multiple loop indices correspond to a dimension of an operand, it is // difficult to infer which loop indices are responsible for sharding. // Therefore, the exact loop index must be specified by others. if (loopIndices->size() > 1) { bool seenLoopIndices = false; for (unsigned loopIdx : *loopIndices) { if (visitedLoopIndices.contains(loopIdx)) { seenLoopIndices = true; break; } } if (!seenLoopIndices) return op->emitOpError() << "the operand " << shardingIt.index() << " has multiple loop indices in a dimension, but none of " "them could be found in the exactly specified annotation " "of op results or operands."; } } } // 3. Finalize sharding option removeTrailingEmptySubArray(shardingOption.shardingArray); if (!anyShardingInResultsOrOperands) shardingOption.empty = true; return shardingOption; } // Get the sharding attributed for the given result and sharding option. static Sharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef loopTypes) { auto resultType = cast(result.getType()); SmallVector> splitAxes(resultType.getRank()); // process the split axes for (auto it : llvm::enumerate(map.getResults())) { AffineExpr expr = it.value(); // `expr` must be an `AffineDimExpr` because `map` is verified by // isProjectedPermutation auto dim = cast(expr); unsigned loopIdx = dim.getPosition(); if (loopIdx < shardingOption.shardingArray.size()) splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]); } removeTrailingEmptySubArray(splitAxes); return Sharding::get(shardingOption.grid, fromArrayOfVector(result.getContext(), splitAxes)); } static FailureOr getSharding(OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map) { Value operandValue = opOperand.get(); auto operandType = dyn_cast(operandValue.getType()); if (!operandType) { if (operandValue.getType().isIntOrIndexOrFloat()) return Sharding(); return failure(); } // 0d tensors cannot be sharded and must get replicated if (operandType.getRank() == 0) { return Sharding(shardingOption.grid); } SmallVector> splitAxes(operandType.getRank()); unsigned numDims = map.getNumDims(); for (auto it : llvm::enumerate(map.getResults())) { int64_t idx = it.index(); AffineExpr expr = it.value(); FailureOr> loopIndices = checkOperandAffineExpr(expr, numDims); if (failed(loopIndices)) return failure(); SmallVector shardedLoopIndices; for (unsigned loopIdx : *loopIndices) { if ((size_t)loopIdx < shardingOption.shardingArray.size() && !shardingOption.shardingArray[loopIdx].empty()) shardedLoopIndices.push_back(loopIdx); } // mostly one sharded loop index is accepted if (shardedLoopIndices.size() > 1) return failure(); if (shardedLoopIndices.size() == 1) { splitAxes[idx].append( shardingOption.shardingArray[shardedLoopIndices[0]]); } } removeTrailingEmptySubArray(splitAxes); return Sharding::get( shardingOption.grid, fromArrayOfVector(opOperand.get().getContext(), splitAxes)); } FailureOr> shard::detail::defaultGetShardingAnnotations( Operation *op, const ShardingOption &shardingOption) { std::vector res; ShardingInterface shardingOp = llvm::cast(op); SmallVector loopTypes = shardingOp.getLoopIteratorTypes(); SmallVector maps = shardingOp.getIndexingMaps(); unsigned numOperands = op->getNumOperands(); for (OpOperand &opOperand : op->getOpOperands()) { FailureOr shardingAttr = ::getSharding( opOperand, shardingOption, maps[opOperand.getOperandNumber()]); if (failed(shardingAttr)) return failure(); res.push_back(*shardingAttr); } for (OpResult result : op->getResults()) { res.push_back(::getSharding(result, shardingOption, maps[numOperands + result.getResultNumber()], loopTypes)); } return res; } //===----------------------------------------------------------------------===// // detail::defaultAddShardingAnnotations //===----------------------------------------------------------------------===// // To add a `shard.shard` op for the given result, based on the details provided // in `shardingOption`, `map`, and `loopTypes`. static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef loopTypes) { Sharding sharding = getSharding(result, shardingOption, map, loopTypes); maybeInsertTargetShardingAnnotation(sharding, result, b); return success(); } // To add a `shard.shard` op for the given operand, based on the details // provided in `shardingOption`, `map`, and `loopTypes`. static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map) { FailureOr sharding = getSharding(opOperand, shardingOption, map); if (failed(sharding)) { return failure(); } OpBuilder::InsertionGuard guard(b); maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b); return success(); } LogicalResult shard::detail::defaultAddShardingAnnotations( Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { assert(!shardingOption.empty && shardingOption.grid); ShardingInterface shardingOp = llvm::cast(op); SmallVector loopTypes = shardingOp.getLoopIteratorTypes(); SmallVector maps = shardingOp.getIndexingMaps(); unsigned numOperands = op->getNumOperands(); // 1. add shard.shard ops for all op results for (OpResult result : op->getResults()) { if (failed(addShardOp(b, result, shardingOption, maps[numOperands + result.getResultNumber()], loopTypes))) return failure(); } // 2. add shard.shard ops for all operands for (OpOperand &opOperand : op->getOpOperands()) { if (failed(addShardOp(b, opOperand, shardingOption, maps[opOperand.getOperandNumber()]))) return failure(); } return success(); } #ifndef NDEBUG static bool isValueCompatibleWithFullReplicationSharding(Value value, Sharding sharding) { if (isa(value.getType())) { return isFullReplication(sharding); } return !sharding; } template static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, ShardingRage &&shardings) { if (std::size(values) != std::size(shardings)) { return false; } return llvm::all_of(llvm::zip_equal(std::forward(values), std::forward(shardings)), [](auto valueAndSharding) { return isValueCompatibleWithFullReplicationSharding( std::get<0>(valueAndSharding), std::get<1>(valueAndSharding)); }); } #endif // NDEBUG void shard::partitionFullyReplicatedOperation( Operation &op, ArrayRef partitionedOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) { assert(partitionedOperands.size() == operandShardings.size()); assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), operandShardings)); assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), resultShardings)); // `clone` will populate the mapping of old to new results. builder.clone(op, partitionMap); } static void updateGridAxisAssignmentForLoopIterators( ArrayRef gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector>> &gridAxesAssignmentForLoopIterators) { AffineDimExpr affineDimExpr = cast(indexingExpr); unsigned loopIteratorIdx = affineDimExpr.getPosition(); if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) { assert(llvm::equal(gridAxesAssignmentForTensorAxis, *gridAxesAssignmentForLoopIterators[loopIteratorIdx])); } else { gridAxesAssignmentForLoopIterators[loopIteratorIdx] = llvm::to_vector(gridAxesAssignmentForTensorAxis); } } ShardingArray shard::getGridAxisAssignmentForLoopIterators( ArrayRef operandShardings, ArrayRef resultShardings, ArrayRef loopIteratorTypes, ArrayRef indexingMaps) { SmallVector>> gridAxisAssignmentForLoopIterators(loopIteratorTypes.size()); std::vector operatorAndResultShardings; operatorAndResultShardings.reserve(operandShardings.size() + resultShardings.size()); llvm::append_range(operatorAndResultShardings, operandShardings); for (auto [sharding, affineMap] : llvm::zip_equal(operatorAndResultShardings, indexingMaps)) { if (!sharding) { continue; } for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] : llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { updateGridAxisAssignmentForLoopIterators( gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, gridAxisAssignmentForLoopIterators); } // Missing trailing split axes means replication on those tensor dimensions. for (unsigned i = sharding.getSplitAxes().size(); i < affineMap.getNumResults(); ++i) { updateGridAxisAssignmentForLoopIterators( {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators); } } ShardingArray res; llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res), [](std::optional> &axes) { if (!axes) { return SmallVector(); }; return std::move(*axes); }); return res; } bool shard::isAtLeastOneReductionIteratorSharded( ArrayRef loopIteratorTypes, ArrayRef> gridAxisAssignmentForLoopIterators) { for (auto [loopIteratorType, gridAxisAssignment] : llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { if (loopIteratorType == utils::IteratorType::reduction && !gridAxisAssignment.empty()) { return true; } } return false; } SmallVector shard::getReductionGridAxes( ArrayRef loopIteratorTypes, ArrayRef> gridAxisAssignmentForLoopIterators) { SmallVector gridAxes; for (auto [loopIteratorType, gridAxisAssignment] : llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { if (loopIteratorType == utils::IteratorType::reduction) { llvm::append_range(gridAxes, gridAxisAssignment); } } return gridAxes; } void shard::partitionTriviallyShardableOperation( Operation &op, ArrayRef partitionedOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) { // `clone` will populate the mapping of old to new results. Operation *newOp = builder.clone(op, partitionMap); // Set the result types to the sharded counterparts. for (auto [oldResult, newResult, sharding] : llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { newResult.setType(shardType( newResult.getType(), getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding)); } }