//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir #define DEBUG_TYPE "xegpu-propagate-layout" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") using namespace mlir; using namespace mlir::dataflow; namespace { //===----------------------------------------------------------------------===// // LayoutInfo //===----------------------------------------------------------------------===// /// Helper class for tracking the analysis state of an mlir value. For layout /// propagation, the analysis state is simply the distribution layout of /// each value. The distribution layout information is encapsulated using /// xegpu::DistributeLayoutAttr class which can hold information about any type /// of distribution layout that XeGPU dialect supports. Purpose of this analysis /// to propagate some unique distribution layout for each value in the program /// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note /// that analysis will reach a fixed point when all values are reached some /// layout and, analysis does not try to modify any already assigned layouts. /// /// Given this, LayoutInfo satisifies the following properties: /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not /// assigned`. /// 2) Two LayoutInfo values are equal if they are both assigned or /// both not assigned. The concrete value of assigned state does not matter. /// 3) The meet operator works as follows: /// - If current state is assigned, return the current state. (already /// a unique layout is assigned. don't change it) /// - Otherwise, return the other state. struct LayoutInfo { private: xegpu::DistributeLayoutAttr storage = nullptr; public: LayoutInfo() = default; LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {} // Two lattice values are equal if they have `some` layout. The actual // content of the layout does not matter. bool operator==(const LayoutInfo &other) const { return this->isAssigned() == other.isAssigned(); } static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs); static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs); void print(raw_ostream &os) const; bool isAssigned() const { return storage != nullptr; } LayoutInfo transpose(ArrayRef permutation) const; SmallVector getLaneLayout() const; SmallVector getLaneData() const; SmallVector getInstData() const; SmallVector getSgLayout() const; SmallVector getSgData() const; SmallVector getOrder() const; bool isSliceLayout() const { if (!isAssigned()) return false; return isa(storage); } int64_t getRank() const { if (!isAssigned()) return -1; return storage.getRank(); } Attribute get() { return storage; } void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; } }; SmallVector LayoutInfo::getLaneLayout() const { if (!isAssigned()) return {}; return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(), [](int64_t val) { return static_cast(val); }); } SmallVector LayoutInfo::getLaneData() const { if (!isAssigned()) return {}; return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(), [](int64_t val) { return static_cast(val); }); } SmallVector LayoutInfo::getInstData() const { if (!isAssigned()) return {}; return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(), [](int64_t val) { return static_cast(val); }); } SmallVector LayoutInfo::getSgLayout() const { if (!isAssigned()) return {}; return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(), [](int64_t val) { return static_cast(val); }); } SmallVector LayoutInfo::getSgData() const { if (!isAssigned()) return {}; return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(), [](int64_t val) { return static_cast(val); }); } SmallVector LayoutInfo::getOrder() const { if (!isAssigned() || !storage.getOrder()) return {}; return llvm::map_to_vector(storage.getOrder().asArrayRef(), [](int64_t val) { return static_cast(val); }); } void LayoutInfo::print(raw_ostream &os) const { if (isAssigned()) { os << storage; } else { os << "Not assigned."; } } LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) { if (!lhs.isAssigned()) return rhs; return lhs; } /// Since this is a backward analysis, join method is not used. LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) { llvm_unreachable("Join should not be triggered by layout propagation."); } /// Construct a new layout with the transposed inst_data or lane_layout, /// lane_data. LayoutInfo LayoutInfo::transpose(ArrayRef permutation) const { if (!isAssigned()) return {}; // Check if the permutation is valid. llvm::SmallSet seen(permutation.begin(), permutation.end()); bool hasDuplicates = seen.size() != permutation.size(); bool withinRange = llvm::all_of(permutation, [&](int64_t idx) { return idx >= 0 && idx < static_cast(permutation.size()); }); if (!withinRange || hasDuplicates) { assert(false && "Invalid permutation for transpose."); return {}; } SmallVector laneLayout; SmallVector laneData; SmallVector instData; SmallVector sgLayout; SmallVector sgData; SmallVector order; for (int64_t idx : permutation) { if (getLaneLayout().size()) { laneLayout.push_back(static_cast(getLaneLayout()[idx])); laneData.push_back(static_cast(getLaneData()[idx])); } if (getInstData().size()) instData.push_back(static_cast(getInstData()[idx])); if (getSgData().size()) { sgLayout.push_back(static_cast(getSgLayout()[idx])); sgData.push_back(static_cast(getSgData()[idx])); } if (getOrder().size()) { order.push_back(static_cast(getOrder()[idx])); } } auto orderAttr = order.size() ? DenseI32ArrayAttr::get(storage.getContext(), order) : nullptr; xegpu::LayoutAttr layoutAttr; if (getLaneLayout().size()) layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData); if (getInstData().size()) layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData); if (getSgData().size()) layoutAttr = xegpu::LayoutAttr::get( storage.getContext(), DenseI32ArrayAttr::get(storage.getContext(), sgLayout), DenseI32ArrayAttr::get(storage.getContext(), sgData), /*inst_data =*/nullptr, /*lane_layout =*/nullptr, /*lane_data =*/nullptr, orderAttr); return LayoutInfo(layoutAttr); } //===----------------------------------------------------------------------===// // LayoutInfoLattice //===----------------------------------------------------------------------===// /// Lattice holding the LayoutInfo for each value. struct LayoutInfoLattice : public Lattice { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice) using Lattice::Lattice; }; /// Helper Functions to get default layouts. A `default layout` is a layout that /// is assigned to a value when the layout is not fixed by some anchor operation /// (like DPAS). /// Helper Function to get the default layout for uniform values like constants. /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1]. /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank, const xegpu::uArch::uArch *uArch) { assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); if (rank == 1) { return LayoutInfo( xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1})); } return LayoutInfo( xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1})); } static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank, int subgroupSize) { assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); if (rank == 1) { return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1})); } return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1})); } /// Helper to get the default layout for 2D block operations. template static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty, const xegpu::uArch::uArch *uArch, unsigned packingSize) { // Expecting a 1D or 2D vector. assert((ty.getRank() == 1 || ty.getRank() == 2) && "Expected 1D or 2D vector."); // Expecting int or float element type. assert(ty.getElementType().isIntOrFloat() && "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (ty.getRank() == 1) return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth(); int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; return LayoutInfo(xegpu::LayoutAttr::get( ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor})); } //===----------------------------------------------------------------------===// // LayoutInfoPropagation //===----------------------------------------------------------------------===// /// Backward data flow analysis to propagate the lane_layout and lane_data of /// each value in the program. Currently, the layouts for operands DPAS, /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of /// this analysis is to propagate those known layouts to all their producers and /// (other) consumers. class LayoutInfoPropagation : public SparseBackwardDataFlowAnalysis { private: xegpu::LayoutKind layoutKind; unsigned indexBitWidth; void visitDpasOp(xegpu::DpasOp dpas, ArrayRef operands, ArrayRef results); void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef operands, ArrayRef results); void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter, ArrayRef operands, ArrayRef results); void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef operands, ArrayRef results); void visitLoadGatherOp(xegpu::LoadGatherOp load, ArrayRef operands, ArrayRef results); void visitTransposeOp(vector::TransposeOp transpose, ArrayRef operands, ArrayRef results); void visitVectorBitcastOp(vector::BitCastOp bitcast, ArrayRef operands, ArrayRef results); void visitCreateDescOp(xegpu::CreateDescOp createDesc, ArrayRef operands, ArrayRef results); void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef operands, ArrayRef results); void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch, ArrayRef operands, ArrayRef results); void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction, ArrayRef operands, ArrayRef results); void visitVectorReductionOp(vector::ReductionOp reduction, ArrayRef operands, ArrayRef results); void visitVectorBroadCastOp(vector::BroadcastOp broadcast, ArrayRef operands, ArrayRef results); void visitShapeCastOp(vector::ShapeCastOp shapeCast, ArrayRef operands, ArrayRef results); void visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice, ArrayRef operands, ArrayRef results); void visitLoadMatrixOp(xegpu::LoadMatrixOp load, ArrayRef operands, ArrayRef results); void visitStoreMatrixOp(xegpu::StoreMatrixOp store, ArrayRef operands, ArrayRef results); void visitLoadGatherOp(xegpu::LoadMatrixOp load, ArrayRef operands, ArrayRef results); void visitStoreScatterOp(xegpu::StoreMatrixOp store, ArrayRef operands, ArrayRef results); bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout); public: LayoutInfoPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable, xegpu::LayoutKind layoutKind, unsigned indexBitWidth) : SparseBackwardDataFlowAnalysis(solver, symbolTable), layoutKind(layoutKind), indexBitWidth(indexBitWidth) {} using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; LogicalResult visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override; void visitBranchOperand(OpOperand &operand) override {}; void visitCallOperand(OpOperand &operand) override {}; void visitNonControlFlowArguments(RegionSuccessor &successor, ArrayRef arguments) override {}; void visitExternalCall(CallOpInterface call, ArrayRef operands, ArrayRef results) override { }; void setToExitState(LayoutInfoLattice *lattice) override { (void)lattice->meet(LayoutInfo()); } }; } // namespace LogicalResult LayoutInfoPropagation::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { TypeSwitch(op) .Case( [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); }) .Case([&](xegpu::StoreNdOp storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); }) .Case([&](xegpu::StoreScatterOp storeScatterOp) { visitStoreScatterOp(storeScatterOp, operands, results); }) .Case([&](xegpu::LoadNdOp loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); }) .Case([&](xegpu::LoadGatherOp loadGatherOp) { visitLoadGatherOp(loadGatherOp, operands, results); }) .Case([&](xegpu::CreateDescOp createDescOp) { visitCreateDescOp(createDescOp, operands, results); }) .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) { visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results); }) .Case([&](xegpu::PrefetchNdOp prefetchNdOp) { visitPrefetchNdOp(prefetchNdOp, operands, results); }) .Case([&](vector::TransposeOp transposeOp) { visitTransposeOp(transposeOp, operands, results); }) .Case([&](vector::BitCastOp bitcastOp) { visitVectorBitcastOp(bitcastOp, operands, results); }) .Case([&](vector::MultiDimReductionOp reductionOp) { visitVectorMultiReductionOp(reductionOp, operands, results); }) .Case([&](vector::ReductionOp reductionOp) { visitVectorReductionOp(reductionOp, operands, results); }) .Case([&](vector::BroadcastOp broadcastOp) { visitVectorBroadCastOp(broadcastOp, operands, results); }) .Case([&](vector::ShapeCastOp shapeCastOp) { visitShapeCastOp(shapeCastOp, operands, results); }) .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) { visitInsertStridedSliceOp(insertStridedSliceOp, operands, results); }) .Case([&](xegpu::LoadMatrixOp loadMatrixOp) { visitLoadMatrixOp(loadMatrixOp, operands, results); }) .Case([&](xegpu::StoreMatrixOp storeMatrixOp) { visitStoreMatrixOp(storeMatrixOp, operands, results); }) // All other ops. .Default([&](Operation *op) { for (const LayoutInfoLattice *resultInfo : results) { if (!resultInfo->getValue().isAssigned()) continue; for (auto [operandInfo, operand] : llvm::zip(operands, op->getOpOperands())) { // If the operand type is not a vector or tensor descriptor, skip // it. if (!isa( operand.get().getType())) continue; // Propagate the result layout to the operand. meet(operandInfo, *resultInfo); } } }); return success(); } bool LayoutInfoPropagation::hasParamsOfLayoutKind( xegpu::DistributeLayoutAttr anchorLayout) { if (anchorLayout == nullptr) { return false; } if (layoutKind == xegpu::LayoutKind::InstData) { return !(anchorLayout.getEffectiveInstDataAsInt().empty()); } if (layoutKind == xegpu::LayoutKind::Lane) { return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() || anchorLayout.getEffectiveLaneDataAsInt().empty()); } if (layoutKind == xegpu::LayoutKind::Subgroup) { return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() || anchorLayout.getEffectiveSgDataAsInt().empty()); } return false; } // This function returns all layouts for the given sgCount, whose sgData: // 1. Evenly divides the wgShape. // 2. Is a multiple of instData. // Example: // wgShape = [128, 64], instData = [8, 16], sgCount = 32 // Returns layouts: // [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32]. SmallVector> getValidLayouts(ArrayRef wgShape, ArrayRef instData, int64_t sgCount) { SmallVector> candidates; for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) { if (sgCount % sgLayout0) continue; int sgLayout1 = sgCount / sgLayout0; int sgData0 = wgShape[0] / sgLayout0; int sgData1 = wgShape[1] / sgLayout1; if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) || (sgData0 % instData[0] || sgData1 % instData[1])) continue; candidates.emplace_back(sgLayout0, sgLayout1); } // Sort primarily by how balanced they are // (i.e., minimize the absolute difference between the two dimensions), and // secondarily by the first dimension in ascending order. llvm::sort(candidates, [](const std::pair &lhs, const std::pair &rhs) { int diffLhs = std::abs(lhs.first - lhs.second); int diffRhs = std::abs(rhs.first - rhs.second); if (diffLhs != diffRhs) return diffLhs < diffRhs; return lhs.first < rhs.first; }); return candidates; } FailureOr getNumSg(Operation *op, const int sgSize) { // Oblivious to workitem layout, the total count matters. auto gpuFunc = op->getParentOfType(); if (!gpuFunc) return failure(); auto knownBlockSize = gpuFunc.getKnownBlockSize(); if (!knownBlockSize.has_value()) return failure(); const int flatBlockSize = llvm::product_of(knownBlockSize.value()); return flatBlockSize / sgSize; } void LayoutInfoPropagation::visitPrefetchNdOp( xegpu::PrefetchNdOp prefetch, ArrayRef operands, ArrayRef results) { LayoutInfo prefetchLayout; xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr(); if (hasParamsOfLayoutKind(anchorLayout)) { prefetchLayout = LayoutInfo(anchorLayout); } else { // Here we assign the default layout to the tensor descriptor operand of // prefetch. auto tdescTy = prefetch.getTensorDescType(); const uArch *uArch = getUArch(getChipStr(prefetch).value_or("")); if (!uArch) return; const auto *uArchInstruction = dyn_cast( uArch->getInstruction( xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); auto blockWHC = uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); if (!blockWHC) prefetch.emitWarning("No known block params found for the element type."); auto [bWidth, bHeight, bCount] = blockWHC.value(); SmallVector instData; int instWidth = xegpu::getLargestDivisor( static_cast(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth); if (instWidth == -1) prefetch.emitWarning( "No suitable instruction multiple found for the given shape."); if (tdescTy.getRank() == 1) instData = {instWidth}; else { int instHeight = xegpu::getLargestDivisor( static_cast(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); if (instHeight == -1) prefetch.emitWarning( "No suitable instruction multiple found for the given shape."); instData = {instHeight, instWidth}; } if (layoutKind == xegpu::LayoutKind::InstData) prefetchLayout = LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); else prefetchLayout = getSIMTLayoutInfoBlockIO( tdescTy, uArch, uArchInstruction->getPackedFormatBitSize()); prefetch.setLayoutAttr( dyn_cast(prefetchLayout.get())); } // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } void LayoutInfoPropagation::visitVectorMultiReductionOp( vector::MultiDimReductionOp reduction, ArrayRef operands, ArrayRef results) { Type resultTy = reduction.getDestType(); // The layout of the result must be present. LayoutInfo resLayoutInfo = results[0]->getValue(); xegpu::DistributeLayoutAttr consumerLayoutAttr; if (!resultTy.isIntOrFloat()) { if (!resLayoutInfo.isAssigned()) return; consumerLayoutAttr = dyn_cast(resLayoutInfo.get()); } VectorType sourceTy = reduction.getSourceVectorType(); SmallVector reductionDims(reduction.getReductionDims()); const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or("")); if (!uArch) return; int numSg = 0; if (layoutKind == xegpu::LayoutKind::Subgroup) { auto numSgOrErr = getNumSg(reduction, uArch->getSubgroupSize()); if (succeeded(numSgOrErr)) numSg = numSgOrErr.value(); } // The result layout represents the layout requirements of the operation. // it is recorded to anchor layout or temporary layout. // it must be honored for current op and may conflict with the layout // propagated from consumer op, the conflict is resolved in later phase by // converting the required result layout to the consumer layout auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout( layoutKind, sourceTy, consumerLayoutAttr, reductionDims, numSg, uArch); xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr); // derive the source layout from the dominant layout and reduction dims auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout( requiredResLayoutAttr, reductionDims); propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr))); // Accumulator should have the same layout as the result. propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(requiredResLayoutAttr))); } void LayoutInfoPropagation::visitVectorReductionOp( vector::ReductionOp reduction, ArrayRef operands, ArrayRef results) { VectorType sourceTy = reduction.getSourceVectorType(); const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or("")); if (!uArch) return; auto requiredResLayoutAttr = xegpu::setupReductionResultLayout(layoutKind, sourceTy, uArch); xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr); auto srcLayoutAttr = xegpu::inferReductionSourceLayout(requiredResLayoutAttr); propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr))); if (reduction.getAcc()) propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(requiredResLayoutAttr))); } void LayoutInfoPropagation::visitVectorBroadCastOp( vector::BroadcastOp broadcast, ArrayRef operands, ArrayRef results) { // The layout of the result must be present. LayoutInfo resLayoutInfo = results[0]->getValue(); if (!resLayoutInfo.isAssigned()) return; // Only consider vector to vector broadcasts for now. VectorType resultTy = broadcast.getResultVectorType(); VectorType sourceTy = dyn_cast(broadcast.getSourceType()); // skip layout propagation for non-vector source operand. if (!sourceTy) return; auto srcShape = sourceTy.getShape(); auto resShape = resultTy.getShape(); auto resultLayoutAttr = dyn_cast(resLayoutInfo.get()); xegpu::DistributeLayoutAttr srcLayoutAttr = xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape); propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr))); } void LayoutInfoPropagation::visitShapeCastOp( vector::ShapeCastOp shapeCast, ArrayRef operands, ArrayRef results) { // The layout of the result must be present. LayoutInfo resLayoutInfo = results[0]->getValue(); if (!resLayoutInfo.isAssigned()) return; ArrayRef resShape = shapeCast.getResultVectorType().getShape(); ArrayRef srcShape = shapeCast.getSourceVectorType().getShape(); auto resultLayoutAttr = dyn_cast(resLayoutInfo.get()); xegpu::DistributeLayoutAttr srcLayoutAttr = xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape); propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr))); } /// Propagate the layout of the result tensor to the source tensor descriptor /// in UpdateNdOffsetOp. void LayoutInfoPropagation::visitUpdateNdOffsetOp( xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef operands, ArrayRef results) { // The layout of the result must be present. LayoutInfo resultLayout = results[0]->getValue(); if (!resultLayout.isAssigned()) return; // Propagate the layout to the source operand. propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); } /// Set the layouts for DPAS A, B, and C operands. void LayoutInfoPropagation::visitDpasOp( xegpu::DpasOp dpas, ArrayRef operands, ArrayRef results) { LayoutInfo dpasALayout; LayoutInfo dpasBLayout; LayoutInfo dpasCDLayout; xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr(); if (hasParamsOfLayoutKind(anchorLayoutCD)) { xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr(); xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr(); assert(hasParamsOfLayoutKind(anchorLayoutA) && "Expected anchor layout for DPAS A operand."); assert(hasParamsOfLayoutKind(anchorLayoutB) && "Expected anchor layout for DPAS B operand."); dpasALayout = LayoutInfo(anchorLayoutA); dpasBLayout = LayoutInfo(anchorLayoutB); dpasCDLayout = LayoutInfo(anchorLayoutCD); } else { const uArch *uArch = getUArch(getChipStr(dpas).value_or("")); if (!uArch) return; VectorType aTy = dpas.getLhsType(); VectorType bTy = dpas.getRhsType(); VectorType cdTy = dpas.getResultType(); xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr; xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout, requiredBLayout; int numSg = 0; if (layoutKind == xegpu::LayoutKind::Subgroup) { LayoutInfo consumerLayout = results[0]->getValue(); if (!consumerLayout.isAssigned()) return; consumerLayoutAttr = dyn_cast(consumerLayout.get()); auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize()); if (failed(numSgOrErr)) { dpas.emitWarning( "Unable to determine the number of subgroups for the operation."); return; } numSg = numSgOrErr.value(); } auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy, consumerLayoutAttr, numSg, uArch); if (!layouts.has_value()) { dpas.emitWarning( "Failed to determine required layouts for DPAS operands."); return; } std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts; dpas.setLayoutAAttr(requiredALayout); dpas.setLayoutBAttr(requiredBLayout); dpas.setLayoutCdAttr(requiredCDLayoutAttr); dpasALayout = LayoutInfo(requiredALayout); dpasBLayout = LayoutInfo(requiredBLayout); dpasCDLayout = LayoutInfo(requiredCDLayoutAttr); } propagateIfChanged(operands[0], operands[0]->meet(dpasALayout)); propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout)); if (operands.size() > 2) propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout)); } /// Set the layout for the value and tensor descriptor operands in StoreNdOp. void LayoutInfoPropagation::visitStoreNdOp( xegpu::StoreNdOp store, ArrayRef operands, ArrayRef results) { LayoutInfo storeLayout; xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr(); if (hasParamsOfLayoutKind(anchorLayout)) { storeLayout = LayoutInfo(anchorLayout); } else { const uArch *uArch = getUArch(getChipStr(store).value_or("")); if (!uArch) return; const auto *uArchInstruction = dyn_cast( uArch->getInstruction( xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); VectorType dataTy = store.getValueType(); auto blockWHC = uArchInstruction->getBlockWidthHeightCount( store.getValueType().getElementType()); if (!blockWHC) store.emitWarning("No known block params found for the element type."); auto [bWidth, bHeight, bCount] = blockWHC.value(); SmallVector instData; int instWidth = xegpu::getLargestDivisor( static_cast(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth); if (instWidth == -1) store.emitWarning( "No suitable instruction multiple found for the given shape."); if (dataTy.getRank() == 1) instData = {instWidth}; else { int instHeight = xegpu::getLargestDivisor( static_cast(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); if (instHeight == -1) store.emitWarning( "No suitable instruction multiple found for the given shape."); instData = {instHeight, instWidth}; } if (layoutKind == xegpu::LayoutKind::InstData) storeLayout = LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); else if (layoutKind == xegpu::LayoutKind::Lane) storeLayout = getSIMTLayoutInfoBlockIO(store.getValueType(), uArch, uArchInstruction->getPackedFormatBitSize()); else { // xegpu::LayoutKind::Subgroup auto sgSize = uArch->getSubgroupSize(); auto numSgOrErr = getNumSg(store, sgSize); if (failed(numSgOrErr)) { store.emitWarning( "Unable to determine the number of subgroups for the operation."); return; } auto sgLayouts = getValidLayouts(store.getValueType().getShape(), instData, numSgOrErr.value()); if (sgLayouts.empty()) { store.emitWarning( "Unable to determine suitable subgroup layout for store value."); return; } SmallVector sgLayout = {sgLayouts[0].first, sgLayouts[0].second}; SmallVector sgData = { static_cast(dataTy.getShape()[0]) / sgLayout[0], static_cast(dataTy.getShape()[1]) / sgLayout[1]}; storeLayout = LayoutInfo(xegpu::LayoutAttr::get( dataTy.getContext(), DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout), DenseI32ArrayAttr::get(dataTy.getContext(), sgData), /*inst_data =*/nullptr, /*lane_layout =*/nullptr, /*lane_data =*/nullptr, /*order =*/nullptr)); } store.setLayoutAttr( dyn_cast(storeLayout.get())); } // Propagate the layout to the value operand. // Both operands should have the same layout for (LayoutInfoLattice *operand : operands) propagateIfChanged(operand, operand->meet(storeLayout)); } /// Propagate the layout of the value to the tensor descriptor operand in /// LoadNdOp. void LayoutInfoPropagation::visitLoadNdOp( xegpu::LoadNdOp load, ArrayRef operands, ArrayRef results) { LayoutInfo loadLayout; xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); if (hasParamsOfLayoutKind(anchorLayout)) { loadLayout = LayoutInfo(anchorLayout); } else { LayoutInfo valueLayout = results[0]->getValue(); // Need the layout of the value to propagate to the tensor descriptor. if (!valueLayout.isAssigned()) return; loadLayout = valueLayout; // LoadNdOp has the transpose effect. However, at the stage of this analysis // this effect is not expected and should be abstracted away. Emit a // warning. if (auto transpose = load.getTranspose()) { load.emitWarning("Transpose effect is not expected for LoadNdOp at " "LayoutInfoPropagation stage."); loadLayout = valueLayout.transpose(transpose.value()); } load.setLayoutAttr(dyn_cast(loadLayout.get())); } // Propagate the new layout to the tensor descriptor operand. propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); } /// For vector::TransposeOp, the layout of the result is transposed and /// propagated to the operand. void LayoutInfoPropagation::visitTransposeOp( vector::TransposeOp transpose, ArrayRef operands, ArrayRef results) { // Need the layout of transpose result to propagate to the operands. LayoutInfo resultLayout = results[0]->getValue(); if (!resultLayout.isAssigned()) return; auto consumerLayoutAttr = dyn_cast(resultLayout.get()); auto srcLayoutAttr = xegpu::inferTransposeSourceLayout( consumerLayoutAttr, transpose.getPermutation()); // Propagate the new layout to the vector operand. propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr))); } /// For vector::BitCastOp, the lane_data of the source layout is changed based /// on the bit width of the source and result types. void LayoutInfoPropagation::visitVectorBitcastOp( vector::BitCastOp bitcast, ArrayRef operands, ArrayRef results) { // Need the layout of bitcast result to propagate to the operands. LayoutInfo resLayoutInfo = results[0]->getValue(); if (!resLayoutInfo.isAssigned()) return; auto srcVecType = bitcast.getSourceVectorType(); auto resVecType = bitcast.getResultVectorType(); auto consumerLayoutAttr = dyn_cast(resLayoutInfo.get()); const uArch *uArch = getUArch(xegpu::getChipStr(bitcast).value_or("")); if (!uArch) return; auto requiredResLayoutAttr = setupBitCastResultLayout( layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch); xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr); int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth(); int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth(); // derive the source layout from the dominant layout and reduction dims auto srcLayoutAttr = xegpu::inferBitCastSourceLayout( requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth); propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr))); } void LayoutInfoPropagation::visitInsertStridedSliceOp( vector::InsertStridedSliceOp insertStridedSlice, ArrayRef operands, ArrayRef results) { // The layout of the result must be present. LayoutInfo resLayoutInfo = results[0]->getValue(); if (!resLayoutInfo.isAssigned()) return; auto srcVecType = insertStridedSlice.getSourceVectorType(); auto resVecType = insertStridedSlice.getDestVectorType(); auto consumerLayoutAttr = dyn_cast(resLayoutInfo.get()); const uArch *uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or("")); if (!uArch) return; auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout( layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch); xegpu::setTemporaryLayout(insertStridedSlice->getResult(0), requiredResLayoutAttr); auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout( requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape()); propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr))); propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(requiredResLayoutAttr))); } /// Propagate the layout of the result to the tensor descriptor, mask and offset /// operands in LoadGatherOp. void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef operands, ArrayRef results) { xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr; xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr(); const uArch *uArch = getUArch(getChipStr(load).value_or("")); if (!uArch) return; auto subgroupSize = uArch->getSubgroupSize(); VectorType resVecTy = load.getValueType(); int chunkSize = load.getChunkSize().value_or(1); LayoutInfo resLayoutInfo = results[0]->getValue(); if (!resLayoutInfo.isAssigned()) return; auto consumerLayoutAttr = dyn_cast(resLayoutInfo.get()); if (hasParamsOfLayoutKind(anchorLayoutAttr)) { requiredAnchorLayoutAttr = anchorLayoutAttr; } else { if (!resVecTy) { load.emitWarning("Not propagating, non-vector payload supplied."); return; } requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout( layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch); load.setLayoutAttr(requiredAnchorLayoutAttr); } auto maskLayoutAttr = requiredAnchorLayoutAttr; // Special handling mask layout for chunked ops: Enforce the default xegpu 1D // layout for mask. if (chunkSize > 1) { if (layoutKind == xegpu::LayoutKind::InstData) maskLayoutAttr = xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}); else if (layoutKind == xegpu::LayoutKind::Lane) maskLayoutAttr = xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1}); else assert(false && "chunked StoreScatterOp should not be used at workgroup level"); } LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr); auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr); // Propagate the new layout to the tensor descriptor operand. if (isa(load.getSourceType())) propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo)); // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo)); if (load.getOffsets()) propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo)); } /// Propagate the layout of the descriptor to the vector offset operand in /// CreateDescOp. void LayoutInfoPropagation::visitCreateDescOp( xegpu::CreateDescOp createDesc, ArrayRef operands, ArrayRef results) { LayoutInfo descLayout = results[0]->getValue(); // Need the layout of the descriptor to propagate to the operands. if (!descLayout.isAssigned()) return; const uArch *uArch = getUArch(getChipStr(createDesc).value_or("")); if (!uArch) return; // For offset operand propagate 1D default layout. LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1, uArch->getSubgroupSize()); propagateIfChanged(operands[1], operands[1]->meet(layout)); } /// Set the layout for the value, tensor descriptor, offset and mask operands in /// the StoreScatterOp. void LayoutInfoPropagation::visitStoreScatterOp( xegpu::StoreScatterOp storeScatter, ArrayRef operands, ArrayRef results) { xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr; xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr(); const uArch *uArch = getUArch(getChipStr(storeScatter).value_or("")); if (!uArch) return; auto subgroupSize = uArch->getSubgroupSize(); VectorType srcVecTy = storeScatter.getValueType(); int chunkSize = storeScatter.getChunkSize().value_or(1); if (hasParamsOfLayoutKind(anchorLayoutAttr)) { requiredAnchorLayoutAttr = anchorLayoutAttr; } else { if (!srcVecTy) { storeScatter.emitWarning("Not propagating, non-vector payload supplied."); return; } requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout( layoutKind, srcVecTy, chunkSize, uArch); storeScatter.setLayoutAttr(requiredAnchorLayoutAttr); } LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr); auto maskLayoutAttr = requiredAnchorLayoutAttr; // Special handling mask layout for chunked ops: Enforce the default xegpu 1D // layout for mask. if (chunkSize > 1) { if (layoutKind == xegpu::LayoutKind::InstData) maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}); else if (layoutKind == xegpu::LayoutKind::Lane) maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}, {1}); else assert(false && "chunked StoreScatterOp should not be used at workgroup level"); } LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr); // Propagate the payload operand layout propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo)); // Propagate the destination (if tdesc) operand layout if (isa(storeScatter.getDestType())) propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo)); // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo)); if (storeScatter.getOffsets()) propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo)); } void LayoutInfoPropagation::visitLoadMatrixOp( xegpu::LoadMatrixOp loadMatrixOp, ArrayRef operands, ArrayRef results) { LayoutInfo resLayoutInfo = results[0]->getValue(); if (!resLayoutInfo.isAssigned()) return; auto consumerLayoutAttr = dyn_cast(resLayoutInfo.get()); xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr(); // only need to set anchor layout, no need to porpagate to memdesc and // offset if (!hasParamsOfLayoutKind(anchorLayout)) { VectorType resVecTy = llvm::cast(loadMatrixOp.getRes().getType()); const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or("")); if (!uArch) return; auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout( layoutKind, resVecTy, consumerLayoutAttr, uArch); loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr); } } // Store matrix is a flavor of scattered store for 2D shapes. void LayoutInfoPropagation::visitStoreMatrixOp( xegpu::StoreMatrixOp storeMatrix, ArrayRef operands, ArrayRef results) { xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr(); LayoutInfo layout; if (hasParamsOfLayoutKind(anchorLayout)) { layout = LayoutInfo(anchorLayout); } else { VectorType srcVecTy = llvm::cast(storeMatrix.getData().getType()); const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or("")); if (!uArch) return; auto requiredAnchorLayoutAttr = xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch); storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr); layout = LayoutInfo(requiredAnchorLayoutAttr); } propagateIfChanged(operands[0], operands[0]->meet(layout)); } namespace { //===----------------------------------------------------------------------===// // RunLayoutInfoPropagation //===----------------------------------------------------------------------===// /// Driver class for running the LayoutInfoPropagation analysis. class RunLayoutInfoPropagation { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind, unsigned indexBitWidth) : target(op) { SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); solver.load(symbolTable, layoutKind, indexBitWidth); (void)solver.initializeAndRun(op); } LayoutInfo getLayoutInfo(Value val); void printAnalysisResult(llvm::raw_ostream &os); private: DataFlowSolver solver; const Operation *target; }; } // namespace LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) { auto *state = solver.lookupState(val); if (!state) return {}; return state->getValue(); } // Print the analysis result for debugging purposes. void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) { auto printFunctionResult = [&](FunctionOpInterface funcOp) { os << "function: " << funcOp.getName() << ":\n"; // Function arguments for (BlockArgument arg : funcOp.getArguments()) { LayoutInfo layout = getLayoutInfo(arg); os << "argument: " << arg << "\n"; os << "layout : "; layout.print(os); os << "\n"; } // Function ops funcOp.walk([&](Operation *op) { // Skip ops that do not have results if (op->getResults().empty()) return; os << "op : "; // For control-flow ops, print the op name only. if (isa(op) || isa(op)) os << op->getName(); else op->print(os); os << "\n"; // Print the layout for each result. for (auto [i, r] : llvm::enumerate(op->getResults())) { LayoutInfo layout = getLayoutInfo(r); os << "layout for result #" << i << ": "; layout.print(os); os << "\n"; } }); }; SmallVector funcOps; if (auto modOp = dyn_cast(target)) { for (auto funcOp : modOp.getOps()) funcOps.push_back(funcOp); // Collect all GpuFuncOps in the module. for (auto gpuModOp : modOp.getOps()) { for (auto gpuFuncOp : gpuModOp.getOps()) funcOps.push_back(gpuFuncOp); } } // Print the analysis result for each function. for (FunctionOpInterface funcOp : funcOps) printFunctionResult(funcOp); } namespace { //===----------------------------------------------------------------------===// // ResolveLayoutConflicts //===----------------------------------------------------------------------===// /// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This /// function tries to find the defining CreateNdDescOp recursively accross /// control-flow boundaries. static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) { // Try to get the defining CreateNdDescOp of the tensor descriptor. auto definingOp = tdescValue.getDefiningOp(); if (definingOp) return definingOp; // If tdescValue is an argument, try to get the tied init value from the // parent loop-like op. if (auto arg = dyn_cast(tdescValue)) { auto *parentOp = arg.getOwner()->getParentOp(); if (auto loop = dyn_cast(parentOp)) { OpOperand *tiedInit = loop.getTiedLoopInit(arg); if (tiedInit) return getDefiningCreateNdDescOp(tiedInit->get()); } } // If not found, return null. return nullptr; } struct ResolveLayoutConflicts { ResolveLayoutConflicts(Operation *parentOp) : parentOp(parentOp), builder(parentOp->getContext()) {} LogicalResult run(); private: Operation *parentOp; OpBuilder builder; LogicalResult resolveTensorDescConsumer(OpOperand &operand); LogicalResult resolveVectorConsumer(OpOperand &operand); LogicalResult assignResultLayout(OpResult &result); }; } // namespace LogicalResult ResolveLayoutConflicts::run() { // Scan all operations in the parent op and resolve layout conflicts at // tensor descriptor and vector use points. auto r = parentOp->walk([&](Operation *op) -> WalkResult { // if the operation inputs vector and output scalar, like multi-reduction we // need to check if the result has layout and add a convert_layout to serve // as anchor op for the reduction op's layout. if (isa(op) || isa(op)) { for (OpResult result : op->getResults()) { if (result.getType().isIntOrFloat()) { auto res = assignResultLayout(result); if (failed(res)) { DBGS() << "Failed to resolve vector consumer for multi-reduction " << *op << "\n"; return WalkResult::interrupt(); } } } } for (OpOperand &operand : op->getOpOperands()) { // Handle conflicts in tensor descriptor operands. Type operandType = operand.get().getType(); if (isa(op) && isa(operandType)) { auto res = resolveTensorDescConsumer(operand); if (failed(res)) { DBGS() << "Failed to resolve tensor descriptor consumer: " << *op << "\n"; return WalkResult::interrupt(); } } // Handle conflicts in vector operands. if (isa(operandType)) { auto res = resolveVectorConsumer(operand); if (failed(res)) { DBGS() << "Failed to resolve vector consumer: " << *op << "\n"; return WalkResult::interrupt(); } } } return WalkResult::advance(); }); return r.wasInterrupted() ? failure() : success(); } LogicalResult ResolveLayoutConflicts::assignResultLayout(OpResult &result) { Operation *producerOp = result.getDefiningOp(); auto producerLayout = xegpu::getDistributeLayoutAttr(result); // Insert a convert_layout op to assign the layout. builder.setInsertionPointAfterValue(result); auto convertOp = xegpu::ConvertLayoutOp::create( builder, producerOp->getLoc(), result.getType(), result, producerLayout, producerLayout); result.replaceAllUsesExcept(convertOp.getResult(), convertOp); return success(); } LogicalResult ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) { Value vectorValue = operand.get(); Operation *consumerOp = operand.getOwner(); // Get the current layout of the vector value. auto producerLayout = xegpu::getDistributeLayoutAttr(vectorValue); if (!producerLayout) { if (auto vectorTy = dyn_cast(vectorValue.getType()); vectorTy && vectorTy.getRank() > 1) consumerOp->emitWarning("Expected layout for non-1D vectors."); return success(); // uniform non-tensor-data vector does not require layout } // Get the consumer expected layout at this operand. auto consumerLayout = xegpu::getConsumerLayoutAt(operand); if (!consumerLayout) return consumerOp->emitError( "No consumer layout found for vector operand."); // If layouts are same, no conflict exists, return success. if (consumerLayout.isEqualTo(producerLayout)) return success(); // Insert a convert_layout op to resolve the conflict. builder.setInsertionPointAfterValue(vectorValue); auto convertOp = xegpu::ConvertLayoutOp::create( builder, consumerOp->getLoc(), vectorValue.getType(), vectorValue, producerLayout, consumerLayout); // Update the operand to use the converted value. operand.set(convertOp.getResult()); return success(); } LogicalResult ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) { Operation *consumerOp = operand.getOwner(); Value tdescValue = operand.get(); auto anchorOp = dyn_cast(consumerOp); auto currTDescType = dyn_cast(tdescValue.getType()); assert(anchorOp && currTDescType && "Expected anchor layout op and tensor descriptor consumer."); // TODO: Scattered tensor desc is not supported for now. if (currTDescType.isScattered()) { DBGS() << "Scattered tensor descriptor not supported: " << tdescValue << "\n"; return failure(); } Attribute currLayout = currTDescType.getLayout(); Attribute expectedLayout = anchorOp.getAnchorLayout(); // A conflict exists in tensor descriptor operand if tensor descriptor's // layout is different from the anchor layout expected by the consumer. if (expectedLayout && currLayout && expectedLayout != currLayout) { // Try to get the defining CreateNdDescOp of the tensor descriptor. auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue); if (!conflictingCreateNdOp) { DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: " << tdescValue << "\n"; return failure(); } // Duplicate the CreateNdDescOp with the expected layout. builder.setInsertionPointAfter(conflictingCreateNdOp); auto newTensorDescType = xegpu::TensorDescType::get( conflictingCreateNdOp.getContext(), currTDescType.getShape(), currTDescType.getElementType(), currTDescType.getEncoding(), expectedLayout); xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create( builder, consumerOp->getLoc(), newTensorDescType, conflictingCreateNdOp->getOperands(), conflictingCreateNdOp->getAttrs()); // Replace the tensor descriptor operand in the consumer op with the new // tensor descriptor. consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult()); } return success(); } using GetLayoutFnTy = function_ref; /// Update an operation with the layout of its results. If the result type is /// a vector type, a temporary layout attribute is added to the operation. If /// the result type is a tensor descriptor type, the type is updated with the /// layout attribute. The users of the result are also updated with the layout /// attribute. static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue) { // Region ops (like scf.for) are already handled by the // updateControlFlowOps. if (mlir::isa(op)) return success(); // Iterate over all the results. for (OpResult result : op->getResults()) { Type resultType = result.getType(); // Layouts are needed only for vector and tensor descriptor types. if (!isa(resultType)) continue; // If the result has no layout but has users, emit a warning and continue. xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result); if (!layout && result.getNumUses() > 0) { op->emitWarning("op has users but no layout assigned for its result"); continue; } // If the result is a tensor descriptor type, update the tensor desc type // with layout. if (auto tensorDescTy = dyn_cast(resultType)) { auto typeWithLayout = xegpu::TensorDescType::get( tensorDescTy.getContext(), tensorDescTy.getShape(), tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout); result.setType(typeWithLayout); continue; } // If the result is a vector type, add a temporary layout attribute to the // op. xegpu::setDistributeLayoutAttr(result, layout); } return success(); } /// Region ops like scf.for need special handling because they have blocks /// inside. If the blocks have tensor descriptor type as block arguments, /// thier types must be updated. Also region op can have results that may not /// have any users (e.g. A and B tiles). They are not assigned a layout by /// layout analysis because they have no users. However inside the region op /// corresponding block arguments for these results do have layouts. /// Therefore, in this case we still need to update the result types with the /// layout attribute. This function function updates the internal block /// arguments and the result types of the region op with the assigned layouts. /// clang-format off /// Example: scf.for ... iter_args(...) -> (out types) { /// ^bb0(block types): /// ... /// scf.yield ... : (yield types) /// } /// clang-format on /// In this example, at scf.yield, control-flow can transfer to two successor /// regions. One is the ^bb0 (for loop body) and the other is the scf.for op /// itself (yield the results). So we update both the block arguments of the /// successor region (i.e. block types) and the result types of the scf.for op /// (i.e. out types). Note that yield types are updated by respective /// producers inside bb0. static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue) { // Only process if the terminator is inside a region branch op. auto branchOp = dyn_cast(terminator->getParentOp()); if (!branchOp) return success(); RegionBranchSuccessorMapping mapping; branchOp.getSuccessorOperandInputMapping(mapping, RegionBranchPoint(terminator)); for (const auto &[successorOperand, successorInputs] : mapping) { for (Value successorInput : successorInputs) { Type inputType = successorInput.getType(); // We only need to operate on tensor descriptor or vector types. if (!isa(inputType)) continue; xegpu::DistributeLayoutAttr successorInputLayout = getLayoutOfValue(successorInput); xegpu::DistributeLayoutAttr successorOperandLayout = getLayoutOfValue(successorOperand->get()); // If either of the layouts is not assigned, we cannot proceed. if (!successorOperandLayout) { LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in " "branch terminator: " << successorOperand->get() << "\n"); return failure(); } // We expect the layouts to match. if (successorInputLayout && successorInputLayout != successorOperandLayout) { LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and " "operand forwarded as the argument: " << successorInputLayout << " vs " << successorOperandLayout << "\n"); return failure(); } // Get tensor descriptor type with the layout. if (auto tdescTy = dyn_cast(inputType)) { auto newTdescTy = xegpu::TensorDescType::get( tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(), tdescTy.getEncoding(), successorOperandLayout); successorInput.setType(newTdescTy); continue; } // If the type is a vector type and this region argument is an OpResult, // set the layout attribute on the OpResult. if (auto result = dyn_cast(successorInput)) xegpu::setDistributeLayoutAttr(result, successorOperandLayout); } } return success(); } /// Update the function arguments and results with the layouts. static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue) { // Only process functions whose type is a standard MLIR FunctionType. // Functions using a different type representation (e.g. llvm.func with // LLVMFunctionType) are not targets for XeGPU layout propagation, and // calling setType(FunctionType{}) on them would corrupt their type. if (!isa(funcOp.getFunctionType())) return success(); SmallVector newArgTypes; // Update the function arguments. for (BlockArgument arg : funcOp.getArguments()) { Type argType = arg.getType(); newArgTypes.push_back(argType); if (!isa(argType)) continue; xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg); if (!layout) { LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg << " but got none.\n"); return failure(); } if (auto tensorDescTy = dyn_cast(argType)) { auto newTdescTy = xegpu::TensorDescType::get( tensorDescTy.getContext(), tensorDescTy.getShape(), tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout); arg.setType(newTdescTy); newArgTypes.back() = newTdescTy; } } // Update the function type with the new argument types. // NOTE: We assume that function results are not expected to have layouts. funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes, funcOp.getResultTypes())); return success(); } namespace { struct XeGPUPropagateLayoutPass final : public xegpu::impl::XeGPUPropagateLayoutBase { XeGPUPropagateLayoutPass() = default; XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default; XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options) : XeGPUPropagateLayoutBase(std::move(options)) {} void runOnOperation() override; }; } // namespace LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly) { RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth); // Print the analysis result and exit. (for debugging purposes) if (printOnly) { auto &os = llvm::outs(); analysis.printAnalysisResult(os); return success(); } // Helper to convert LayoutInfo to xegpu::LayoutAttr. auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr { LayoutInfo layout = analysis.getLayoutInfo(val); if (!layout.isAssigned()) return {}; if (auto opResult = dyn_cast(val)) { Operation *defOp = opResult.getDefiningOp(); if (auto anchorOp = dyn_cast(defOp)) { auto anchorLayout = anchorOp.getAnchorLayout(); if (anchorLayout != nullptr) return anchorLayout; } xegpu::DistributeLayoutAttr requiredResLayoutAttr = xegpu::getTemporaryLayout(opResult); if (requiredResLayoutAttr != nullptr) return requiredResLayoutAttr; } xegpu::DistributeLayoutAttr layoutAttr = cast(layout.get()); if (layout.isSliceLayout()) return cast(layoutAttr); return cast(layoutAttr); }; Operation *op = target; auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult { for (mlir::Operation &op : llvm::reverse(block->getOperations())) { LogicalResult r = success(); TypeSwitch(&op) .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) { r = updateControlFlowOps(builder, branchTermOp, getXeGPULayoutForValue); }) .Case([&](mlir::FunctionOpInterface funcOp) { r = updateFunctionOpInterface(builder, funcOp, getXeGPULayoutForValue); }) .Default([&](Operation *op) { r = updateOp(builder, op, getXeGPULayoutForValue); }); if (failed(r)) { op.emitError("Failed to update operation with the layout."); return WalkResult::interrupt(); } } return WalkResult::advance(); }); if (walkResult.wasInterrupted()) return failure(); return success(); } LogicalResult xegpu::resolveLayoutConflicts(Operation *target) { ResolveLayoutConflicts resolver(target); return resolver.run(); } void XeGPUPropagateLayoutPass::runOnOperation() { xegpu::LayoutKind layoutKind; if (this->layoutKind == "lane") { layoutKind = xegpu::LayoutKind::Lane; } else if (this->layoutKind == "inst") { layoutKind = xegpu::LayoutKind::InstData; } else if (this->layoutKind == "subgroup") { layoutKind = xegpu::LayoutKind::Subgroup; } else { getOperation()->emitError("Unsupported layout kind option: " + this->layoutKind); signalPassFailure(); return; } OpBuilder builder(&getContext()); if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind, this->indexBitWidth, this->printOnly))) { signalPassFailure(); return; } // Resolve layout conflicts if any. if (failed(xegpu::resolveLayoutConflicts(getOperation()))) { signalPassFailure(); return; } }