[MLIR][XeGPU] Enhance XeGPU lane layout to support "wrap-around" distribution (#186958)

This PR extends XeGPU lane layout to support wrap-around distribution,
enabling replication of lane-level tensor tiles across all lanes when
the tile size matches lane_data along a given dimension. Previously,
distribution required the tile size to exceed the number of lanes ×
lane_data for even partitioning.

This PR also refactors layout attribute interface functions:

computeDistributedShape() computes the distributed vector shape and is
shared by work-to-subgroup and subgroup-to-lane distribution, which
follow the same distribution rule (even or wrap-around).

computeStaticDistributedCoords() computes compile-time distributed
coordinates of sub-tiles per subgroup/lane. It is the compile-time
counterpart of computeDistributedCoords() and is used by
isCompatibleWith().
This commit is contained in:
Jianhui Li 2026-03-20 17:42:25 -07:00 committed by GitHub
parent 0ae9aaf539
commit 367da15a11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 363 additions and 119 deletions

View File

@ -265,6 +265,60 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"FailureOr<SmallVector<SmallVector<Value>>>",
"computeDistributedCoords",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
InterfaceMethod<[{Statically computes multidimensional coordinates for all dist units
assigned to a compute unit identified by `linearId`. This is the
compile-time counterpart of `computeDistributedCoords`: it performs
the same delinearization and round-robin enumeration but operates
entirely on static integer values. Returns a list of coordinate
vectors, one per dist unit.}],
/*retTy=*/"SmallVector<SmallVector<int64_t>>",
/*methodName=*/"computeStaticDistributedCoords",
/*args=*/(ins "int64_t":$linearId, "ArrayRef<int64_t>":$shape)>,
InterfaceMethod<[{Computes the distributed shape for each compute unit by dividing each
dimension of `shape` by the corresponding layout factor (sg_layout or
lane_layout).
The distributed shape represents the per-compute-unit tile. Each
distribution unit is defined as the combination of layout factors and
per-unit data (`subshape`, e.g., sg_data or lane_data). When `shape`
spans multiple distribution units, the distributed shape may contain
multiple such units.
For wrap-around dimensions where the division is uneven, the tensor tile
is broadcast to all subgroups/lanes.}],
/*retTy=*/"FailureOr<SmallVector<int64_t>>",
/*methodName=*/"computeDistributedShape",
/*args=*/(ins "SmallVector<int64_t>":$shape),
/*methodBody=*/[{
SmallVector<int64_t> layout;
SmallVector<int64_t> subShape;
if ($_self.isForWorkgroup()) {
layout = $_self.getEffectiveSgLayoutAsInt();
subShape = $_self.getEffectiveSgDataAsInt();
} else if ($_self.isForSubgroup()) {
layout = $_self.getEffectiveLaneLayoutAsInt();
subShape = $_self.getEffectiveLaneDataAsInt();
} else {
return failure();
}
assert(
!subShape.empty() &&
"sgdata or lanedata cannot be empty for distributed shape computation");
SmallVector<int64_t> distributedShape(shape.size());
for (auto [i, dim] : llvm::enumerate(shape)) {
int64_t distriUnit = layout[i]*subShape[i];
if ((dim % distriUnit) == 0) {
// Evenly divisible case, divide the dimension by the layout factor.
distributedShape[i] = dim / layout[i];
assert((distributedShape[i] % subShape[i] == 0) &&
"Even distribution: sgdata or lanedata must divide the distributed dimension");
} else {
// wrap around case, the dimension size must be equal to subShape value
assert(dim == subShape[i] &&
"Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
distributedShape[i] = dim;
}
}
return distributedShape;
}]>,
InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
/*retTy=*/"bool",
/*methodName=*/"isSliceOf",
@ -277,28 +331,12 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"ArrayRef<int64_t>": $perm,
"xegpu::LayoutKind": $kind)>,
InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
at a specific level of the layout hierarchy. Unlike isEqualTo,
this compares only the effective (non-sliced) fields at the
requested level.}],
at a specific level of the layout hierarchy regarding a given shape. }],
/*retTy=*/"bool",
/*methodName=*/"isCompatibleWith",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
"xegpu::LayoutKind": $level),
/*methodBody=*/[{
if (!other)
return false;
switch (level) {
case xegpu::LayoutKind::Subgroup:
return $_self.getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
$_self.getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt();
case xegpu::LayoutKind::InstData:
return $_self.getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt();
case xegpu::LayoutKind::Lane:
return $_self.getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
$_self.getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
}
return false;
}]>,
"SmallVector<int64_t>": $shape,
"xegpu::LayoutKind": $level)>,
InterfaceMethod</*desc=*/[{Check if this layout is equal to another layout.
For LayoutAttr, this compares all fields.
For SliceAttr, this requires the same parent and same sliced dims.}],
@ -561,12 +599,24 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
FailureOr<SmallVector<SmallVector<Value>>>
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
///Statically computes multidimensional coordinates for all dist units
///assigned to a compute unit identified by `linearId`. This is the
///compile-time counterpart of `computeDistributedCoords`.
SmallVector<SmallVector<int64_t>>
computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
/// Check if this layout is equal to another layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
/// Check if this layout is compatible with another layout
/// at a specific level of the layout hierarchy regarding a given shape.
bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
SmallVector<int64_t> shape,
xegpu::LayoutKind level);
/// Check if this layout is a transpose of another layout.
bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
}];
@ -778,16 +828,27 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
///Statically computes multidimensional coordinates for all dist units
///assigned to a compute unit identified by `linearId`. This is the
///compile-time counterpart of `computeDistributedCoords`.
SmallVector<SmallVector<int64_t>>
computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
/// Check if this layout is equal to another layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
/// Check if this layout is compatible with another layout
/// at a specific level of the layout hierarchy regarding a given shape.
bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
SmallVector<int64_t> shape,
xegpu::LayoutKind level);
/// Check if this layout is a transpose of another layout.
bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);

View File

@ -96,6 +96,31 @@ genCoordinates(OpBuilder &builder, Location loc,
return coordinates;
}
static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
llvm::ArrayRef<int64_t> canonicalIds, llvm::ArrayRef<int64_t> layout,
llvm::ArrayRef<int64_t> subShape, llvm::ArrayRef<int64_t> shape) {
// Compute distribution unit shape (clamped to srcShape).
SmallVector<int64_t> distUnitShape(shape.size());
for (size_t i = 0; i < shape.size(); ++i)
distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
// Compute local offset of this ID within a distribution unit.
SmallVector<int64_t> localOffset(shape.size());
for (size_t i = 0; i < shape.size(); ++i)
localOffset[i] = canonicalIds[i] * subShape[i];
// Enumerate all distribution units and compute coordinates.
SmallVector<SmallVector<int64_t>> coordinates;
for (SmallVector<int64_t> unitOffs :
StaticTileOffsetRange(shape, distUnitShape)) {
SmallVector<int64_t> coord(shape.size());
for (size_t i = 0; i < shape.size(); ++i)
coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
coordinates.push_back(coord);
}
return coordinates;
}
// Checks if the given shape can be evenly distributed based on the layout
// and data factors provided by the LayoutAttr.
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
@ -160,7 +185,7 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
// check LaneLayout and LaneData
auto maybeLaneShape =
tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
attr.getEffectiveLaneDataAsInt(), false);
attr.getEffectiveLaneDataAsInt());
return maybeLaneShape.has_value();
}
@ -238,25 +263,17 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
<< lane_layout.size();
}
// sg_data is optional for Workgroup layout, but its presence requires
// sg_layout.
if (sg_data) {
if (!sg_layout)
return emitError() << "expected sg_layout being used with sg_data";
if (sg_data.size() != sg_layout.size())
return emitError()
<< "expected sg_data and sg_layout to have the same rank";
}
if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
return emitError() << "sg_layout and sg_data must be used together";
if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
return emitError()
<< "expected sg_data and sg_layout to have the same rank";
// lane_data is optional for Subgroup layout, but its presence requires
// lane_layout.
if (lane_data) {
if (!lane_layout)
return emitError() << "expected lane_layout being used with lane_data";
if (lane_data.size() != lane_layout.size())
return emitError()
<< "expected lane_data and lane_layout to have the same rank";
}
if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
return emitError() << "lane_layout and lane_data must be used together";
if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
return emitError()
<< "expected lane_data and lane_layout to have the same rank";
if (order) {
if (!sg_layout && !lane_layout)
@ -373,12 +390,8 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
} else {
return failure();
}
if (subShape.empty()) {
if (auto derivedShape = computeShapeRatio(shape, layout))
subShape = derivedShape.value();
else
return failure();
}
assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
"distributed coordinates computation");
// delinearize Ids
auto maybeIds = delinearizeId(builder, loc, linearId);
@ -396,6 +409,42 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
return *this == dyn_cast<xegpu::LayoutAttr>(other);
}
/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
/// compute multi-dimensional offsets for a given linear ID when distributed by
/// LayoutAttr.
SmallVector<SmallVector<int64_t>>
LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
ArrayRef<int64_t> shape) {
SmallVector<int64_t> layoutVec;
SmallVector<int64_t> subShape;
SmallVector<int64_t> instData;
if (isForWorkgroup()) {
layoutVec = getEffectiveSgLayoutAsInt();
subShape = getEffectiveSgDataAsInt();
} else if (isForSubgroup()) {
instData = getEffectiveInstDataAsInt();
layoutVec = getEffectiveLaneLayoutAsInt();
subShape = getEffectiveLaneDataAsInt();
}
if (!instData.empty()) {
linearId = 0;
subShape = instData;
}
assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
// Delinearize the linear ID using the order attribute.
SmallVector<int64_t> order = getEffectiveOrderAsInt();
SmallVector<int64_t> delinearizedId(layoutVec.size());
int64_t remaining = linearId;
for (size_t i = 0; i < order.size(); ++i) {
int64_t dimIdx = order[i];
delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
remaining = remaining / layoutVec[dimIdx];
}
return genStaticCoordinates(delinearizedId, layoutVec, subShape, shape);
}
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
@ -743,6 +792,46 @@ bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
return false;
}
bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
SmallVector<int64_t> shape,
xegpu::LayoutKind level) {
if (!other)
return false;
if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
if (level == xegpu::LayoutKind::Subgroup)
return (getEffectiveSgLayoutAsInt() ==
other.getEffectiveSgLayoutAsInt() &&
getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
if (level == xegpu::LayoutKind::Lane)
return (getEffectiveLaneLayoutAsInt() ==
other.getEffectiveLaneLayoutAsInt() &&
getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt());
}
auto compareCoordsForAllIds = [&](int64_t size) {
for (int64_t id : llvm::seq<int64_t>(0, size)) {
auto coords = computeStaticDistributedCoords(id, shape);
auto otherCoords = other.computeStaticDistributedCoords(id, shape);
if (coords != otherCoords)
return false;
}
return true;
};
if (level == xegpu::LayoutKind::Subgroup) {
int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
return compareCoordsForAllIds(wgSize);
}
if (level == xegpu::LayoutKind::InstData) {
return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
}
if (level == xegpu::LayoutKind::Lane) {
int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
return compareCoordsForAllIds(subgroupSize);
}
return true;
}
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@ -819,12 +908,8 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
return failure();
}
if (subShape.empty()) {
if (auto derivedShape = computeShapeRatio(shape, layout))
subShape = derivedShape.value();
else
return failure();
}
if (subShape.empty())
return failure();
// delinearize Ids
auto maybeIds = delinearizeId(builder, loc, linearId);
@ -834,10 +919,65 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
// The effective sgIds for offsets computing correspond
// to the dims that are not sliced.
ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
SmallVector<Value> sgIds =
SmallVector<Value> canonicalIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
}
/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
/// compute multi-dimensional offsets for a given linear ID when distributed by
/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
/// only the non-sliced dimensions for coordinate computation.
SmallVector<SmallVector<int64_t>>
SliceAttr::computeStaticDistributedCoords(int64_t linearId,
ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
SmallVector<int64_t> layout;
SmallVector<int64_t> subShape;
SmallVector<int64_t> instData;
if (isForWorkgroup()) {
layout = getEffectiveSgLayoutAsInt();
subShape = getEffectiveSgDataAsInt();
} else if (isForSubgroup()) {
instData = getEffectiveInstDataAsInt();
layout = getEffectiveLaneLayoutAsInt();
subShape = getEffectiveLaneDataAsInt();
}
if (!instData.empty()) {
linearId = 0;
subShape = instData;
}
assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
// Delinearize the ID using the parent layout (same as the IR version).
SliceAttr flattened = flatten();
auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
SmallVector<int64_t> parentLayoutVec;
if (parent.isForWorkgroup())
parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
else
parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
SmallVector<int64_t> allIds(parentLayoutVec.size());
int64_t remaining = linearId;
for (size_t i = 0; i < order.size(); ++i) {
int64_t dimIdx = order[i];
allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
if (i < order.size() - 1)
remaining = remaining / parentLayoutVec[dimIdx];
}
// The effective IDs for coordinate computation correspond
// to the dims that are not sliced.
ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
SmallVector<int64_t> canonicalIds =
XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
return genStaticCoordinates(canonicalIds, layout, subShape, shape);
}
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
@ -871,6 +1011,50 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
(flattenedThis.getDims() == flattenedOther.getDims()));
}
bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
SmallVector<int64_t> shape,
xegpu::LayoutKind level) {
if (!other)
return false;
if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
// short cut when order is the same, no need to compute coords and compare
if (level == xegpu::LayoutKind::Subgroup)
if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
return true;
if (level == xegpu::LayoutKind::Lane)
if (getEffectiveLaneLayoutAsInt() ==
other.getEffectiveLaneLayoutAsInt() &&
getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
return true;
}
auto compareCoordsForAllIds = [&](int64_t size) {
for (int64_t id : llvm::seq<int64_t>(0, size)) {
auto coords = computeStaticDistributedCoords(id, shape);
auto otherCoords = other.computeStaticDistributedCoords(id, shape);
if (coords != otherCoords)
return false;
}
return true;
};
auto flattenedThis = flatten();
auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
if (level == xegpu::LayoutKind::Subgroup) {
int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
return compareCoordsForAllIds(wgSize);
}
if (level == xegpu::LayoutKind::InstData) {
return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
}
if (level == xegpu::LayoutKind::Lane) {
int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
return compareCoordsForAllIds(subgroupSize);
}
return true;
}
xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
if (sliceDimsToDrop.empty())
return *this;

View File

@ -677,18 +677,6 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
auto srcShape = sourceTy.getShape();
auto resShape = resultTy.getShape();
size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
if (dimDiff == 0) {
Operation *srcOp = broadcast.getSource().getDefiningOp();
if (!srcOp)
return;
[[maybe_unused]] bool hasUnitDim =
llvm::any_of(srcShape, [](int64_t dim) { return dim == 1; });
assert(
hasUnitDim && isa<vector::ShapeCastOp>(srcOp) &&
"When broadcasting from unit-dim, the producer op must be shape_cast!");
}
auto resultLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());

View File

@ -838,8 +838,11 @@ struct SgToWiConvertLayout
ConversionPatternRewriter &rewriter) const override {
auto inputLayout = op.getInputLayoutAttr();
auto targetLayout = op.getTargetLayoutAttr();
auto resShape = cast<VectorType>(op.getResult().getType()).getShape();
SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
if (!inputLayout.isCompatibleWith(targetLayout, xegpu::LayoutKind::Lane)) {
if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
xegpu::LayoutKind::Lane)) {
return rewriter.notifyMatchFailure(
op, "lowering incompatible convert_layout not yet supported");
}

View File

@ -2083,11 +2083,14 @@ struct ConvertLayoutDistribution
PatternRewriter &rewriter) const override {
auto inputLayout = op.getInputLayoutAttr();
auto targetLayout = op.getTargetLayoutAttr();
auto resShape = cast<VectorType>(op.getResult().getType()).getShape();
if (!inputLayout || !targetLayout)
return rewriter.notifyMatchFailure(op, "missing layout attributes");
if (!inputLayout.isCompatibleWith(targetLayout, xegpu::LayoutKind::Lane)) {
SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
xegpu::LayoutKind::Lane)) {
return rewriter.notifyMatchFailure(
op, "lowering incompatible convert_layout not yet supported");
}

View File

@ -52,21 +52,13 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
xegpu::DistributeLayoutAttr layout) {
int count = 1;
SmallVector<int64_t> sgShape(shape);
if (layout && layout.isForWorkgroup()) {
SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
if (!layout.getEffectiveSgDataAsInt().empty())
sgShape = layout.getEffectiveSgDataAsInt();
else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
sgShape = *maybeDerivedSgData;
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
// Clamp distUnit to the original shape to handle cases where data is
// shared among subgroups, which may cause distUnit to exceed the original
// shape.
for (size_t i = 0; i < distUnit.size(); ++i)
distUnit[i] = std::min(shape[i], distUnit[i]);
count = computeProduct(shape) / computeProduct(distUnit);
}
return std::make_pair(sgShape, count);
auto distributedShape = layout.computeDistributedShape(
SmallVector<int64_t>(shape.begin(), shape.end()));
if (failed(distributedShape))
return std::make_pair(sgShape, count);
auto sgData = layout.getEffectiveSgDataAsInt();
count = computeProduct(distributedShape.value()) / computeProduct(sgData);
return std::make_pair(sgData, count);
}
/// Utility helper for deriving a list of offsets for each sub-TensorDescs
@ -626,7 +618,8 @@ struct WgToSgConvertLayoutOp
SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
// Fast path: if sg_layout and sg_data are identical, no SLM needed
if (inputLayout.isCompatibleWith(targetLayout,
SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
xegpu::LayoutKind::Subgroup)) {
inputLayout = inputLayout.dropSgLayoutAndData();
targetLayout = targetLayout.dropSgLayoutAndData();
@ -1632,16 +1625,20 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
converter.addConversion(
[&](xegpu::TensorDescType type,
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
xegpu::LayoutAttr layout = type.getLayoutAttr();
// Only convert WG-level tensor descs. SG-level or layout-less types
// are already legal and should pass through unchanged.
if (!layout || !layout.isForWorkgroup())
return std::nullopt;
Type elemTy = type.getElementType();
ArrayRef<int64_t> shape = type.getShape();
int count;
SmallVector<int64_t> subShape;
xegpu::LayoutAttr layout = type.getLayoutAttr();
std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
if (layout)
layout = layout.dropSgLayoutAndData();
layout = layout.dropSgLayoutAndData();
auto newTy = xegpu::TensorDescType::get(
type.getContext(), subShape, elemTy, type.getEncoding(), layout);

View File

@ -116,6 +116,9 @@ xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
effectiveLaneLayout.size() &&
"Rank of the original vector type should be greater or equal to the "
"size of the lane layout to distribute the vector type.");
// TODO: replace the implementation with
// auto distributedShape = layout.computeDistributedShape(
// SmallVector<int64_t>(originalType.getShape()));
SmallVector<int64_t> distributedShape(originalType.getShape());
// Only distribute the last `laneLayout.size()` dimensions. The remaining
// dimensions are not distributed.

View File

@ -657,14 +657,6 @@ func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
return
}
// -----
func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
// expected-error@+1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>}}
!xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
return
}
// -----
func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
@ -766,8 +758,8 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) {
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
!xegpu.tensor_desc<16x2xf32,
#xegpu.scatter_tdesc_attr<chunk_size = 2>,
// expected-error@+1 {{expected sg_layout being used with sg_data}}
#xegpu.layout<sg_data = [16, 2], lane_layout = [8, 1], lane_data = [1, 2]>>
// expected-error@+1 {{sg_layout and sg_data must be used together}}
#xegpu.layout<sg_layout = [2, 1], lane_layout = [8, 1], lane_data = [1, 2]>>
return
}
@ -776,8 +768,8 @@ func.func @tensor_desc_rank_mismatch(%src: ui64, %offsets: vector<16xindex>) {
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
!xegpu.tensor_desc<16x2xf32,
#xegpu.scatter_tdesc_attr<chunk_size = 2>,
// expected-error@+1 {{expected lane_layout being used with lane_data}}
#xegpu.layout<inst_data = [16, 2], lane_data = [1, 2]>>
// expected-error@+1 {{lane_layout and lane_data must be used together}}
#xegpu.layout<inst_data = [16, 2], lane_layout = [16, 1]>>
return
}

View File

@ -27,6 +27,16 @@ gpu.func @create_nd_tdesc_subgroup_3(%src: memref<128x128xf32>) {
gpu.return
}
// -----
// CHECK: func.func @create_nd_tdesc_wrap_around_layout(%[[arg0:.*]]: memref<24x32xf32>) {
func.func @create_nd_tdesc_wrap_around_layout(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
return
}
// CHECK: gpu.func @create_nd_tdesc_wg_1(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @create_nd_tdesc_wg_1(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [3, 2], sg_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>>

View File

@ -853,7 +853,9 @@ gpu.module @test {
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16>
// CHECK-NEXT: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
// CHECK-NEXT: vector.broadcast %[[SHAPECAST]]
// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[SHAPECAST]]
// CHECK-SAME {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16>
// CHECK-NEXT: vector.broadcast %[[EXP]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
@ -862,8 +864,9 @@ func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x1
%3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
%5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
%6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16>
xegpu.store_nd %6, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
%6 = math.exp %5: vector<16x1xf16>
%7 = vector.broadcast %6 : vector<16x1xf16> to vector<16x16xf16>
xegpu.store_nd %7, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
return
}
}

View File

@ -104,28 +104,28 @@ gpu.module @test_distribution {
// CHECK-LABEL: dpas_no_sg_data
gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
// CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
%tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>>
%load_a = xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
%load_a = xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>}
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>>
-> vector<128x128xf16>
%tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
order = [1, 0]>>
%load_b = xegpu.load_nd %tdesc_b[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]> }
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
%load_b = xegpu.load_nd %tdesc_b[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]> }
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
order = [1, 0]>>
-> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
{layout_a = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
{layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>,
layout_b = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1],
order = [1, 0]>,
layout_cd = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}

View File

@ -107,22 +107,22 @@ gpu.module @test_1_1_assignment {
// CHECK-LABEL: dpas_no_sg_data
gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>>
%load_a = xegpu.load_nd %tdesc_a {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
%load_a = xegpu.load_nd %tdesc_a {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
order = [1, 0]>>
-> vector<128x128xf16>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
order = [1, 0]>>
%load_b = xegpu.load_nd %tdesc_b {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
%load_b = xegpu.load_nd %tdesc_b {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
order = [1, 0]>>
-> vector<128x128xf16>
// CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
%dpas = xegpu.dpas %load_a, %load_b
{layout_a = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>,
layout_b = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>,
layout_cd = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
{layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>,
layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>,
layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}