This PR enhances insert_strided_slice layout rules to handle slice layout and adjust the layout to fit the src shape. It adds dropDims as layout utility function.
1129 lines
47 KiB
C++
1129 lines
47 KiB
C++
//===---- XeGPULayoutImpl.cpp - MLIR Utilities for XeGPUOps
|
|
//------------------===//
|
|
//
|
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements layout utility functions for XeGPU dialect
|
|
// transformation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/ValueRange.h"
|
|
#include "mlir/Interfaces/LoopLikeInterface.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include <cstdint>
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
|
|
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
|
|
op->walk([&](Operation *nestOp) {
|
|
for (OpOperand &opr : nestOp->getOpOperands()) {
|
|
auto layout = getDistributeLayoutAttr(opr.get());
|
|
setDistributeLayoutAttr(opr, layout);
|
|
}
|
|
|
|
for (OpResult result : nestOp->getOpResults()) {
|
|
auto layout = getDistributeLayoutAttr(result);
|
|
setDistributeLayoutAttr(result, layout);
|
|
}
|
|
});
|
|
}
|
|
|
|
SmallVector<NamedAttribute>
|
|
xegpu::dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
|
|
SmallVector<NamedAttribute> out;
|
|
out.reserve(attrs.size());
|
|
|
|
for (auto attr : attrs) {
|
|
if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
|
|
auto newLayout = dist.dropSgLayoutAndData();
|
|
if (newLayout)
|
|
out.emplace_back(attr.getName(), newLayout);
|
|
} else {
|
|
out.push_back(attr);
|
|
}
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
SmallVector<NamedAttribute>
|
|
xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
|
|
SmallVector<NamedAttribute> out;
|
|
out.reserve(attrs.size());
|
|
|
|
for (auto attr : attrs) {
|
|
if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
|
|
auto newLayout = dist.dropInstData();
|
|
if (newLayout)
|
|
out.emplace_back(attr.getName(), newLayout);
|
|
} else {
|
|
out.push_back(attr);
|
|
}
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
// Attach layout attributes to all vector-type operands of operations within
|
|
// the given operation's region. Reports an error if any vector operand lacks
|
|
// a layout attribute.
|
|
bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
|
|
auto result = rootOp->walk([&](Operation *op) {
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
// Layouts are needed for vector type only.
|
|
if (!isa<VectorType>(operand.get().getType()))
|
|
continue;
|
|
// Skip block arguments since they don't have defining ops to attach
|
|
// layout attributes to.
|
|
if (isa<BlockArgument>(operand.get()))
|
|
continue;
|
|
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
|
|
if (!layout) {
|
|
op->emitWarning("Could not find layout attribute for operand ")
|
|
<< operand.getOperandNumber() << " of operation " << op->getName();
|
|
continue;
|
|
}
|
|
xegpu::setTemporaryLayout(operand, layout);
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
return !result.wasInterrupted();
|
|
}
|
|
|
|
template <typename T, typename>
|
|
void xegpu::removeLayoutAttr(const T &operandOrResult) {
|
|
Operation *owner = operandOrResult.getOwner();
|
|
std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
|
|
if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
|
|
owner->removeAttr(name);
|
|
}
|
|
|
|
// Explicit instantiation for OpResult
|
|
template void
|
|
xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
|
|
|
|
// Explicit instantiation for OpOperand
|
|
template void
|
|
xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
|
|
|
|
void xegpu::removeLayoutAttrs(Operation *op) {
|
|
op->walk([&](Operation *nestOp) {
|
|
// Remove all attributes of DistributeLayoutAttr type
|
|
SmallVector<StringAttr> attrsToRemove;
|
|
for (auto namedAttr : nestOp->getAttrs()) {
|
|
if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
|
|
attrsToRemove.push_back(namedAttr.getName());
|
|
}
|
|
for (auto attrName : attrsToRemove)
|
|
nestOp->removeAttr(attrName);
|
|
});
|
|
}
|
|
|
|
/// Infers the source layout attribute for a broadcast operation given the
|
|
/// result layout attribute, result shape, source shape.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
|
|
ArrayRef<int64_t> resShape,
|
|
ArrayRef<int64_t> srcShape) {
|
|
|
|
SmallVector<int64_t> bcastDims;
|
|
auto returnLayout = resLayout;
|
|
|
|
// Handling broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
|
|
int dimDiff = resShape.size() - srcShape.size();
|
|
|
|
if (dimDiff > 0) {
|
|
// Adding the missing leading dims
|
|
for (int i = 0; i < dimDiff; i++)
|
|
bcastDims.push_back(i);
|
|
|
|
// Create a slice layout for the source
|
|
returnLayout = xegpu::SliceAttr::get(
|
|
resLayout.getContext(), resLayout,
|
|
DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
|
|
}
|
|
return returnLayout;
|
|
}
|
|
|
|
/// Infers the source layout attribute for a reduction operation given the
|
|
/// result layout attribute and reduced dims.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
|
|
SmallVector<int64_t> reduceDims) {
|
|
|
|
assert(isa<xegpu::SliceAttr>(resLayout) &&
|
|
"reduction result layout must be slice layout");
|
|
|
|
xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
|
|
|
|
assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
|
|
"reduction dims must match with slice dims");
|
|
|
|
return sliceLayout.getParent();
|
|
}
|
|
|
|
/// Infers the source layout attribute for a transpose operation given the
|
|
/// result layout attribute and permutation.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::inferTransposeSourceLayout(xegpu::DistributeLayoutAttr resLayout,
|
|
ArrayRef<int64_t> permutation) {
|
|
return resLayout.transposeDims(permutation);
|
|
}
|
|
|
|
/// Infers the source layout attribute for a bitcast operation given the
|
|
/// result layout attribute, result element type bitwidth, and source element
|
|
/// type bitwidth.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
|
|
int resElemTyBitWidth, int srcElemTyBitWidth) {
|
|
|
|
SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
|
|
SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
|
|
SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
|
|
size_t sgDataSize = sgData.size();
|
|
size_t instDataSize = instData.size();
|
|
size_t laneDataSize = laneData.size();
|
|
int64_t sgDataValue = -1;
|
|
int64_t instDataValue = -1;
|
|
int64_t laneDataValue = -1;
|
|
int64_t dim = resLayout.getRank() - 1;
|
|
|
|
if (srcElemTyBitWidth <= resElemTyBitWidth) {
|
|
int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
|
|
if (sgDataSize)
|
|
sgDataValue = sgData.back() * bitWidthRatio;
|
|
if (instDataSize)
|
|
instDataValue = instData.back() * bitWidthRatio;
|
|
if (laneDataSize)
|
|
laneDataValue = laneData.back() * bitWidthRatio;
|
|
} else {
|
|
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
|
|
if (sgDataSize) {
|
|
assert((sgData.back() % bitWidthRatio) == 0 &&
|
|
"sgData not divisible by bitWidthRatio");
|
|
sgDataValue = sgData.back() / bitWidthRatio;
|
|
}
|
|
if (instDataSize) {
|
|
assert((instData.back() % bitWidthRatio) == 0 &&
|
|
"instData not divisible by bitWidthRatio");
|
|
instDataValue = instData.back() / bitWidthRatio;
|
|
}
|
|
if (laneDataSize) {
|
|
assert((laneData.back() % bitWidthRatio) == 0 &&
|
|
"laneData not divisible by bitWidthRatio");
|
|
laneDataValue = laneData.back() / bitWidthRatio;
|
|
}
|
|
}
|
|
|
|
xegpu::DistributeLayoutAttr finalSrcLayout;
|
|
finalSrcLayout =
|
|
resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
|
|
|
|
return finalSrcLayout;
|
|
}
|
|
|
|
/// Infers the source layout attribute for an insert strided slice operation
|
|
/// given the result layout attribute, result shape, and source shape. Removes
|
|
/// leading dimensions from the result layout to match the source shape size.
|
|
xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
|
|
xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
|
|
ArrayRef<int64_t> srcShape) {
|
|
|
|
int srcShapeSize = srcShape.size();
|
|
int resShapeSize = resShape.size();
|
|
int dimDiff = resShapeSize - srcShapeSize;
|
|
|
|
if (dimDiff > 0) {
|
|
// assert that the leading dimensions being sliced off are not distributed
|
|
// (i.e. sg_layout and lane_layout for those dimensions are all 1)
|
|
auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
|
|
auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
|
|
for (int i = 0; i < dimDiff; i++) {
|
|
assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
|
|
(resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
|
|
"Leading dimensions being sliced off must not be distributed");
|
|
}
|
|
return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
|
|
}
|
|
return resLayout;
|
|
}
|
|
|
|
/// Infers the source layout attribute for a shape cast operation given the
|
|
/// result layout attribute, result shape, and source shape.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
|
|
ArrayRef<int64_t> resShape,
|
|
ArrayRef<int64_t> srcShape) {
|
|
|
|
// There are three use cases:
|
|
// 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
|
|
// tensor before broadcast
|
|
// 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
|
|
// for multi-stage reduction
|
|
// 3. combines all dims to a single dim and put in the innermost dim in 2d as
|
|
// [1, combinedData] or [combinedData]. Say, [2, 4, 8] -> [1, 64] or [64]
|
|
// Use cases are only supported after workgroup distribution,
|
|
// like cross-sg reduction saves multidimension data to
|
|
// 1D slm buffer, shapecast inserted by cse/canonicalization passes.
|
|
|
|
// Use case 1: Shapes only differ by expanding unit dimensions, for broadcast
|
|
SmallVector<int64_t> expandedUnitDims;
|
|
|
|
if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
|
|
// create a slice layout for the source by removing the expanded unit dims
|
|
auto sliceDimsAttr = DenseI64ArrayAttr::get(
|
|
resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
|
|
auto srcLayout =
|
|
xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
|
|
return srcLayout;
|
|
}
|
|
|
|
// Use case 2: Dim split from source to result, for multi-stage reduction
|
|
SmallVector<SmallVector<int64_t>> splitDimGroups;
|
|
if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups)) {
|
|
auto srcLayout = resLayout;
|
|
for (const auto &dimGroup : splitDimGroups)
|
|
srcLayout = srcLayout.collapseDims(dimGroup);
|
|
|
|
return srcLayout;
|
|
}
|
|
|
|
// Use case 3: Collaspse to innermost dim, for cross-sg reduction to SLM
|
|
auto matchCollapseToInnermostDim = [&](ArrayRef<int64_t> src,
|
|
ArrayRef<int64_t> dst) -> bool {
|
|
// only one non-unit dim in dst which is the innermost dim
|
|
if ((dst.size() != 2) && (dst.size() != 1))
|
|
return false;
|
|
int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
|
|
std::multiplies<int64_t>());
|
|
if (dst.size() == 1)
|
|
return (dst[0] == srcSize);
|
|
return (dst[0] == 1) && (dst[1] == srcSize);
|
|
};
|
|
|
|
if (matchCollapseToInnermostDim(srcShape, resShape)) {
|
|
int srcShapeSize = srcShape.size();
|
|
int resShapeSize = resShape.size();
|
|
auto context = resLayout.getContext();
|
|
auto resInstData = resLayout.getEffectiveInstDataAsInt();
|
|
auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
|
|
auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
|
|
|
|
// Extract layout info from result's innermost dimension and apply to
|
|
// source's innermost dimension while setting all other dimensions to 1.
|
|
// The inferred layout is restricted by srcShape to ensure it fits within
|
|
// the source dimensions.
|
|
// Examples 1:
|
|
// srcShape=[8, 16, 32], resShape=[1, 4096]
|
|
// resInstData=[1, 16]
|
|
// -> inferredInstData=[1, 1, min(16, 32)]=[1, 1, 16]
|
|
// Examples 2:
|
|
// srcShape=[4, 8, 64], resShape=[2048]
|
|
// resLaneLayout=[16], resLaneData=[2]
|
|
// -> inferredLaneLayout=[1, 1, 16]
|
|
// -> inferredLaneData=[1, 1, min(2, 64/16)]=[1, 1, 2]
|
|
|
|
if (resInstData.size() != 0) {
|
|
// assert resInstData must be 1 for all but the innermost dim
|
|
for (int i = 0; i < resShapeSize - 1; i++) {
|
|
assert(resInstData[i] == 1 &&
|
|
"only innermost dim can have non-unit instData");
|
|
}
|
|
SmallVector<int> inferredInstData(srcShapeSize, 1);
|
|
inferredInstData[srcShapeSize - 1] =
|
|
std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
|
|
return xegpu::LayoutAttr::get(context, inferredInstData);
|
|
}
|
|
|
|
if (resLaneLayout.size() != 0) {
|
|
for (int i = 0; i < resShapeSize - 1; i++) {
|
|
assert(resLaneData[i] == 1 &&
|
|
"only innermost dim can have non-unit instData");
|
|
}
|
|
assert(srcShape.back() % resLaneLayout.back() == 0 &&
|
|
"source innermost dim must be >= result lane layout");
|
|
SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
|
|
SmallVector<int> inferredLaneData(srcShapeSize, 1);
|
|
inferredLaneLayout.back() = resLaneLayout.back();
|
|
inferredLaneData.back() = std::min(
|
|
resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
|
|
return xegpu::LayoutAttr::get(context, inferredLaneLayout,
|
|
inferredLaneData);
|
|
}
|
|
}
|
|
llvm_unreachable("running into unsupported shape cast scenarios");
|
|
return nullptr;
|
|
}
|
|
|
|
/// Sets up layout for reduction operations by creating a SliceAttr for the
|
|
/// result.
|
|
///
|
|
/// Algorithm Overview:
|
|
/// This function attempts to construct a source layout that, when sliced along
|
|
/// reduction dimensions, produces a result layout compatible with the
|
|
/// consumer layout.
|
|
///
|
|
/// For subgroup layouts, it first tries to align the source layout's subgroup
|
|
/// layout and data with the consumer's layout on non-reduction dimensions.
|
|
/// Then, it distributes remaining subgroups across reduction dimensions. This
|
|
/// avoids subgroup data redistribution overhead between the reduced result and
|
|
/// its consumer.
|
|
///
|
|
/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
|
|
/// Lane Layout requires {1, ..., 1, subgroupSize}
|
|
/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
|
|
///
|
|
/// Examples:
|
|
/// 1. Subgroup layout - Row reduction on 2D tensor:
|
|
/// srcShape=[32, 64], reductionDims=[1], resShape=[32], subgroupSize=16,
|
|
/// workgroupSize=32
|
|
/// Consumer Layout:
|
|
/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
|
|
/// [1]>} Result: srcLayout with sgLayout=[4, 8], sgData=[8, 8] (matches
|
|
/// consumer on non-reduction dim, minimizing data redistribution on
|
|
/// reduction dim)
|
|
/// 2. Subgroup layout - Same example above but consumer has different layout:
|
|
/// sgLayout=[32], sgData=[1]
|
|
/// Result: srcLayout with sgLayout=[32,1], sgData=[1, 64]
|
|
/// (distributes all subgroups on non reduction dim)
|
|
///
|
|
/// 2. InstData layout - Column reduction:
|
|
/// srcShape=[32, 64], reductionDims=[0], subgroupSize=16
|
|
/// Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
|
|
/// innermost)
|
|
///
|
|
/// 3. Lane layout - Multi-dimensional reduction:
|
|
/// srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
|
|
/// Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
|
|
/// (subgroupSize on innermost dim, max vector size on reduction dim)
|
|
|
|
xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
|
|
xegpu::LayoutKind layoutKind, VectorType srcVecTy,
|
|
DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
|
|
const xegpu::uArch::uArch *uArch) {
|
|
|
|
auto srcShape = srcVecTy.getShape();
|
|
int srcRank = srcShape.size();
|
|
auto context = consumerLayout.getContext();
|
|
|
|
// Reduction layout requires at least 2D tensors
|
|
if (srcRank < 2)
|
|
return nullptr;
|
|
|
|
// Helper lambda to convert int64 vectors to int32 DenseArrayAttr
|
|
auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
|
|
SmallVector<int32_t> vec32(vec.begin(), vec.end());
|
|
return DenseI32ArrayAttr::get(context, vec32);
|
|
};
|
|
|
|
// Extract original plain layout for workgroup/subgroup size recovery
|
|
xegpu::SliceAttr consumerSliceLayout =
|
|
dyn_cast<xegpu::SliceAttr>(consumerLayout);
|
|
DistributeLayoutAttr plainLayout =
|
|
consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
|
|
: consumerLayout;
|
|
|
|
const int subgroupSize = uArch->getSubgroupSize();
|
|
int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
|
|
|
|
xegpu::DistributeLayoutAttr srcLayout;
|
|
|
|
if (layoutKind == xegpu::LayoutKind::Subgroup) {
|
|
auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
|
|
const int workgroupSize = std::accumulate(
|
|
sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
|
|
SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
|
|
SmallVector<int64_t> consumerSgLayout =
|
|
consumerLayout.getEffectiveSgLayoutAsInt();
|
|
int remainingSgCount = workgroupSize;
|
|
int consumerIdx = consumerSgLayout.size() - 1;
|
|
|
|
// First pass: Match consumer's layout on non-reduction dimensions
|
|
for (int i = srcRank - 1; i >= 0; i--) {
|
|
if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
|
|
sgLayout[i] = consumerSgLayout[consumerIdx];
|
|
assert((srcShape[i] % sgLayout[i] == 0) &&
|
|
"source shape not divisible by consumer sg_layout");
|
|
sgData[i] = srcShape[i] / sgLayout[i];
|
|
remainingSgCount /= sgLayout[i];
|
|
consumerIdx--;
|
|
}
|
|
}
|
|
|
|
// Second pass: Distribute remaining subgroups across reduction dimensions
|
|
for (int i = srcRank - 1; i >= 0; i--) {
|
|
if (llvm::is_contained(reductionDims, i)) {
|
|
sgLayout[i] =
|
|
std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
|
|
assert((srcShape[i] % sgLayout[i] == 0) &&
|
|
"source shape not divisible by sg_layout");
|
|
sgData[i] = srcShape[i] / sgLayout[i];
|
|
remainingSgCount /= sgLayout[i];
|
|
}
|
|
}
|
|
|
|
assert(remainingSgCount == 1 && "not all subgroups distributed");
|
|
srcLayout = xegpu::LayoutAttr::get(
|
|
context, toInt32Attr(sgLayout), toInt32Attr(sgData),
|
|
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
|
|
/*lane_data =*/nullptr, /*order =*/nullptr);
|
|
|
|
} else if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
|
|
SmallVector<int64_t> instData(srcRank, 1);
|
|
instData[srcRank - 2] =
|
|
std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
|
|
instData[srcRank - 1] = subgroupSize;
|
|
srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
|
|
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
|
|
SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
|
|
laneLayout[srcRank - 1] = subgroupSize;
|
|
laneData[srcRank - 2] =
|
|
std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
|
|
srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
|
|
toInt32Attr(laneData),
|
|
consumerLayout.getOrder());
|
|
}
|
|
|
|
return xegpu::SliceAttr::get(context, srcLayout,
|
|
DenseI64ArrayAttr::get(context, reductionDims));
|
|
}
|
|
|
|
/// Sets up the result layout for a bitcast operation.
|
|
/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
|
|
/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
|
|
/// result layout can be correctly divided back to the source layout during
|
|
/// inference.
|
|
///
|
|
/// Examples:
|
|
/// 1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
|
|
/// Consumer layout: instData=[1, 16], subgroupSize=16
|
|
/// Source shape: [8, 32]
|
|
/// Result layout: instData=[1, 32] (16 * 2)
|
|
/// The innermost dimension is multiplied by 2 to maintain consistency.
|
|
///
|
|
/// 2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
|
|
/// Consumer instData=[1, 16], subgroupSize=16
|
|
/// Source shape: [4, 128]
|
|
/// adjust the instData from [1, 16] to [1, 16 * 4 = 64]
|
|
///
|
|
/// 3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
|
|
/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
|
|
/// No adjustment needed - returns consumer layout directly.
|
|
///
|
|
xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
|
|
xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
|
|
DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
|
|
|
|
int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
|
|
int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
|
|
|
|
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
|
|
SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
|
|
SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
|
|
SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
|
|
assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
|
|
"laneData must be available for all dimensions");
|
|
size_t dim = srcShape.size() - 1;
|
|
int64_t sgDataValue = -1;
|
|
int64_t instDataValue = -1;
|
|
int64_t laneDataValue = -1;
|
|
const int subgroupSize = uArch->getSubgroupSize();
|
|
|
|
if (srcElemTyBitWidth > resElemTyBitWidth) {
|
|
// When casting to a smaller bitwidth, multiply the result layout
|
|
// accordingly to ensure it can be divided by the ratio back to the
|
|
// source layout.
|
|
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
|
|
int innermostDimLaneLayout = subgroupSize;
|
|
if (layoutKind == xegpu::LayoutKind::Subgroup) {
|
|
sgDataValue = sgData[dim];
|
|
} else if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
instDataValue = instData[dim];
|
|
// Adjust instDataValue so it still fits within an instruction after
|
|
// dividing by bitWidthRatio
|
|
while ((instDataValue <= srcShape[dim]) &&
|
|
(instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
|
|
instDataValue *= 2;
|
|
assert((srcShape[dim] % instDataValue) == 0 &&
|
|
"srcShape, instData, and lanelayout for innermost must be 2^n !");
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
laneDataValue = laneData[dim];
|
|
while ((laneDataValue <= srcShape[dim]) &&
|
|
(laneDataValue % bitWidthRatio != 0))
|
|
laneDataValue *= 2;
|
|
}
|
|
// Now set only instData and laneData, preserving sgData
|
|
xegpu::DistributeLayoutAttr resLayout;
|
|
resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
|
|
laneDataValue);
|
|
return resLayout;
|
|
}
|
|
return consumerLayout;
|
|
}
|
|
|
|
/// Sets up the result layout for an insert strided slice operation.
|
|
/// Creates a result layout based on the specified layout kind (InstData or
|
|
/// Lane).
|
|
xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
|
|
xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
|
|
VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
|
|
const xegpu::uArch::uArch *uArch) {
|
|
|
|
xegpu::DistributeLayoutAttr requiredResLayout;
|
|
SmallVector<int64_t> consumerInstData =
|
|
consumerLayout.getEffectiveInstDataAsInt();
|
|
SmallVector<int64_t> consumerLaneData =
|
|
consumerLayout.getEffectiveLaneDataAsInt();
|
|
SmallVector<int64_t> consumerLaneLayout =
|
|
consumerLayout.getEffectiveLaneLayoutAsInt();
|
|
ArrayRef<int64_t> srcShape = srcVectorTy.getShape();
|
|
int64_t instDataValue = -1;
|
|
int64_t laneDataValue = -1;
|
|
|
|
requiredResLayout = consumerLayout;
|
|
int srcRank = srcShape.size();
|
|
|
|
if (layoutKind == xegpu::LayoutKind::Subgroup) {
|
|
assert(true &&
|
|
"subgroup layout assignment not supported for insertStridedSlice.");
|
|
} else if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
for (int dim = 0; dim < srcRank; dim++) {
|
|
instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
|
|
requiredResLayout =
|
|
requiredResLayout.setDimData(dim, -1, instDataValue, -1);
|
|
}
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
for (int dim = 0; dim < srcRank; dim++) {
|
|
assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
|
|
"srcShape must be divisible by laneLayout for all dimensions");
|
|
laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
|
|
consumerLaneData[dim]);
|
|
|
|
requiredResLayout =
|
|
requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
|
|
}
|
|
}
|
|
return requiredResLayout;
|
|
}
|
|
|
|
/// Sets up the anchor layout for load gather and load matrix operation.
|
|
/// load matrix lowers to load gather and 1d block load. All of them share the
|
|
/// same layout setup logic.
|
|
/// For Subgroup layout, uses the consumer layout directly.
|
|
/// non-chunked loads:
|
|
/// InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
|
|
/// LaneLayout = {1, ..., subgroupSize}
|
|
/// lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
|
|
/// chunked loads:
|
|
/// InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
|
|
/// LaneLayout = {subgroupSize, 1}
|
|
/// lane_data={1,min(consumer, maxLaneLoadSize)}
|
|
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
|
|
xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
|
|
xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
|
|
int maxChunkSize, ArrayRef<int64_t> resShape, int subgroupSize) {
|
|
|
|
if (layoutKind == xegpu::LayoutKind::Subgroup)
|
|
return consumerLayout;
|
|
|
|
SmallVector<int64_t> consumerInstData =
|
|
consumerLayout.getEffectiveInstDataAsInt();
|
|
SmallVector<int64_t> consumerLaneData =
|
|
consumerLayout.getEffectiveLaneDataAsInt();
|
|
|
|
SmallVector<int> instData(resShape.size(), 1);
|
|
SmallVector<int> laneLayout(resShape.size(), 1);
|
|
SmallVector<int> laneData(resShape.size(), 1);
|
|
|
|
if (!isChunkedLoad) {
|
|
if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
instData.back() = std::min(static_cast<int>(consumerInstData.back()),
|
|
maxChunkSize * subgroupSize);
|
|
return xegpu::LayoutAttr::get(context, instData);
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
laneData.back() =
|
|
std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
|
|
laneLayout.back() = std::min(static_cast<int64_t>(subgroupSize),
|
|
resShape.back() / laneData.back());
|
|
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
|
|
}
|
|
} else {
|
|
assert(resShape.size() == 2 && "Chunked Store must access 2D tensor tile.");
|
|
if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
instData[0] = subgroupSize;
|
|
instData[1] =
|
|
std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
|
|
return xegpu::LayoutAttr::get(context, instData);
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
laneLayout[0] = subgroupSize;
|
|
laneData[1] =
|
|
std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
|
|
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// Sets up the anchor layout for a load gather operation.
|
|
xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
|
|
xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
|
|
xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
|
|
|
|
const int subgroupSize = uArch->getSubgroupSize();
|
|
ArrayRef<int64_t> resShape = resVecTy.getShape();
|
|
auto context = resVecTy.getContext();
|
|
auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
|
|
|
|
const auto *uArchInstruction =
|
|
dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
|
|
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
|
|
int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
|
|
|
|
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
|
|
(chunkSize > 1), maxChunkSize, resShape,
|
|
subgroupSize);
|
|
}
|
|
|
|
/// Sets up the anchor layout for load matrix operation.
|
|
/// TODO: enhance load matrix to indicate lowering to chunked load or not.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
|
|
VectorType resVecTy,
|
|
xegpu::DistributeLayoutAttr consumerLayout,
|
|
const xegpu::uArch::uArch *uArch) {
|
|
|
|
const int subgroupSize = uArch->getSubgroupSize();
|
|
ArrayRef<int64_t> resShape = resVecTy.getShape();
|
|
auto context = resVecTy.getContext();
|
|
auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
|
|
|
|
const auto *uArchInstruction =
|
|
dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
|
|
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
|
|
int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
|
|
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
|
|
false, maxChunkSize, resShape,
|
|
subgroupSize);
|
|
}
|
|
|
|
/// Sets up the anchor layout for store scatter and store matrix operation.
|
|
/// store matrix lowers to store scatter and 1d block store. All of them share
|
|
/// the same layout setup logic. For Subgroup layout, not support yet.
|
|
/// non-chunked stores:
|
|
/// InstData = {1, ..., subgroupSize}
|
|
/// LaneLayout = {1, ..., subgroupSize}
|
|
/// lane_data = {1, ..., 1}
|
|
/// chunked stores:
|
|
/// InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
|
|
/// LaneLayout = {subgroupSize, 1}
|
|
/// lane_data={1,min(srcVec, maxLaneStoreSize)}
|
|
static xegpu::DistributeLayoutAttr
|
|
setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
|
|
mlir::MLIRContext *context, bool isChunkedStore,
|
|
int maxChunkSize, ArrayRef<int64_t> srcShape,
|
|
int subgroupSize) {
|
|
|
|
int srcShapeSize = srcShape.size();
|
|
SmallVector<int> instData(srcShapeSize, 1);
|
|
SmallVector<int> laneLayout(srcShapeSize, 1);
|
|
SmallVector<int> laneData(srcShapeSize, 1);
|
|
|
|
if (layoutKind == xegpu::LayoutKind::Subgroup) {
|
|
assert(true &&
|
|
"subgroup layout assignment not supported for storeScatter.");
|
|
return nullptr;
|
|
}
|
|
|
|
if (!isChunkedStore) {
|
|
if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
instData[srcShapeSize - 1] =
|
|
std::min(subgroupSize, static_cast<int>(srcShape.back()));
|
|
return xegpu::LayoutAttr::get(context, instData);
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
laneLayout[srcShapeSize - 1] =
|
|
std::min(subgroupSize, static_cast<int>(srcShape.back()));
|
|
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
|
|
}
|
|
} else {
|
|
assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
|
|
if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
instData[0] = subgroupSize;
|
|
instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
|
|
return xegpu::LayoutAttr::get(context, instData);
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
laneLayout[0] = subgroupSize;
|
|
laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
|
|
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// Sets up the anchor layout for a store scatter operation.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::setupStoreScatterAnchorLayout(xegpu::LayoutKind layoutKind,
|
|
VectorType srcVecTy, int chunkSize,
|
|
const uArch::uArch *uArch) {
|
|
|
|
const int subgroupSize = uArch->getSubgroupSize();
|
|
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
|
|
auto context = srcVecTy.getContext();
|
|
auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
|
|
|
|
const auto *uArchInstruction =
|
|
dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
|
|
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
|
|
int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
|
|
return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
|
|
maxChunkSize, srcShape, subgroupSize);
|
|
}
|
|
|
|
/// Sets up the anchor layout for a store matrix operation.
|
|
xegpu::DistributeLayoutAttr
|
|
xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
|
|
VectorType srcVecTy,
|
|
const xegpu::uArch::uArch *uArch) {
|
|
|
|
const int subgroupSize = uArch->getSubgroupSize();
|
|
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
|
|
auto context = srcVecTy.getContext();
|
|
auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
|
|
|
|
const auto *uArchInstruction =
|
|
dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
|
|
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
|
|
int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
|
|
|
|
return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
|
|
srcShape, subgroupSize);
|
|
}
|
|
|
|
// This function returns the default lane layout for a given vector type.
|
|
// - `packingSize` means multiple consecutive elements can be accessed
|
|
// together as a single unit.
|
|
// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
|
|
// 1x2xf16 w/o vnni).
|
|
template <typename RankedTy>
|
|
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(
|
|
RankedTy ty, const xegpu::uArch::uArch *uArch,
|
|
std::optional<unsigned> packingSize = std::nullopt, bool vnni = false) {
|
|
// Expecting a 1D or 2D vector.
|
|
assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
|
|
"Expected 1D non-vnni or 2D vector.");
|
|
// Expecting int or float element type.
|
|
assert(ty.getElementType().isIntOrFloat() &&
|
|
"Expected int or float element type.");
|
|
|
|
auto context = ty.getContext();
|
|
auto rank = ty.getRank();
|
|
SmallVector<int> laneLayout(rank, 1);
|
|
SmallVector<int> laneData(rank, 1);
|
|
if (packingSize.has_value()) {
|
|
unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
|
|
int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
|
|
laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
|
|
}
|
|
laneLayout.back() = uArch->getSubgroupSize();
|
|
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
|
|
}
|
|
|
|
// This function returns all layouts for the given sgCount, whose sgData:
|
|
// 1. Evenly divides the wgShape.
|
|
// 2. Is a multiple of instData.
|
|
// Example:
|
|
// wgShape = [128, 64], instData = [8, 16], sgCount = 32
|
|
// Returns layouts:
|
|
// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
|
|
using LayoutRepresentation = std::pair<int64_t, int64_t>;
|
|
static SmallVector<LayoutRepresentation>
|
|
getValidLayouts(ArrayRef<int64_t> wgShape, ArrayRef<int64_t> instData,
|
|
int64_t sgCount) {
|
|
SmallVector<LayoutRepresentation> candidates;
|
|
for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
|
|
if (sgCount % sgLayout0)
|
|
continue;
|
|
int64_t sgLayout1 = sgCount / sgLayout0;
|
|
int64_t sgData0 = wgShape[0] / sgLayout0;
|
|
int64_t sgData1 = wgShape[1] / sgLayout1;
|
|
if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
|
|
(sgData0 % instData[0] || sgData1 % instData[1]))
|
|
continue;
|
|
candidates.emplace_back(sgLayout0, sgLayout1);
|
|
}
|
|
// Sort primarily by how balanced they are
|
|
// (i.e., minimize the absolute difference between the two dimensions), and
|
|
// secondarily by the first dimension in ascending order.
|
|
llvm::sort(candidates, [](const LayoutRepresentation &lhs,
|
|
const LayoutRepresentation &rhs) {
|
|
int diffLhs = std::abs(lhs.first - lhs.second);
|
|
int diffRhs = std::abs(rhs.first - rhs.second);
|
|
if (diffLhs != diffRhs)
|
|
return diffLhs < diffRhs;
|
|
return lhs.first < rhs.first;
|
|
});
|
|
return candidates;
|
|
}
|
|
|
|
/// Sets up the anchor layouts for dpas operands (A, B, and C/D).
|
|
/// The numSg and consumerLayout (optional) are only used by sg layout
|
|
/// creation.
|
|
std::optional<
|
|
std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
|
|
xegpu::DistributeLayoutAttr>>
|
|
xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
|
|
VectorType bTy, VectorType cdTy,
|
|
xegpu::DistributeLayoutAttr consumerLayout,
|
|
const xegpu::uArch::uArch *uArch, int numSg) {
|
|
auto context = aTy.getContext();
|
|
const auto *uArchInstruction =
|
|
dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
|
|
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
|
|
|
|
auto getInstDataVectors = [&]()
|
|
-> std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
|
|
SmallVector<int64_t>>> {
|
|
const int subgroupSize = uArch->getSubgroupSize();
|
|
const unsigned dataALen = aTy.getShape().front();
|
|
auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
|
|
const int maxALen =
|
|
xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
|
|
|
|
const unsigned dataBLen = bTy.getShape().back();
|
|
auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
|
|
const int maxBLen =
|
|
xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
|
|
|
|
auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
|
|
const int maxCLen =
|
|
xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
|
|
if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
|
|
return std::nullopt;
|
|
|
|
SmallVector<int64_t> instDataA(aTy.getRank(), 1);
|
|
instDataA[aTy.getRank() - 2] = maxALen;
|
|
instDataA[aTy.getRank() - 1] = subgroupSize;
|
|
SmallVector<int64_t> instDataB(bTy.getRank(), 1);
|
|
instDataB[bTy.getRank() - 2] = subgroupSize;
|
|
instDataB[bTy.getRank() - 1] = maxBLen;
|
|
SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
|
|
instDataCD[cdTy.getRank() - 2] = maxALen;
|
|
instDataCD[cdTy.getRank() - 1] = maxCLen;
|
|
return std::make_tuple(instDataA, instDataB, instDataCD);
|
|
};
|
|
|
|
if (layoutKind == xegpu::LayoutKind::Subgroup) {
|
|
assert(numSg > 0 &&
|
|
"Number of subgroups must be provided for sg layout creation.");
|
|
auto instDataVecs = getInstDataVectors();
|
|
if (!instDataVecs)
|
|
return std::nullopt;
|
|
auto [instDataA, instDataB, instDataCD] = *instDataVecs;
|
|
assert(instDataA.size() == 2 && instDataB.size() == 2 &&
|
|
instDataCD.size() == 2 &&
|
|
"Sg layout creation expects valid 2D inst data");
|
|
|
|
std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
|
|
if (consumerLayout && consumerLayout.isForWorkgroup()) {
|
|
SmallVector<int64_t> sgLayoutD =
|
|
consumerLayout.getEffectiveSgLayoutAsInt();
|
|
consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
|
|
}
|
|
|
|
// Step 1. Get all valid layouts for A, B and C/D operands.
|
|
// Order them from most balanced to least balanced.
|
|
auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
|
|
auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
|
|
auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
|
|
if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
|
|
return std::nullopt;
|
|
|
|
// Step 2. If the consumer layout can be reused for all operands, that
|
|
// layout is chosen. Otherwise, pick the most balanced subgroup layout
|
|
// that is valid for A, B and C (if present) operands
|
|
llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
|
|
llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
|
|
layoutsCD.end());
|
|
std::optional<LayoutRepresentation> bestPick;
|
|
for (auto &sgLayout : layoutsB) {
|
|
if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
|
|
// Is in (A and B and CD) and matches consumer -> best pick
|
|
if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
|
|
bestPick = sgLayout;
|
|
break;
|
|
}
|
|
// Is in (A and B and CD) layoutsB is ordered from most
|
|
// balanced to least. So the first one we see is the most balanced
|
|
// one, remember it and later only update if there is one that matches
|
|
// the consumer.
|
|
if (!bestPick)
|
|
bestPick = sgLayout;
|
|
}
|
|
}
|
|
// Step 3. If there is no subgroup layout compatible with A, B and C (if
|
|
// present) operands, we fail.
|
|
if (!bestPick)
|
|
return std::nullopt;
|
|
SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
|
|
static_cast<int>(bestPick->second)};
|
|
SmallVector<int> sgDataA = {
|
|
static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
|
|
static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
|
|
SmallVector<int> sgDataB = {
|
|
static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
|
|
static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
|
|
SmallVector<int> sgDataCD = {
|
|
static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
|
|
static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
|
|
|
|
auto dpasALayout = xegpu::LayoutAttr::get(
|
|
context, DenseI32ArrayAttr::get(context, sgLayout),
|
|
DenseI32ArrayAttr::get(context, sgDataA),
|
|
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
|
|
/*lane_data =*/nullptr, /*order =*/nullptr);
|
|
|
|
auto dpasBLayout = xegpu::LayoutAttr::get(
|
|
context, DenseI32ArrayAttr::get(context, sgLayout),
|
|
DenseI32ArrayAttr::get(context, sgDataB),
|
|
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
|
|
/*lane_data =*/nullptr, /*order =*/nullptr);
|
|
|
|
auto dpasCDLayout = xegpu::LayoutAttr::get(
|
|
context, DenseI32ArrayAttr::get(context, sgLayout),
|
|
DenseI32ArrayAttr::get(context, sgDataCD),
|
|
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
|
|
/*lane_data =*/nullptr, /*order =*/nullptr);
|
|
return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
|
|
} else if (layoutKind == xegpu::LayoutKind::InstData) {
|
|
auto instDataVecs = getInstDataVectors();
|
|
if (!instDataVecs)
|
|
return std::nullopt;
|
|
auto [instDataA, instDataB, instDataCD] = *instDataVecs;
|
|
return std::make_tuple(
|
|
xegpu::LayoutAttr::get(
|
|
context, SmallVector<int>(instDataA.begin(), instDataA.end())),
|
|
xegpu::LayoutAttr::get(
|
|
context, SmallVector<int>(instDataB.begin(), instDataB.end())),
|
|
xegpu::LayoutAttr::get(
|
|
context, SmallVector<int>(instDataCD.begin(), instDataCD.end())));
|
|
} else if (layoutKind == xegpu::LayoutKind::Lane) {
|
|
auto aLayout = getDefaultLaneLayout2DBlockIo(
|
|
aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
|
|
auto bLayout = getDefaultLaneLayout2DBlockIo(
|
|
bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
|
|
auto cdLayout = getDefaultLaneLayout2DBlockIo(
|
|
cdTy, uArch, uArchInstruction->getPackedFormatBitSizeB());
|
|
return std::make_tuple(aLayout, bLayout, cdLayout);
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
|
|
Operation *op = operand.getOwner();
|
|
unsigned idx = operand.getOperandNumber();
|
|
xegpu::DistributeLayoutAttr resLayout;
|
|
if (op->getNumResults() == 1 && isa<VectorType>(op->getResult(0).getType()))
|
|
resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
|
|
|
|
// For vector::BroadcastOp, infer the source layout from the result layout.
|
|
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
|
|
if (!resLayout)
|
|
return xegpu::DistributeLayoutAttr();
|
|
auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
|
|
if (!srcTy)
|
|
return xegpu::DistributeLayoutAttr();
|
|
return xegpu::inferBroadcastSourceLayout(
|
|
resLayout, broadcast.getResultVectorType().getShape(),
|
|
srcTy.getShape());
|
|
}
|
|
|
|
// For vector::MultiDimReductionOp, infer source layout from result layout
|
|
// using reduction dims. Acc operand is expected to have the same layout as
|
|
// the result.
|
|
if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
|
|
if (!resLayout)
|
|
return xegpu::DistributeLayoutAttr();
|
|
if (idx == 0) {
|
|
SmallVector<int64_t> reductionDims(reduction.getReductionDims());
|
|
return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
|
|
}
|
|
if (idx == 1)
|
|
return resLayout;
|
|
}
|
|
|
|
// For vector::BitCastOp, infer source layout from result layout using
|
|
// element type bitwidths.
|
|
if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
|
|
if (!resLayout)
|
|
return xegpu::DistributeLayoutAttr();
|
|
int resElemBitWidth =
|
|
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
|
|
int srcElemBitWidth =
|
|
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
|
|
return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
|
|
srcElemBitWidth);
|
|
}
|
|
|
|
// For vector::ShapeCastOp, infer source layout from result layout using
|
|
// shapes.
|
|
if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
|
|
if (!resLayout)
|
|
return xegpu::DistributeLayoutAttr();
|
|
return xegpu::inferShapeCastSourceLayout(
|
|
resLayout, shapeCast.getResultVectorType().getShape(),
|
|
shapeCast.getSourceVectorType().getShape());
|
|
}
|
|
|
|
// For vector::InsertStridedSliceOp, infer source layout from result layout.
|
|
// Dest vector must have the same layout as the result.
|
|
if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
|
|
if (!resLayout)
|
|
return xegpu::DistributeLayoutAttr();
|
|
if (idx == 0)
|
|
return xegpu::inferInsertStridedSliceSourceLayout(
|
|
resLayout, insertSlice.getDestVectorType().getShape(),
|
|
insertSlice.getSourceVectorType().getShape());
|
|
if (idx == 1)
|
|
return resLayout;
|
|
}
|
|
|
|
// For vector::TransposeOp, infer source layout from result layout using
|
|
// permutation.
|
|
if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
|
|
if (!resLayout)
|
|
return xegpu::DistributeLayoutAttr();
|
|
return xegpu::inferTransposeSourceLayout(resLayout,
|
|
transpose.getPermutation());
|
|
}
|
|
|
|
// For elementwise operations, all operands must have the same layout as the
|
|
// result.
|
|
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
|
|
if (!resLayout)
|
|
return xegpu::DistributeLayoutAttr();
|
|
return resLayout;
|
|
}
|
|
// TODO: Handle more cases as needed here.
|
|
// By default, assume no layout conflict and return the current layout of
|
|
// the operand.
|
|
return xegpu::getDistributeLayoutAttr(operand.get());
|
|
}
|