[MLIR][XeGPU] Refactor layout propagation utilities (#179016)

This PR refactors layout propagation into two distinct components:
result/anchor layout setup and source layout inference from the result.

For operations that require a specific result layout due to semantic or
hardware constraints, the propagation logic explicitly sets up the
result or anchor layout. Otherwise, it infers the source layout from the
backward-propagated consumer layout.

The result or anchor layout may differ from the backward-propagated
consumer layout; any such discrepancies are resolved via the existing
layout-conflict mechanism.

**This PR introduces the following utility functions:**

Source layout inference:

> inferBroadcastSourceLayout()
> inferMultiReductionSourceLayout()
> inferBitCastSourceLayout()
> inferShapeCastSourceLayout()
> inferInsertStridedSliceSourceLayout()

Result / anchor layout setup:

> setupMultiReductionResultLayout()
> setupBitCastResultLayout()
> setupInsertStridedSliceResultLayout()
> setupLoadMatrixAnchorLayout()
> setupStoreMatrixAnchorLayout()
> setupLoadGatherAnchorLayout()
> setupStoreScatterAnchorLayout()

Part of subgroup distribution related code changes are separated and
created as PR https://github.com/llvm/llvm-project/pull/179018/changes.
This commit is contained in:
Jianhui Li 2026-02-05 19:26:25 -08:00 committed by GitHub
parent 15a30e3acf
commit 61b8a57839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1902 additions and 548 deletions

View File

@ -226,16 +226,31 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
"xegpu::DistributeLayoutAttr",
"setUnitDimData",
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
/*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
"xegpu::DistributeLayoutAttr",
"setUnitDimLayout",
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
/*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
indices based on the effective layout level.}],
"FailureOr<SmallVector<Value>>",
"delinearizeId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
InterfaceMethod<[{Derive a new layout with sg_data, inst_data and lane_data set to the
specified values for the given dimension. Passing -1 for any parameter
preserves its original value.}],
"xegpu::DistributeLayoutAttr",
"setDimData",
(ins "int64_t": $dim,
"int64_t": $sgData,
"int64_t": $instData,
"int64_t": $laneData)>,
InterfaceMethod<[{Derive a new layout by collapsing dimensions.
`dimGroup` specifies a group of adjacent dimensions that are collapsed into
a single dimension in the derived layout.}],
"xegpu::DistributeLayoutAttr",
"collapseDims",
(ins "SmallVector<int64_t>": $dimGroup)>,
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
assigned to a level identified by linearId. The shape parameter
represents the higher-level problem size. Each level may access
@ -501,10 +516,20 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
// Derive a new layout with sg_data, inst_data and lane_data set to the
// specified values for the given dimension. Passing -1 for any parameter
// preserves its original value.
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
@ -672,10 +697,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
// Derive a new layout with sg_data, inst_data and lane_data set to the
// specified values for the given dimension. Passing -1 for any parameter
// preserves its original value.
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>

View File

@ -103,12 +103,6 @@ void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
const UnrollOptions &options);
enum class LayoutKind { Lane, InstData, Subgroup };
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
LayoutKind layoutKind, bool printOnly = false);
LogicalResult resolveLayoutConflicts(Operation *target);
} // namespace xegpu
} // namespace mlir

View File

@ -0,0 +1,168 @@
//===- XeGPULayoutImpl.h - Layout utility functions ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
#define MLIR_DIALECT_XEGPU_UTILS_XeGPULayoutImpl_H_
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
class VectorType;
class OpOperand;
class OpResult;
class OpBuilder;
class ValueRange;
class TypeConverter;
class OpFoldResult;
namespace xegpu {
class DistributeLayoutAttr;
class LayoutAttr;
class TensorDescType;
} // namespace xegpu
namespace xegpu {
enum class LayoutKind { Lane, InstData, Subgroup };
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
LayoutKind layoutKind, bool printOnly = false);
LogicalResult resolveLayoutConflicts(Operation *target);
/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
/// OpResult of of the given operation. If the operation contains regions, it is
/// also applied recursively to the contained operations operation.
/// TODO: To be replaced by recoverTemporaryLayouts()
void recoverTemporaryLayoutsDeprecated(Operation *op);
/// Attach layout attributes to all vector-type operands of operations within
/// the given operation's nested region. Reports an error if any vector operand
/// lacks a layout attribute.
bool recoverTemporaryLayouts(Operation *rootOp);
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
std::is_same_v<T, OpResult>>>
void removeLayoutAttr(const T &operandOrResult);
/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
/// given operation if they exist. If the operation contains regions, it is also
/// applied recursively to the contained operations
void removeLayoutAttrs(Operation *op);
/// Updates the NamedAttribute sequence by dropping sg-layout and
/// sg-data information from any DistributeLayoutAttr found.
SmallVector<NamedAttribute>
dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
/// Updates the NamedAttribute sequence by dropping inst-data information from
/// any DistributeLayoutAttr found.
SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
/// Infers the source layout attribute for a broadcast operation given the
/// result layout attribute, result shape, and source shape.
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
/// Infers the source layout attribute for a reduction operation given the
/// result layout attribute and reduced dims.
DistributeLayoutAttr
inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
SmallVector<int64_t> reduceDims);
/// Infers the source layout attribute for a bitcast operation given the
/// result layout attribute, result element type bitwidth, and source element
/// type bitwidth.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout,
int resElemTyBitWidth,
int srcElemTyBitWidth);
/// Infers the source layout attribute for a shape cast operation given the
/// result layout attribute, result shape, and source shape.
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
/// Infers the source layout attribute for an insert strided slice operation
/// given the result layout attribute, result shape, and source shape. Removes
/// leading dimensions from the result layout to match the source shape size.
DistributeLayoutAttr
inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
/// This function first attempts to construct a source layout that, when
/// sliced along reduction dimensions, produces a result layout compatible
/// with the consumer's preferred layout. This minimizes data redistribution
/// overhead. The SliceAttr for the result is then created based on the
/// derived source layout and the specified reduction dimensions.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind,
VectorType srcVectorTy,
DistributeLayoutAttr consumerLayout,
SmallVector<int64_t> reductionDims,
const uArch::uArch *uArch);
/// Setup the result layout attribute for a bitcast operation based on element
/// type bitwidths. This ensures the source layout can always be derived from
/// the result layout.
///
/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
/// resElemTyBitWidth), the result layout's innermost dimension data sizes
/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This maintains
/// the invariant that the source layout can be recovered by adjusting the
/// result layout based on bitwidth ratio of input vs output.
DistributeLayoutAttr setupBitCastResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
/// Sets up the result layout for an insert strided slice operation.
/// Creates a result layout based on the specified layout kind (InstData or
/// Lane).
DistributeLayoutAttr setupInsertStridedSliceResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
/// Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr
setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
int chunkSize, DistributeLayoutAttr consumerLayout,
const uArch::uArch *uArch);
/// Sets up the anchor layout for load matrix operation.
DistributeLayoutAttr
setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
DistributeLayoutAttr consumerLayout,
const uArch::uArch *uArch);
/// Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
VectorType vectorTy,
int chunkSize,
const uArch::uArch *uArch);
/// Sets up the anchor layout for a store matrix operation.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
VectorType vectorTy,
const uArch::uArch *uArch);
} // namespace xegpu
} // namespace mlir
#endif // MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_

View File

@ -137,12 +137,6 @@ template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {});
/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
std::string getTemporaryLayoutName(const OpOperand &operand);
/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
std::string getTemporaryLayoutName(const OpResult result);
/// Retrieves the DistributeLayoutAttr associated with a given Value. For
/// TensorDescType values, the DistributeLayoutAttr is extracted from the
/// TensorDescType itself. For other values, it is obtained from the attributes
@ -155,26 +149,6 @@ DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
/// found, it will check the operand itself and its defining op.
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
std::is_same_v<T, OpResult>>>
void removeLayoutAttr(const T &operandOrResult);
/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
/// given operation if they exist. If the operation contains regions, it is also
/// applied recursively to the contained operations
void removeLayoutAttrs(Operation *op);
/// Updates the NamedAttribute sequence by dropping sg-layout and
/// sg-data information from any DistributeLayoutAttr found.
SmallVector<NamedAttribute>
dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
/// Updates the NamedAttribute sequence by dropping inst-data information from
/// any DistributeLayoutAttr found.
SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
/// user should use setAnchorLayout instead
void setDistributeLayoutAttr(const OpResult &Result,
@ -185,6 +159,12 @@ void setDistributeLayoutAttr(const OpResult &Result,
void setDistributeLayoutAttr(const OpOperand &opr,
const DistributeLayoutAttr layout);
/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
std::string getTemporaryLayoutName(const OpOperand &operand);
/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
std::string getTemporaryLayoutName(const OpResult result);
/// get and set distribute layout attribute for non-anchor operations
/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
template <typename T,
@ -198,17 +178,6 @@ template <typename T,
void setTemporaryLayout(const T &operandOrResult,
const DistributeLayoutAttr layout);
/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
/// OpResult of of the given operation. If the operation contains regions, it is
/// also applied recursively to the contained operations operation.
/// TODO: To be replaced by recoverTemporaryLayouts()
void recoverTemporaryLayoutsDeprecated(Operation *op);
/// Attach layout attributes to all vector-type operands of operations within
/// the given operation's region. Reports an error if any vector operand lacks
/// a layout attribute.
bool recoverTemporaryLayouts(Operation *rootOp);
/// Helper function to check if the layout is packed. Layout is packed if it is
/// 2D and lane_data[0] != 1 (data packed from col dimension).
/// TODO: Move to target info.
@ -217,6 +186,15 @@ bool requirePacked(const LayoutAttr layout);
/// Helper function to check if the layout requires a transpose effect.
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch);
// Check if dst shape is an expansion of src shape by inserting unit dimensions.
bool matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
SmallVector<int64_t> &expandedUnitDims);
// Checks if dst shape is an expansion of src shape where each dimension in src
// is split into one or more consecutive dimensions in dst
bool matchSplitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
SmallVector<SmallVector<int64_t>> &splitDimGroups);
} // namespace xegpu
} // namespace mlir

View File

@ -216,15 +216,19 @@ protected:
};
struct SpirvLoadGatherInstruction : public LoadGatherInstructionInterface {
int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
return 16;
}
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
};
struct SpirvStoreScatterInstruction : public StoreScatterInstructionInterface {
int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
return 16;
}
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
};
struct LoadMatrixInstruction : public LoadMatrixInstructionInterface {
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
};
struct StoreMatrixInstruction : public StoreMatrixInstructionInterface {
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
};
//===----------------------------------------------------------------------===//
@ -239,9 +243,11 @@ struct PVCuArch final : public Xe2Plus {
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
static const SpirvStoreScatterInstruction storeScatterInst;
static const SpirvLoadGatherInstruction loadGatherInst;
static const Instruction *arr[] = {&dpasInst, &loadNdInst,
&storeNdInst, &prefetchNdInst,
&storeScatterInst, &loadGatherInst};
static const StoreMatrixInstruction storeMatrixInst;
static const LoadMatrixInstruction loadMatrixInst;
static const Instruction *arr[] = {
&dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst,
&storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
return arr;
}

View File

@ -40,7 +40,9 @@ enum class InstructionKind {
Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
StoreScatter, // Lane-level store (scalar, vector)
LoadGather // Lane-level load (scalar, vector)
LoadGather, // Lane-level load (scalar, vector)
StoreMatrix, // Lane-level matrix store to slm
LoadMatrix // Lane-level matrix load to slm
// @TODO: Add more instructions as needed
};
@ -71,6 +73,10 @@ struct Instruction {
return "store";
case InstructionKind::LoadGather:
return "load";
case InstructionKind::StoreMatrix:
return "store_matrix";
case InstructionKind::LoadMatrix:
return "load_matrix";
}
llvm_unreachable("Unknown InstructionKind");
}
@ -254,17 +260,6 @@ struct MMAInstructionInterface {
// Common instructions (shared across architectures)
//===----------------------------------------------------------------------===//
struct StoreScatterInstructionInterface : public Instruction {
StoreScatterInstructionInterface()
: Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
static bool classof(const Instruction *B) {
return B->getInstructionKind() == InstructionKind::StoreScatter;
}
virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
virtual ~StoreScatterInstructionInterface() = default;
};
struct LoadGatherInstructionInterface : public Instruction {
LoadGatherInstructionInterface()
: Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
@ -272,10 +267,43 @@ struct LoadGatherInstructionInterface : public Instruction {
return B->getInstructionKind() == InstructionKind::LoadGather;
}
virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
virtual ~LoadGatherInstructionInterface() = default;
};
struct StoreScatterInstructionInterface : public Instruction {
StoreScatterInstructionInterface()
: Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
static bool classof(const Instruction *B) {
return B->getInstructionKind() == InstructionKind::StoreScatter;
}
virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
virtual ~StoreScatterInstructionInterface() = default;
};
struct LoadMatrixInstructionInterface : public Instruction {
LoadMatrixInstructionInterface()
: Instruction(InstructionKind::LoadMatrix, InstructionScope::Lane) {}
static bool classof(const Instruction *B) {
return B->getInstructionKind() == InstructionKind::LoadMatrix;
}
virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
virtual ~LoadMatrixInstructionInterface() = default;
};
struct StoreMatrixInstructionInterface : public Instruction {
StoreMatrixInstructionInterface()
: Instruction(InstructionKind::StoreMatrix, InstructionScope::Lane) {}
static bool classof(const Instruction *B) {
return B->getInstructionKind() == InstructionKind::StoreMatrix;
}
virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
virtual ~StoreMatrixInstructionInterface() = default;
};
} // namespace uArch
} // namespace xegpu
} // namespace mlir

View File

@ -398,7 +398,7 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
auto sgDataOpt = getSgData();
auto instDataOpt = getInstData();
auto laneDataOpt = getLaneData();
@ -407,15 +407,14 @@ LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
SmallVector<int32_t> instData;
SmallVector<int32_t> laneData;
if (sgDataOpt) {
if (sgDataOpt)
sgData = llvm::to_vector(sgDataOpt.asArrayRef());
}
if (instDataOpt) {
if (instDataOpt)
instData = llvm::to_vector(instDataOpt.asArrayRef());
}
if (laneDataOpt) {
if (laneDataOpt)
laneData = llvm::to_vector(laneDataOpt.asArrayRef());
}
for (auto dim : unitDims) {
if (dim < static_cast<int64_t>(sgData.size()))
@ -440,19 +439,17 @@ LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr
LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
auto sgLayoutOpt = getSgLayout();
auto laneLayoutOpt = getLaneLayout();
SmallVector<int32_t> sgLayout;
SmallVector<int32_t> laneLayout;
if (sgLayoutOpt) {
if (sgLayoutOpt)
sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
}
if (laneLayoutOpt) {
if (laneLayoutOpt)
laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
}
for (auto dim : unitDims) {
if (dim < static_cast<int64_t>(sgLayout.size()))
@ -471,6 +468,174 @@ LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
getLaneData(), getOrder());
}
// Derive a new layout with sg_data, inst_data and lane_data set to the
// specified values for the given dimension
DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
int64_t instData,
int64_t laneData) {
SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
sgDataVec[dim] = sgData;
if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
instDataVec[dim] = instData;
if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
laneDataVec[dim] = laneData;
SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
return LayoutAttr::get(
getContext(), getSgLayout(),
sgDataVec.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), sgDataVec32),
instDataVec.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), instDataVec32),
getLaneLayout(),
laneDataVec.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), laneDataVec32),
getOrder());
}
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
DenseI32ArrayAttr orderAttr = getOrder();
SmallVector<int32_t> orderVec;
if (orderAttr && !orderAttr.empty()) {
orderVec = llvm::to_vector(
llvm::map_range(orderAttr.asArrayRef(),
[](int32_t idx) { return static_cast<int32_t>(idx); }));
}
SmallVector<int64_t> sortedDimGroup = dimGroup;
llvm::sort(sortedDimGroup);
int64_t dimBeforeCurrent = -1;
for (auto dimIdx : sortedDimGroup) {
// when order is present, adjacency dims are on order values like [3, 2, 1,
// 0] in decreasing order otherwise based on dim indices like [0, 1, 2, 3]
// in increasing order
if (dimBeforeCurrent >= 0) {
if (!orderVec.empty()) {
int64_t orderBefore = orderVec[dimBeforeCurrent];
int64_t orderCurrent = orderVec[dimIdx];
if (orderBefore != (orderCurrent - 1))
llvm::report_fatal_error(
"dimensions being collapsed must be adjacent in order");
} else {
if (dimIdx != (dimBeforeCurrent + 1))
llvm::report_fatal_error(
"dimensions being collapsed must be adjacent");
}
}
dimBeforeCurrent = dimIdx;
}
int firstDim = sortedDimGroup.front();
// collapse the dimensions in dimGroup into one dimension by multiplying their
// sizes together
if (!sgLayout.empty()) {
int64_t collapsedSglayout = 1, collapsedSgData = 1;
for (auto dimIdx : dimGroup) {
collapsedSglayout *= sgLayout[dimIdx];
collapsedSgData *= sgData[dimIdx];
}
for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
}
sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
sgData.insert(sgData.begin() + firstDim, collapsedSgData);
}
if (!instData.empty()) {
int64_t collapsedInstData = 1;
for (auto dimIdx : dimGroup)
collapsedInstData *= instData[dimIdx];
for (auto dimIdx : llvm::reverse(sortedDimGroup))
instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
instData.insert(instData.begin() + firstDim, collapsedInstData);
}
if (!laneLayout.empty()) {
int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
for (auto dimIdx : dimGroup) {
collapsedLaneLayout *= laneLayout[dimIdx];
collapsedLaneData *= laneData[dimIdx];
}
for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
laneLayout.erase(laneLayout.begin() + dimIdx,
laneLayout.begin() + dimIdx + 1);
laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
}
laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
}
// go through the values inside collapsedOrder, and re-map the order values
// to be in range of [0, N-1] where N is the number of dimensions in
// collapsed shape for exmaple, collapse dim group {2, 3} of order[1, 2, 3,
// 4] to new order[1, 3, 4]. the loop below remaps it to [1, 2, 3].
SmallVector<int32_t> collapsedOrder;
if (!orderVec.empty()) {
for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
if (dimIdx != firstDim)
orderVec.erase(orderVec.begin() + dimIdx,
orderVec.begin() + dimIdx + 1);
}
// say we have orderVec = {5, 3, 2, 1, 0}
// Create indices [0, 1, 2, 3, 4]
SmallVector<size_t> indices =
llvm::to_vector(llvm::seq<size_t>(0, orderVec.size()));
// Sort indices based on corresponding values
llvm::sort(indices,
[&](size_t a, size_t b) { return orderVec[a] < orderVec[b]; });
collapsedOrder = llvm::to_vector(llvm::map_range(
indices, [&](size_t i) { return static_cast<int32_t>(i); }));
}
// Create collapsed layout
SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
SmallVector<int32_t> instData32(instData.begin(), instData.end());
SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
auto collapsedLayout = xegpu::LayoutAttr::get(
getContext(),
sgLayout32.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), sgLayout32),
sgData32.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), sgData32),
instData32.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), instData32),
laneLayout32.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), laneLayout32),
laneData32.empty() ? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), laneData32),
collapsedOrder.empty()
? DenseI32ArrayAttr()
: DenseI32ArrayAttr::get(getContext(), collapsedOrder));
return collapsedLayout;
}
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@ -624,12 +789,12 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
// it maps to dim [2] in parent space.
static SetVector<int64_t>
mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
static SmallVector<int64_t>
mapSlicedDimsToParentSpace(const SmallVector<int64_t> &dimsToMap,
ArrayRef<int64_t> sliceDims) {
// Rather than recovering the exact parent rank, we compute a safe upper bound
// so that dimsToMap can be adjusted safely. This upper bound is defined as
// max(dimsToMap, sliceDims) + 1 + sliceDims.size().
// Rather than recovering the exact parent rank, we compute a safe upper
// bound so that dimsToMap can be adjusted safely. This upper bound is
// defined as max(dimsToMap, sliceDims) + 1 + sliceDims.size().
int64_t maxDim = -1;
maxDim =
std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
@ -648,10 +813,10 @@ mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
}
// Map unit dims from sliced space to parent space
SetVector<int64_t> adjustUnitDims;
SmallVector<int64_t> adjustUnitDims;
for (auto dim : dimsToMap) {
int64_t mappedDim = remainingDims[dim];
adjustUnitDims.insert(mappedDim);
adjustUnitDims.push_back(mappedDim);
}
return adjustUnitDims;
@ -659,12 +824,12 @@ mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
DistributeLayoutAttr parentLayout = getParent();
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
SetVector<int64_t> adjustUnitDims =
SmallVector<int64_t> adjustUnitDims =
mapSlicedDimsToParentSpace(unitDims, sliceDims);
return SliceAttr::get(getContext(),
@ -673,18 +838,51 @@ SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr
SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
DistributeLayoutAttr parentLayout = getParent();
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
SetVector<int64_t> adjustUnitDims =
SmallVector<int64_t> adjustUnitDims =
mapSlicedDimsToParentSpace(unitDims, sliceDims);
return SliceAttr::get(
getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
}
// Derive a new layout with sg_data, inst_data and lane_data set to the
// specified values for the given dimension
DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
int64_t instData, int64_t laneData) {
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
auto parent = getParent();
SmallVector<int64_t> dimSet;
dimSet.push_back(dim);
SmallVector<int64_t> adjustDims =
mapSlicedDimsToParentSpace(dimSet, sliceDims);
return SliceAttr::get(
getContext(),
parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
}
// Derive a new layout by collapsing dimensions.
// `dimGroup` specifies a group of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
// Map the sliced dims from parent space to collapsed space
SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
SmallVector<int64_t> dimsInParentSpace =
mapSlicedDimsToParentSpace(dimGroup, sliceDims);
auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
return SliceAttr::get(getContext(), collapsedParent,
DenseI64ArrayAttr::get(getContext(), sliceDims));
}
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
@ -820,7 +1018,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
return emitError() << "unsupported element type " << elementType
<< ": expected integer or float";
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
// for gather and scatter ops, Low-precision types are packed in 32-bit
// units.
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
int chunkAlignmentFactor =
bitWidth < xegpu::uArch::generalPackedFormatBitSize

View File

@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUPropagateLayout.cpp
XeGPUVectorLinearize.cpp
XeGPUPeepHoleOptimizer.cpp
XeGPULayoutImpl.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/PassManager.h"

View File

@ -0,0 +1,851 @@
//===---- XeGPULayoutImpl.cpp - MLIR Utilities for XeGPUOps
//------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements layout utility functions for XeGPU dialect
// transformation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
using namespace mlir;
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
auto layout = getDistributeLayoutAttr(opr.get());
setDistributeLayoutAttr(opr, layout);
}
for (OpResult result : nestOp->getOpResults()) {
auto layout = getDistributeLayoutAttr(result);
setDistributeLayoutAttr(result, layout);
}
});
}
SmallVector<NamedAttribute>
xegpu::dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute> out;
out.reserve(attrs.size());
for (auto attr : attrs) {
if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
auto newLayout = dist.dropSgLayoutAndData();
if (newLayout)
out.emplace_back(attr.getName(), newLayout);
} else {
out.push_back(attr);
}
}
return out;
}
SmallVector<NamedAttribute>
xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute> out;
out.reserve(attrs.size());
for (auto attr : attrs) {
if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
auto newLayout = dist.dropInstData();
if (newLayout)
out.emplace_back(attr.getName(), newLayout);
} else {
out.push_back(attr);
}
}
return out;
}
// Attach layout attributes to all vector-type operands of operations within
// the given operation's region. Reports an error if any vector operand lacks
// a layout attribute.
bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
auto result = rootOp->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
<< operand.getOperandNumber() << " of operation " << op->getName();
return WalkResult::interrupt();
}
xegpu::setDistributeLayoutAttr(operand, layout);
}
return WalkResult::advance();
});
return !result.wasInterrupted();
}
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->removeAttr(name);
}
// Explicit instantiation for OpResult
template void
xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
// Explicit instantiation for OpOperand
template void
xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
void xegpu::removeLayoutAttrs(Operation *op) {
op->walk([&](Operation *nestOp) {
// Remove all attributes of DistributeLayoutAttr type
SmallVector<StringAttr> attrsToRemove;
for (auto namedAttr : nestOp->getAttrs()) {
if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
attrsToRemove.push_back(namedAttr.getName());
}
for (auto attrName : attrsToRemove)
nestOp->removeAttr(attrName);
});
}
/// Infers the source layout attribute for a broadcast operation given the
/// result layout attribute, result shape, source shape.
xegpu::DistributeLayoutAttr
xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape) {
SmallVector<int64_t> bcastDims;
auto returnLayout = resLayout;
// Handling broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
int dimDiff = resShape.size() - srcShape.size();
if (dimDiff > 0) {
// Adding the missing leading dims
for (int i = 0; i < dimDiff; i++)
bcastDims.push_back(i);
// Create a slice layout for the source
returnLayout = xegpu::SliceAttr::get(
resLayout.getContext(), resLayout,
DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
}
return returnLayout;
}
/// Infers the source layout attribute for a reduction operation given the
/// result layout attribute and reduced dims.
xegpu::DistributeLayoutAttr
xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
SmallVector<int64_t> reduceDims) {
assert(isa<xegpu::SliceAttr>(resLayout) &&
"reduction result layout must be slice layout");
xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
auto sliceDims = sliceLayout.getDims().asArrayRef();
assert(reduceDims == sliceDims &&
"reduction dims must match with slice dims");
return sliceLayout.getParent();
}
/// Infers the source layout attribute for a bitcast operation given the
/// result layout attribute, result element type bitwidth, and source element
/// type bitwidth.
xegpu::DistributeLayoutAttr
xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
int resElemTyBitWidth, int srcElemTyBitWidth) {
SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
size_t sgDataSize = sgData.size();
size_t instDataSize = instData.size();
size_t laneDataSize = laneData.size();
int64_t sgDataValue = -1;
int64_t instDataValue = -1;
int64_t laneDataValue = -1;
int64_t dim = resLayout.getRank() - 1;
if (srcElemTyBitWidth <= resElemTyBitWidth) {
int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
if (sgDataSize)
sgDataValue = sgData.back() * bitWidthRatio;
if (instDataSize)
instDataValue = instData.back() * bitWidthRatio;
if (laneDataSize)
laneDataValue = laneData.back() * bitWidthRatio;
} else {
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
if (sgDataSize) {
assert((sgData.back() % bitWidthRatio) == 0 &&
"sgData not divisible by bitWidthRatio");
sgDataValue = sgData.back() / bitWidthRatio;
}
if (instDataSize) {
assert((instData.back() % bitWidthRatio) == 0 &&
"instData not divisible by bitWidthRatio");
instDataValue = instData.back() / bitWidthRatio;
}
if (laneDataSize) {
assert((laneData.back() % bitWidthRatio) == 0 &&
"laneData not divisible by bitWidthRatio");
laneDataValue = laneData.back() / bitWidthRatio;
}
}
xegpu::DistributeLayoutAttr finalSrcLayout;
finalSrcLayout =
resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
return finalSrcLayout;
}
/// Infers the source layout attribute for an insert strided slice operation
/// given the result layout attribute, result shape, and source shape. Removes
/// leading dimensions from the result layout to match the source shape size.
xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape) {
int srcShapeSize = srcShape.size();
int resShapeSize = resShape.size();
int dimDiff = resShapeSize - srcShapeSize;
assert(isa<xegpu::LayoutAttr>(resLayout) &&
"insertStridedSlice result layout must be plain layout");
auto context = resLayout.getContext();
auto resInstData = resLayout.getEffectiveInstDataAsInt();
auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
if (resInstData.size() != 0) {
SmallVector<int> inferredInstData(srcShapeSize);
for (int i = 0; i < srcShapeSize; i++)
inferredInstData[i] = resInstData[i + dimDiff];
return xegpu::LayoutAttr::get(context, inferredInstData);
}
if (resLaneLayout.size() != 0) {
SmallVector<int> inferredLaneLayout(srcShapeSize);
SmallVector<int> inferredLaneData(srcShapeSize);
for (int i = 0; i < srcShapeSize; i++) {
inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
inferredLaneData[i] = resLaneData[i + dimDiff];
}
return xegpu::LayoutAttr::get(context, inferredLaneLayout,
inferredLaneData);
}
return nullptr;
}
/// Infers the source layout attribute for a shape cast operation given the
/// result layout attribute, result shape, and source shape.
xegpu::DistributeLayoutAttr
xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape) {
// There are three use cases:
// 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
// tensor before broadcast
// 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
// for multi-stage reduction
// 3. combines all dims to a single dim and put in the innermost dim in 2d as
// [1, combinedData] or [combinedData]. Say, [2, 4, 8] -> [1, 64] or [64]
// Use cases are only supported after workgroup distribution,
// like cross-sg reduction saves multidimension data to
// 1D slm buffer, shapecast inserted by cse/canonicalization passes.
// Use case 1: Shapes only differ by expanding unit dimensions, for broadcast
SmallVector<int64_t> expandedUnitDims;
if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
// create a slice layout for the source by removing the expanded unit dims
auto sliceDimsAttr = DenseI64ArrayAttr::get(
resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
auto srcLayout =
xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
return srcLayout;
}
// Use case 2: Dim split from source to result, for multi-stage reduction
SmallVector<SmallVector<int64_t>> splitDimGroups;
if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups)) {
auto srcLayout = resLayout;
for (const auto &dimGroup : splitDimGroups)
srcLayout = srcLayout.collapseDims(dimGroup);
return srcLayout;
}
// Use case 3: Collaspse to innermost dim, for cross-sg reduction to SLM
auto matchCollapseToInnermostDim = [&](ArrayRef<int64_t> src,
ArrayRef<int64_t> dst) -> bool {
// only one non-unit dim in dst which is the innermost dim
if ((dst.size() != 2) && (dst.size() != 1))
return false;
int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
std::multiplies<int64_t>());
if (dst.size() == 1)
return (dst[0] == srcSize);
return (dst[0] == 1) && (dst[1] == srcSize);
};
if (matchCollapseToInnermostDim(srcShape, resShape)) {
int srcShapeSize = srcShape.size();
int resShapeSize = resShape.size();
auto context = resLayout.getContext();
auto resInstData = resLayout.getEffectiveInstDataAsInt();
auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
// Extract layout info from result's innermost dimension and apply to
// source's innermost dimension while setting all other dimensions to 1.
// The inferred layout is restricted by srcShape to ensure it fits within
// the source dimensions.
// Examples 1:
// srcShape=[8, 16, 32], resShape=[1, 4096]
// resInstData=[1, 16]
// -> inferredInstData=[1, 1, min(16, 32)]=[1, 1, 16]
// Examples 2:
// srcShape=[4, 8, 64], resShape=[2048]
// resLaneLayout=[16], resLaneData=[2]
// -> inferredLaneLayout=[1, 1, 16]
// -> inferredLaneData=[1, 1, min(2, 64/16)]=[1, 1, 2]
if (resInstData.size() != 0) {
// assert resInstData must be 1 for all but the innermost dim
for (int i = 0; i < resShapeSize - 1; i++) {
assert(resInstData[i] == 1 &&
"only innermost dim can have non-unit instData");
}
SmallVector<int> inferredInstData(srcShapeSize, 1);
inferredInstData[srcShapeSize - 1] =
std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
return xegpu::LayoutAttr::get(context, inferredInstData);
}
if (resLaneLayout.size() != 0) {
for (int i = 0; i < resShapeSize - 1; i++) {
assert(resLaneData[i] == 1 &&
"only innermost dim can have non-unit instData");
}
assert(srcShape.back() % resLaneLayout.back() == 0 &&
"source innermost dim must be >= result lane layout");
SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
SmallVector<int> inferredLaneData(srcShapeSize, 1);
inferredLaneLayout.back() = resLaneLayout.back();
inferredLaneData.back() = std::min(
resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
return xegpu::LayoutAttr::get(context, inferredLaneLayout,
inferredLaneData);
}
}
llvm_unreachable("running into unsupported shape cast scenarios");
return nullptr;
}
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
/// Algorithm Overview:
/// This function attempts to construct a source layout that, when sliced along
/// reduction dimensions, produces a result layout compatible with the
/// consumer layout.
///
/// For subgroup layouts, it first tries to align the source layout's subgroup
/// layout and data with the consumer's layout on non-reduction dimensions.
/// Then, it distributes remaining subgroups across reduction dimensions. This
/// avoids subgroup data redistribution overhead between the reduced result and
/// its consumer.
///
/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
/// Lane Layout requires {1, ..., 1, subgroupSize}
/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
///
/// Examples:
/// 1. Subgroup layout - Row reduction on 2D tensor:
/// srcShape=[32, 64], reductionDims=[1], resShape=[32], subgroupSize=16,
/// workgroupSize=32
/// Consumer Layout:
/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
/// [1]>} Result: srcLayout with sgLayout=[4, 8], sgData=[8, 8] (matches
/// consumer on non-reduction dim, minimizing data redistribution on
/// reduction dim)
/// 2. Subgroup layout - Same example above but consumer has different layout:
/// sgLayout=[32], sgData=[1]
/// Result: srcLayout with sgLayout=[32,1], sgData=[1, 64]
/// (distributes all subgroups on non reduction dim)
///
/// 2. InstData layout - Column reduction:
/// srcShape=[32, 64], reductionDims=[0], subgroupSize=16
/// Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
/// innermost)
///
/// 3. Lane layout - Multi-dimensional reduction:
/// srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
/// Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
/// (subgroupSize on innermost dim, max vector size on reduction dim)
xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy,
DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
const xegpu::uArch::uArch *uArch) {
auto srcShape = srcVecTy.getShape();
int srcRank = srcShape.size();
auto context = consumerLayout.getContext();
// Reduction layout requires at least 2D tensors
if (srcRank < 2)
return nullptr;
// Helper lambda to convert int64 vectors to int32 DenseArrayAttr
auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
SmallVector<int32_t> vec32(vec.begin(), vec.end());
return DenseI32ArrayAttr::get(context, vec32);
};
// Extract original plain layout for workgroup/subgroup size recovery
xegpu::SliceAttr consumerSliceLayout =
dyn_cast<xegpu::SliceAttr>(consumerLayout);
DistributeLayoutAttr plainLayout =
consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
: consumerLayout;
const int subgroupSize = uArch->getSubgroupSize();
int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
xegpu::DistributeLayoutAttr srcLayout;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
const int workgroupSize = std::accumulate(
sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
SmallVector<int64_t> consumerSgLayout =
consumerLayout.getEffectiveSgLayoutAsInt();
int remainingSgCount = workgroupSize;
int consumerIdx = consumerSgLayout.size() - 1;
// First pass: Match consumer's layout on non-reduction dimensions
for (int i = srcRank - 1; i >= 0; i--) {
if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
sgLayout[i] = consumerSgLayout[consumerIdx];
assert((srcShape[i] % sgLayout[i] == 0) &&
"source shape not divisible by consumer sg_layout");
sgData[i] = srcShape[i] / sgLayout[i];
remainingSgCount /= sgLayout[i];
consumerIdx--;
}
}
// Second pass: Distribute remaining subgroups across reduction dimensions
for (int i = srcRank - 1; i >= 0; i--) {
if (llvm::is_contained(reductionDims, i)) {
sgLayout[i] =
std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
assert((srcShape[i] % sgLayout[i] == 0) &&
"source shape not divisible by sg_layout");
sgData[i] = srcShape[i] / sgLayout[i];
remainingSgCount /= sgLayout[i];
}
}
assert(remainingSgCount == 1 && "not all subgroups distributed");
srcLayout = xegpu::LayoutAttr::get(
context, toInt32Attr(sgLayout), toInt32Attr(sgData),
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
/*lane_data =*/nullptr, /*order =*/nullptr);
} else if (layoutKind == xegpu::LayoutKind::InstData) {
SmallVector<int64_t> instData(srcRank, 1);
instData[srcRank - 2] =
std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
instData[srcRank - 1] = subgroupSize;
srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
} else if (layoutKind == xegpu::LayoutKind::Lane) {
SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
laneLayout[srcRank - 1] = subgroupSize;
laneData[srcRank - 2] =
std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
toInt32Attr(laneData),
consumerLayout.getOrder());
}
return xegpu::SliceAttr::get(context, srcLayout,
DenseI64ArrayAttr::get(context, reductionDims));
}
/// Sets up the result layout for a bitcast operation.
/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
/// result layout can be correctly divided back to the source layout during
/// inference.
///
/// Examples:
/// 1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
/// Consumer layout: instData=[1, 16], subgroupSize=16
/// Source shape: [8, 32]
/// Result layout: instData=[1, 32] (16 * 2)
/// The innermost dimension is multiplied by 2 to maintain consistency.
///
/// 2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
/// Consumer instData=[1, 16], subgroupSize=16
/// Source shape: [4, 128]
/// adjust the instData from [1, 16] to [1, 16 * 4 = 64]
///
/// 3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
/// No adjustment needed - returns consumer layout directly.
///
xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
size_t dim = srcShape.size() - 1;
int64_t sgDataValue = -1;
int64_t instDataValue = -1;
int64_t laneDataValue = -1;
const int subgroupSize = uArch->getSubgroupSize();
if (srcElemTyBitWidth > resElemTyBitWidth) {
// When casting to a smaller bitwidth, multiply the result layout
// accordingly to ensure it can be divided by the ratio back to the
// source layout.
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
int innermostDimLaneLayout = subgroupSize;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
assert(sgData.size() == srcShape.size() &&
"sgData must be available for all dimensions");
sgDataValue = sgData[dim];
} else if (layoutKind == xegpu::LayoutKind::InstData) {
assert(instData.size() == srcShape.size() &&
"instData must be available for all dimensions");
instDataValue = instData[dim];
// Adjust instDataValue so it still fits within an instruction after
// dividing by bitWidthRatio
while ((instDataValue <= srcShape[dim]) &&
(instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
instDataValue *= 2;
assert((srcShape[dim] % instDataValue) == 0 &&
"srcShape, instData, and lanelayout for innermost must be 2^n !");
} else if (layoutKind == xegpu::LayoutKind::Lane) {
assert(laneData.size() == srcShape.size() &&
"laneData must be available for all dimensions");
laneDataValue = laneData[dim];
while ((laneDataValue <= srcShape[dim]) &&
(laneDataValue % bitWidthRatio != 0))
laneDataValue *= 2;
}
// Now set only instData and laneData, preserving sgData
xegpu::DistributeLayoutAttr resLayout;
resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
laneDataValue);
return resLayout;
}
return consumerLayout;
}
/// Sets up the result layout for an insert strided slice operation.
/// Creates a result layout based on the specified layout kind (InstData or
/// Lane).
/// Subgroup layout is currently not supported for this operation.
/// InstData layout is first set to be {1, .., subgroupSize}.
/// Lane layout is first set to be {1, ..., subgroupSize} with lane data {1,
/// ..., 1}. The instData and laneData is then adjusted to contain packed data,
/// by checking if the consumerLayout's innermost dimension.
///
/// Examples:
/// 1. InstData layout without packing:
/// resShape=[8, 32], subgroupSize=16, bitwidth=32
/// packingFactor=1, packedDataSize=16
/// consumerLayout: instData=[1, 16]
/// Result: instData=[1, 16]
///
/// 2. InstData layout with packing:
/// resShape=[8, 64], subgroupSize=16, bitwidth=8, packingFactor=4
/// consumerLayout: instData=[1, 64]
/// Result: instData=[1, 64] (adjusted for packed data)
///
/// 3. Lane layout without packing:
/// resShape=[4, 64], subgroupSize=16, bitwidth=32
/// consumerLayout: laneLayout=[1, 16], laneData=[1, 1]
/// Result: laneLayout=[1, 16], laneData=[1, 1]
///
/// 4. Lane layout with packing:
/// resShape=[4, 64], subgroupSize=16, bitwidth=16, packingFactor=2
/// consumerLayout: laneLayout=[1, 16], laneData=[1, 2]
/// Result: laneLayout=[1, 16], laneData=[1, 2] (adjusted for packed data)
xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
const xegpu::uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredResLayout;
auto subgroupSize = uArch->getSubgroupSize();
auto context = resVectorTy.getContext();
auto resShape = resVectorTy.getShape();
int resShapeSize = resShape.size();
auto srcShape = srcVectorTy.getShape();
SmallVector<int64_t> consumerInstData =
consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> consumerLaneData =
consumerLayout.getEffectiveLaneDataAsInt();
SmallVector<int> instData(resShapeSize, 1);
SmallVector<int> laneLayout(resShapeSize, 1);
SmallVector<int> laneData(resShapeSize, 1);
const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
int packedDataSize = subgroupSize * packingFactor;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
assert(true &&
"subgroup layout assignment not supported for insertStridedSlice.");
} else if (layoutKind == xegpu::LayoutKind::InstData) {
assert(srcShape.back() >= subgroupSize &&
"source innermost dim must be >= subgroupSize");
instData.back() = subgroupSize;
if (consumerInstData.back() == packedDataSize &&
srcShape.back() >= packedDataSize)
instData.back() = packedDataSize;
requiredResLayout = xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout.back() = subgroupSize;
laneData.back() = 1;
if (consumerLaneData.back() == packingFactor &&
srcShape.back() >= packedDataSize)
laneData.back() = packingFactor;
requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
return requiredResLayout;
}
/// Sets up the anchor layout for load gather and load matrix operation.
/// load matrix lowers to load gather and 1d block load. All of them share the
/// same layout setup logic.
/// For Subgroup layout, uses the consumer layout directly.
/// non-chunked loads:
/// InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
/// LaneLayout = {1, ..., subgroupSize}
/// lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
/// chunked loads:
/// InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
/// LaneLayout = {subgroupSize, 1}
/// lane_data={1,min(consumer, maxLaneLoadSize)}
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
int maxChunkSize, int valShapeSize, int subgroupSize) {
if (layoutKind == xegpu::LayoutKind::Subgroup)
return consumerLayout;
SmallVector<int64_t> consumerInstData =
consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> consumerLaneData =
consumerLayout.getEffectiveLaneDataAsInt();
SmallVector<int> instData(valShapeSize, 1);
SmallVector<int> laneLayout(valShapeSize, 1);
SmallVector<int> laneData(valShapeSize, 1);
if (!isChunkedLoad) {
if (layoutKind == xegpu::LayoutKind::InstData) {
instData[valShapeSize - 1] =
std::min(static_cast<int>(consumerInstData[valShapeSize - 1]),
maxChunkSize * subgroupSize);
return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout.back() = subgroupSize;
laneData.back() =
std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
} else {
assert(valShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
if (layoutKind == xegpu::LayoutKind::InstData) {
instData[0] = subgroupSize;
instData[1] =
std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout[0] = subgroupSize;
laneData[1] =
std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
}
return nullptr;
}
/// Sets up the anchor layout for a load gather operation.
xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
const int subgroupSize = uArch->getSubgroupSize();
int resShapeSize = resVecTy.getShape().size();
auto context = resVecTy.getContext();
auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::SpirvLoadGatherInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
(chunkSize > 1), maxChunkSize,
resShapeSize, subgroupSize);
}
/// Sets up the anchor layout for load matrix operation.
/// TODO: enhance load matrix to indicate lowering to chunked load or not.
xegpu::DistributeLayoutAttr
xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
VectorType resVecTy,
xegpu::DistributeLayoutAttr consumerLayout,
const xegpu::uArch::uArch *uArch) {
const int subgroupSize = uArch->getSubgroupSize();
int resShapeSize = resVecTy.getShape().size();
auto context = resVecTy.getContext();
auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadMatrixInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadMatrix));
int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
false, maxChunkSize, resShapeSize,
subgroupSize);
}
/// Sets up the anchor layout for store scatter and store matrix operation.
/// store matrix lowers to store scatter and 1d block store. All of them share
/// the same layout setup logic. For Subgroup layout, not support yet.
/// non-chunked stores:
/// InstData = {1, ..., subgroupSize}
/// LaneLayout = {1, ..., subgroupSize}
/// lane_data = {1, ..., 1}
/// chunked stores:
/// InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
/// LaneLayout = {subgroupSize, 1}
/// lane_data={1,min(srcVec, maxLaneStoreSize)}
static xegpu::DistributeLayoutAttr
setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
mlir::MLIRContext *context, bool isChunkedStore,
int maxChunkSize, ArrayRef<int64_t> srcShape,
int subgroupSize) {
int srcShapeSize = srcShape.size();
SmallVector<int> instData(srcShapeSize, 1);
SmallVector<int> laneLayout(srcShapeSize, 1);
SmallVector<int> laneData(srcShapeSize, 1);
if (layoutKind == xegpu::LayoutKind::Subgroup) {
assert(true &&
"subgroup layout assignment not supported for storeScatter.");
return nullptr;
}
if (!isChunkedStore) {
if (layoutKind == xegpu::LayoutKind::InstData) {
instData[srcShapeSize - 1] = subgroupSize;
return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout[srcShapeSize - 1] = subgroupSize;
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
} else {
assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
if (layoutKind == xegpu::LayoutKind::InstData) {
instData[0] = subgroupSize;
instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout[0] = subgroupSize;
laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
}
return nullptr;
}
/// Sets up the anchor layout for a store scatter operation.
xegpu::DistributeLayoutAttr
xegpu::setupStoreScatterAnchorLayout(xegpu::LayoutKind layoutKind,
VectorType srcVecTy, int chunkSize,
const uArch::uArch *uArch) {
const int subgroupSize = uArch->getSubgroupSize();
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
auto context = srcVecTy.getContext();
auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::SpirvStoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
maxChunkSize, srcShape, subgroupSize);
}
/// Sets up the anchor layout for a store matrix operation.
xegpu::DistributeLayoutAttr
xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
VectorType srcVecTy,
const xegpu::uArch::uArch *uArch) {
const int subgroupSize = uArch->getSubgroupSize();
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
auto context = srcVecTy.getContext();
auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction = dyn_cast<xegpu::uArch::StoreMatrixInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreMatrix));
int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
srcShape, subgroupSize);
}

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"

View File

@ -15,7 +15,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Attributes.h"
@ -127,6 +127,7 @@ public:
}
Attribute get() { return storage; }
void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
};
SmallVector<int> LayoutInfo::getLaneLayout() const {
@ -307,27 +308,6 @@ static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
}
/// Helper to get the default layout for a vector type.
static LayoutInfo getSIMTLayoutInfoScatterIO(VectorType vectorTy,
const xegpu::uArch::uArch *uArch) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
// Expecting int or float element type.
assert(vectorTy.getElementType().isIntOrFloat() &&
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
if (vectorTy.getRank() == 1)
return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
{uArch->getSubgroupSize(), 1},
{1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
/// is set according to the following criteria:
/// * For A operand, the data must be packed in minimum
@ -417,11 +397,27 @@ private:
void visitShapeCastOp(vector::ShapeCastOp shapeCast,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void
visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitLoadGatherOp(xegpu::LoadMatrixOp load,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitStoreScatterOp(xegpu::StoreMatrixOp store,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
public:
@ -497,6 +493,12 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case([&](vector::ShapeCastOp shapeCastOp) {
visitShapeCastOp(shapeCastOp, operands, results);
})
.Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
})
.Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
visitLoadMatrixOp(loadMatrixOp, operands, results);
})
.Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
visitStoreMatrixOp(storeMatrixOp, operands, results);
})
@ -646,32 +648,45 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
LayoutInfo resLayoutInfo = results[0]->getValue();
if (!resLayoutInfo.isAssigned())
return;
// We only consider 2D -> 1D reductions at this point.
VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
if (!resultTy || resultTy.getRank() != 1) {
reduction.emitWarning("Expecting output type to be 1D vector.");
return;
}
VectorType sourceTy = reduction.getSourceVectorType();
SmallVector<int64_t> reductionDims(reduction.getReductionDims());
auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
// Given that the result is 1D, the layout of the operand should be 2D with
// default layout.
LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
reduction->getContext(), 2, uArch->getSubgroupSize());
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
// The result layout represents the layout requirements of the operation.
// it is recorded to anchor layout or temporary layout.
// it must be honored for current op and may conflict with the layout
// propagated from consumer op, the conflict is resolved in later phase by
// converting the required result layout to the consumer layout
auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
// derive the source layout from the dominant layout and reduction dims
auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
requiredResLayoutAttr, reductionDims);
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
propagateIfChanged(operands[1],
operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
}
void LayoutInfoPropagation::visitVectorBroadCastOp(
vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
LayoutInfo resLayoutInfo = results[0]->getValue();
if (!resLayoutInfo.isAssigned())
return;
// Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
@ -679,55 +694,41 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
if (!sourceTy)
return;
// Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
if (sourceTy.getRank() != resultTy.getRank()) {
auto sourceDims = sourceTy.getShape();
auto resultDims = resultTy.getShape();
SmallVector<int64_t> bcastDims;
auto dimDiff = resultTy.getRank() - sourceTy.getRank();
// adding the missing leading dims
for (int i = 0; i < dimDiff; i++)
bcastDims.push_back(i);
auto srcShape = sourceTy.getShape();
auto resShape = resultTy.getShape();
// for the rest dims in the resultTy, if sourceTy dim is 1, then it's
// broadcasted dim
for (size_t i = 0; i < sourceDims.size(); i++)
if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
bcastDims.push_back(i + dimDiff);
size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
for (size_t i = 0; i < srcShape.size(); i++)
if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
broadcast.emitWarning("broadcast must either from low-rank or same-rank "
"with unit-dim, mixed scenario is not supported!");
// create a slice layout for the source
xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
broadcast->getContext(),
cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
auto resultLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
xegpu::DistributeLayoutAttr srcLayoutAttr =
xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
return;
}
void LayoutInfoPropagation::visitShapeCastOp(
vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
LayoutInfo resLayoutInfo = results[0]->getValue();
if (!resLayoutInfo.isAssigned())
return;
VectorType sourceTy = shapeCast.getSourceVectorType();
VectorType resultTy = shapeCast.getResultVectorType();
// Shape cast layout propagation only supports 1D -> 2D shape casts.
// TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
// unit dimensions and non-unit dims match.
if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
return;
}
int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
auto resultLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
xegpu::DistributeLayoutAttr srcLayoutAttr =
xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
}
/// Propagate the layout of the result tensor to the source tensor descriptor
@ -748,7 +749,6 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp(
void LayoutInfoPropagation::visitDpasOp(
xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo dpasALayout;
LayoutInfo dpasBLayout;
LayoutInfo dpasCDLayout;
@ -945,7 +945,6 @@ void LayoutInfoPropagation::visitDpasOp(
void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo storeLayout;
xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
if (hasParamsOfLayoutKind(anchorLayout)) {
@ -986,7 +985,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
storeLayout =
getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
uArchInstruction->getPackedFormatBitSize());
else { // LayoutKind::Subgroup
else { // xegpu::LayoutKind::Subgroup
auto sgSize = uArch->getSubgroupSize();
auto numSgOrErr = getNumSg(store, sgSize);
if (failed(numSgOrErr)) {
@ -1026,7 +1025,6 @@ void LayoutInfoPropagation::visitStoreNdOp(
void LayoutInfoPropagation::visitLoadNdOp(
xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo loadLayout;
xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
if (hasParamsOfLayoutKind(anchorLayout)) {
@ -1072,66 +1070,60 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Need the layout of bitcast result to propagate to the operands.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
LayoutInfo resLayoutInfo = results[0]->getValue();
if (!resLayoutInfo.isAssigned())
return;
int inElemTyBitWidth =
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
int outElemTyBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
// If the element bit widths are the same, then the layout does not change.
if (inElemTyBitWidth == outElemTyBitWidth) {
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
return;
}
// Check if the result layout is valid. i.e. result vector can be distributed.
auto resultLaneLayout = resultLayout.getLaneLayout();
auto resultLaneData = resultLayout.getLaneData();
if (failed(xegpu::getDistributedVectorType(
bitcast.getResultVectorType(),
xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
resultLaneData)))) {
bitcast.emitWarning(
"Result vector type can not be evenly distributed across lanes.");
return;
}
int64_t rank = bitcast.getSourceVectorType().getRank();
// Bitcast is a `narrowing` if the input element type bit width larger than
// the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
: outElemTyBitWidth / inElemTyBitWidth;
SmallVector<int> sourceLaneLayout =
resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
SmallVector<int> outData = resultLayout.getLaneData();
// TODO: Currently we assume that bitcasts does not require cross lane
// communication. So each lane must own the required number of elements to
// perform the bitcast locally without cross-lane communication.
int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
if (outInnerBitsPerLane < inElemTyBitWidth) {
bitcast.emitWarning(
"Narrowing bitcast with cross lane communication is not supported.");
return;
}
// Check if each lane owns a single element in all dimensions except the
// innermost dimension.
SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
bitcast.emitWarning("Each lane must not own multiple elements in any "
"dimension other than "
"the innermost dimension.");
return;
}
// Decide lane data based on whether the bitcast is narrowing or widening.
int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
: outData[rank - 1] * bitCastRatio;
sourceLaneData.push_back(innerMostLaneData);
auto srcVecType = bitcast.getSourceVectorType();
auto resVecType = bitcast.getResultVectorType();
propagateIfChanged(
operands[0],
operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
auto requiredResLayoutAttr = setupBitCastResultLayout(
layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
// derive the source layout from the dominant layout and reduction dims
auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
}
void LayoutInfoPropagation::visitInsertStridedSliceOp(
vector::InsertStridedSliceOp insertStridedSlice,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
LayoutInfo resLayoutInfo = results[0]->getValue();
if (!resLayoutInfo.isAssigned())
return;
auto srcVecType = insertStridedSlice.getSourceVectorType();
auto resVecType = insertStridedSlice.getDestVectorType();
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
requiredResLayoutAttr);
auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
propagateIfChanged(operands[1],
operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
return;
}
/// Propagate the layout of the result to the tensor descriptor, mask and offset
@ -1139,97 +1131,56 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo loadLayout;
LayoutInfo maskLayout;
xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
auto uArch = getUArch(getChipStr(load).value_or(""));
const int subgroupSize = uArch->getSubgroupSize();
xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
if (hasParamsOfLayoutKind(anchorLayout)) {
loadLayout = LayoutInfo(anchorLayout);
maskLayout = loadLayout;
auto subgroupSize = uArch->getSubgroupSize();
VectorType resVecTy = load.getValueType();
int chunkSize = load.getChunkSize().value_or(1);
LayoutInfo resLayoutInfo = results[0]->getValue();
if (!resLayoutInfo.isAssigned())
return;
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
requiredAnchorLayoutAttr = anchorLayoutAttr;
} else {
LayoutInfo valueLayout = results[0]->getValue();
// Need the layout of the value to propagate to the tensor descriptor.
if (!valueLayout.isAssigned())
return;
auto resAttr = dyn_cast<xegpu::DistributeLayoutAttr>(valueLayout.get());
auto instDataIncoming = resAttr.getEffectiveInstDataAsInt();
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
instDataIncoming = SmallVector<int64_t>(
cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent())
.getInstData()
.asArrayRef());
VectorType payloadTy = load.getValueType();
if (!payloadTy) {
if (!resVecTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
// Check if value inst_data complies with uArch
if (layoutKind == xegpu::LayoutKind::InstData) {
// Each lane loads either one element
SmallVector<int> instDataUarch{subgroupSize};
// Or multiple elements as 2D with lane's elements in the inner dimension
if (payloadTy.getRank() != 1) {
if (payloadTy.getRank() != 2) {
load.emitWarning("Expected 2D payload for LoadGatherOp.");
return;
}
int elemBitWidth = payloadTy.getElementTypeBitWidth();
instDataUarch.push_back((
std::min(static_cast<int>(payloadTy.getShape().back()),
uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
}
// If inst data does not match, enforce the uArch-based one
if (!llvm::equal(instDataIncoming, instDataUarch)) {
xegpu::LayoutAttr sourceAttr = dyn_cast<xegpu::LayoutAttr>(resAttr);
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr)) {
sourceAttr = cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent());
}
assert(sourceAttr);
xegpu::DistributeLayoutAttr updatedLayoutAttr = xegpu::LayoutAttr::get(
load.getContext(), sourceAttr.getSgLayout(), sourceAttr.getSgData(),
DenseI32ArrayAttr::get(load.getContext(), instDataUarch),
sourceAttr.getLaneLayout(), sourceAttr.getLaneData(),
sourceAttr.getOrder());
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
updatedLayoutAttr = xegpu::SliceAttr::get(
load.getContext(), updatedLayoutAttr, sliceAttr.getDims());
valueLayout = LayoutInfo(updatedLayoutAttr);
}
}
loadLayout = valueLayout;
load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
load.setLayoutAttr(requiredAnchorLayoutAttr);
}
// If no user-defined anchor or we deal with a chunked op, set the default
// mask layout.
// Rank 1 data : Keep the mask layout aligned with data.
// Rank >1 data: Enforce the default xegpu 1D layout for mask.
if (!hasParamsOfLayoutKind(anchorLayout) ||
load.getValueType().getRank() > 1) {
auto maskLayoutAttr = requiredAnchorLayoutAttr;
// Special handling mask layout for chunked ops: Enforce the default xegpu 1D
// layout for mask.
if (chunkSize > 1) {
if (layoutKind == xegpu::LayoutKind::InstData)
maskLayout = LayoutInfo(
xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
maskLayoutAttr =
xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
else if (layoutKind == xegpu::LayoutKind::Lane)
maskLayout =
getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
maskLayoutAttr =
xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
else
assert(false &&
"chunked StoreScatterOp should not be used at workgroup level");
}
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
if (load.getOffsets())
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
}
/// Propagate the layout of the descriptor to the vector offset operand in
@ -1254,109 +1205,97 @@ void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo payloadLayout;
LayoutInfo maskLayout;
xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
const int subgroupSize = uArch->getSubgroupSize();
auto subgroupSize = uArch->getSubgroupSize();
VectorType srcVecTy = storeScatter.getValueType();
int chunkSize = storeScatter.getChunkSize().value_or(1);
if (hasParamsOfLayoutKind(anchorLayout)) {
payloadLayout = LayoutInfo(anchorLayout);
maskLayout = payloadLayout;
if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
requiredAnchorLayoutAttr = anchorLayoutAttr;
} else {
// Currently, for 2D StoreScatterOp we expect that the height dimension of
// the tensor descriptor is equal to the subgroup size. This is ensured by
// the op verifier.
VectorType payloadTy = storeScatter.getValueType();
if (!payloadTy) {
if (!srcVecTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
if (layoutKind == xegpu::LayoutKind::InstData) {
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
uArch->getInstruction(
xegpu::uArch::InstructionKind::StoreScatter));
const int subgroupSize = uArch->getSubgroupSize();
SmallVector<int> instDataUarch{subgroupSize};
if (payloadTy.getRank() != 1) {
if (payloadTy.getRank() != 2) {
storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
return;
}
int elemBitWidth = payloadTy.getElementTypeBitWidth();
instDataUarch.push_back((
std::min(static_cast<int>(payloadTy.getShape().back()),
uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
}
payloadLayout = LayoutInfo(
xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));
} else {
auto payloadShape = payloadTy.getShape();
if (payloadShape.size() > 1)
assert(payloadShape[0] == subgroupSize &&
"Expected the first dimension of 2D tensor descriptor to be "
"equal to "
"subgroup size.");
payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
}
storeScatter.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
layoutKind, srcVecTy, chunkSize, uArch);
storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
}
// If no user-defined anchor or we deal with a chunked op, set the default
// mask layout.
// Rank 1 data : Keep the mask layout aligned with data.
// Rank >1 data: Enforce the default xegpu 1D layout for mask.
if (!hasParamsOfLayoutKind(anchorLayout) ||
storeScatter.getValueType().getRank() > 1) {
LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
auto maskLayoutAttr = requiredAnchorLayoutAttr;
// Special handling mask layout for chunked ops: Enforce the default xegpu 1D
// layout for mask.
if (chunkSize > 1) {
if (layoutKind == xegpu::LayoutKind::InstData)
maskLayout = LayoutInfo(
xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
maskLayoutAttr =
xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
else if (layoutKind == xegpu::LayoutKind::Lane)
maskLayout =
getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
{subgroupSize}, {1});
else
assert(false &&
"chunked StoreScatterOp should not be used at workgroup level");
}
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
// Propagate the destination (if tdesc) operand layout
if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
if (storeScatter.getOffsets())
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
}
void LayoutInfoPropagation::visitLoadMatrixOp(
xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo resLayoutInfo = results[0]->getValue();
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
// only need to set anchor layout, no need to porpagate to memdesc and
// offset
if (!hasParamsOfLayoutKind(anchorLayout)) {
VectorType resVecTy =
llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
layoutKind, resVecTy, consumerLayoutAttr, uArch);
loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
}
}
// Store matrix is a flavor of scattered store for 2D shapes.
void LayoutInfoPropagation::visitStoreMatrixOp(
xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
Value operand = storeMatrix.getData();
unsigned index =
std::distance(storeMatrix.operand_begin(),
llvm::find(storeMatrix->getOperands(), operand));
xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
LayoutInfo layout;
if (hasParamsOfLayoutKind(anchorLayout)) {
layout = LayoutInfo(anchorLayout);
} else {
VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
VectorType srcVecTy =
llvm::cast<VectorType>(storeMatrix.getData().getType());
assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
SmallVector<int> instData = {1, uArch->getSubgroupSize()};
if (layoutKind == xegpu::LayoutKind::InstData)
layout = LayoutInfo(
xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
else
layout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
auto requiredAnchorLayoutAttr =
xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
layout = LayoutInfo(requiredAnchorLayoutAttr);
}
propagateIfChanged(operands[index], operands[index]->meet(layout));
propagateIfChanged(operands[0], operands[0]->meet(layout));
}
namespace {
@ -1736,10 +1675,24 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
LayoutInfo layout = analysis.getLayoutInfo(val);
if (!layout.isAssigned())
return {};
if (auto opResult = dyn_cast<OpResult>(val)) {
Operation *defOp = opResult.getDefiningOp();
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
auto anchorLayout = anchorOp.getAnchorLayout();
if (anchorLayout != nullptr)
return anchorLayout;
}
xegpu::DistributeLayoutAttr requiredResLayoutAttr =
xegpu::getTemporaryLayout(opResult);
if (requiredResLayoutAttr != nullptr)
return requiredResLayoutAttr;
}
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
if (layout.isSliceLayout())
return cast<xegpu::SliceAttr>(layoutAttr);
return cast<xegpu::LayoutAttr>(layoutAttr);
};

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/AffineMap.h"
@ -1532,8 +1533,9 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
}
// case 2: source and result have same rank
if (rankDiff == 0) {
SetVector<int64_t> broadcastUnitDims =
broadcastOp.computeBroadcastedUnitDims();
auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
broadcastUnitDimsSet.end());
bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
if (!isEqualTo)
return rewriter.notifyMatchFailure(

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>
@ -1113,27 +1114,10 @@ struct WgToSgVectorShapeCastOp
return failure();
ArrayRef<int64_t> srcShape = srcType.getShape();
llvm::SetVector<int64_t> expandedUnitDims;
// Check if shapes only differ by expanding unit dimensions (like
// expand_dims)
auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
ArrayRef<int64_t> dst) -> bool {
// All unit dimensions in dst that don't appear in src are the expanded
// unit dimensions
size_t srcIdx = 0;
for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
srcIdx++;
else if (dst[dstIdx] == 1)
expandedUnitDims.insert(dstIdx);
else
return false;
return srcIdx == src.size();
};
xegpu::DistributeLayoutAttr layoutToDistribute = layout;
if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
SmallVector<int64_t> expandedUnitDims;
if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getTemporaryLayout(op->getOpOperand(0));
@ -1488,15 +1472,8 @@ struct WgToSgMultiDimReductionOp
SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
auto storeMatrixLayout = xegpu::SliceAttr::get(
rewriter.getContext(),
xegpu::LayoutAttr::get(rewriter.getContext(), /*sg_layout =*/nullptr,
/*sg_data =*/nullptr,
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
/*lane_data =*/nullptr, /*order =*/nullptr),
dyn_cast<xegpu::SliceAttr>(layout).getDims());
xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
storeOffsets2D, /*layout=*/storeMatrixLayout);
storeOffsets2D, /*layout=*/nullptr);
gpu::BarrierOp::create(rewriter, loc);

View File

@ -366,111 +366,6 @@ template void xegpu::setTemporaryLayout<mlir::OpOperand>(
const mlir::OpOperand &operand,
const mlir::xegpu::DistributeLayoutAttr layout);
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
auto layout = getDistributeLayoutAttr(opr.get());
setDistributeLayoutAttr(opr, layout);
}
for (OpResult result : nestOp->getOpResults()) {
auto layout = getDistributeLayoutAttr(result);
setDistributeLayoutAttr(result, layout);
}
});
}
/// Attach layout attributes to all vector-type operands of operations within
/// the given operation's region. Reports an error if any vector operand lacks
/// a layout attribute.
bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
auto result = rootOp->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
op->emitWarning("Could not find layout attribute for operand ")
<< operand.getOperandNumber() << " of operation " << op->getName();
continue;
}
xegpu::setDistributeLayoutAttr(operand, layout);
}
return WalkResult::advance();
});
return !result.wasInterrupted();
}
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->removeAttr(name);
}
SmallVector<NamedAttribute>
xegpu::dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute> out;
out.reserve(attrs.size());
for (auto attr : attrs) {
if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
auto newLayout = dist.dropSgLayoutAndData();
if (newLayout)
out.emplace_back(attr.getName(), newLayout);
} else {
out.push_back(attr);
}
}
return out;
}
SmallVector<NamedAttribute>
xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute> out;
out.reserve(attrs.size());
for (auto attr : attrs) {
if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
auto newLayout = dist.dropInstData();
if (newLayout)
out.emplace_back(attr.getName(), newLayout);
} else {
out.push_back(attr);
}
}
return out;
}
// Explicit instantiation for OpResult
template void
xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
// Explicit instantiation for OpOperand
template void
xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
void xegpu::removeLayoutAttrs(Operation *op) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands())
removeLayoutAttr(opr);
for (OpResult result : nestOp->getOpResults())
removeLayoutAttr(result);
if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
op->removeAttr("layout");
if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
op->removeAttr("layout_a");
if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
op->removeAttr("layout_b");
if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
op->removeAttr("layout_cd");
});
}
SmallVector<Value>
xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
Value value, ArrayRef<int64_t> shape) {
@ -786,3 +681,58 @@ bool xegpu::requireTranspose(const xegpu::LayoutAttr layout,
return false;
return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
}
// Check if dst shape is an expansion of src shape by inserting unit dimensions.
// Returns true if all dimensions in src match corresponding dimensions in dst
// (after skipping unit dimensions), and populates expandedUnitDims with the
// indices of the unit dimensions in dst that were added (not present in src).
// Example: src=[2,3], dst=[1,2,3,1] -> true, expandedUnitDims=[0,3]
bool xegpu::matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
SmallVector<int64_t> &expandedUnitDims) {
// All unit dimensions in dst that don't appear in src are the expanded
// unit dimensions
size_t srcIdx = 0;
for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
srcIdx++;
else if (dst[dstIdx] == 1)
expandedUnitDims.push_back(dstIdx);
else
return false;
return srcIdx == src.size();
}
// Checks if dst shape is an expansion of src shape where each dimension in src
// is split into one or more consecutive dimensions in dst whose product equals
// the original dimension. Populates splitDimGroups with groups of dst indices
// that correspond to each src dimension. Example: src=[6,4], dst=[2,3,2,2] ->
// true
bool xegpu::matchSplitDimExpansion(
ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
SmallVector<SmallVector<int64_t>> &splitDimGroups) {
// each dim in src can be mapped to one or more dims in dst whose product
// equals to the src dim
size_t srcIdx = 0;
int64_t accumulatedSize = 1;
SmallVector<int64_t> currentDstDims;
splitDimGroups.clear();
for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
if (srcIdx >= src.size())
return false;
accumulatedSize *= dst[dstIdx];
currentDstDims.push_back(dstIdx);
if (accumulatedSize == src[srcIdx]) {
// Record the mapping: srcIdx -> currentDstDims
splitDimGroups.push_back(currentDstDims);
// move to next src dim
srcIdx++;
accumulatedSize = 1;
currentDstDims.clear();
} else if (accumulatedSize > src[srcIdx]) {
return false;
}
}
return srcIdx == src.size();
}

View File

@ -217,7 +217,7 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}> :
// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.layout<inst_data = [16]>}> :
// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf32> to vector<16x16xf32>
// CHECK: xegpu.store %[[BCAST]], %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
@ -234,3 +234,89 @@ func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<1.000000e+00> : vector<4x16xf32>
// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<0.000000e+00> : vector<8x32xf32>
// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
%c0 = arith.constant 0 : index
%cst_small = arith.constant dense<1.0> : vector<4x16xf32>
%cst_large = arith.constant dense<0.0> : vector<8x32xf32>
%insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
%tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32>
xegpu.store_nd %insert, %tdesc : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32>
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<1> : vector<4x64xi8>
// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<0> : vector<8x64xi8>
// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
%c0 = arith.constant 0 : index
%cst_small = arith.constant dense<1> : vector<4x64xi8>
%cst_large = arith.constant dense<0> : vector<8x64xi8>
%insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
%tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x64xi8> -> !xegpu.tensor_desc<8x64xi8, #xegpu.layout<inst_data = [8, 64]>>
xegpu.store_nd %insert, %tdesc <{layout = #xegpu.layout<inst_data = [8, 64]>}>: vector<8x64xi8>, !xegpu.tensor_desc<8x64xi8, #xegpu.layout<inst_data = [8, 64]>>
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_shape_cast_expand_non_unit_dims(
// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[STEP:.*]]], %[[CST:.*]] <{layout = #xegpu.layout<inst_data = [16]>}> : memref<1024xf16>, vector<1024xindex>, vector<1024xi1> -> vector<1024xf16>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} : vector<1024xf16> to vector<8x8x16xf16>
// CHECK: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [1, 1, 16]>, dims = [0]>} dense<0.000000e+00> : vector<8x16xf16>
// CHECK: %[[CST_1:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [1, 16]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16>
// CHECK: %[[REDUCE_0:.*]] = vector.multi_reduction <add>, %[[CAST]], %[[CST_0]] {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [1, 1, 16]>, dims = [0]>} [0] : vector<8x8x16xf16> to vector<8x16xf16>
// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[REDUCE_0]], %[[CST_1]] {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [1, 16]>, dims = [0]>} [0] : vector<8x16xf16> to vector<16xf16>
func.func @vector_shape_cast_expand_non_unit_dims(%arg0: memref<1024xf16>, %arg1: memref<16xf16>) {
%cst = arith.constant dense<true> : vector<1024xi1>
%0 = vector.step : vector<1024xindex>
%1 = xegpu.load %arg0[%0], %cst : memref<1024xf16>, vector<1024xindex>, vector<1024xi1> -> vector<1024xf16>
%2 = vector.shape_cast %1 : vector<1024xf16> to vector<8x8x16xf16>
%cst_0 = arith.constant dense<0.000000e+00> : vector<8x16xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<16xf16>
%3 = vector.multi_reduction <add>, %2, %cst_0 [0] : vector<8x8x16xf16> to vector<8x16xf16>
%4 = vector.multi_reduction <add>, %3, %cst_1 [0] : vector<8x16xf16> to vector<16xf16>
%cst_2 = arith.constant dense<true> : vector<16xi1>
%cst_3 = arith.constant dense<1> : vector<16xindex>
xegpu.store %4, %arg1[%cst_3], %cst_2 <{layout = #xegpu.layout<inst_data = [16]>}> : vector<16xf16>, memref<16xf16>, vector<16xindex>, vector<16xi1>
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_shape_cast_expand_and_merge(
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [32]>} dense<true> : vector<256xi1>
// CHECK: %[[STEP:.*]] = vector.step {layout_result_0 = #xegpu.layout<inst_data = [32]>} : vector<256xindex>
// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[STEP]]], %[[CST]] <{layout = #xegpu.layout<inst_data = [32]>}> : memref<256xf16>, vector<256xindex>, vector<256xi1> -> vector<256xf16>
// CHECK: %[[CAST_0:.*]] = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 32]>} : vector<256xf16> to vector<2x4x32xf16>
// CHECK: %[[CAST_1:.*]] = vector.shape_cast %[[CAST_0]] {layout_result_0 = #xegpu.layout<inst_data = [1, 32]>} : vector<2x4x32xf16> to vector<1x256xf16>
// CHECK: %[[CAST_2:.*]] = vector.shape_cast %[[CAST_1]] {layout_result_0 = #xegpu.layout<inst_data = [32]>} : vector<1x256xf16> to vector<256xf16>
// CHECK: xegpu.store %[[CAST_2]], %arg1[%[[STEP]]], %[[CST]] <{layout = #xegpu.layout<inst_data = [32]>}> : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
func.func @vector_shape_cast_expand_and_merge(%arg0: memref<256xf16>, %arg1: memref<256xf16>) {
%cst = arith.constant dense<true> : vector<256xi1>
%0 = vector.step : vector<256xindex>
%1 = xegpu.load %arg0[%0], %cst : memref<256xf16>, vector<256xindex>, vector<256xi1> -> vector<256xf16>
%2 = vector.shape_cast %1 : vector<256xf16> to vector<2x4x32xf16>
%4 = vector.shape_cast %2 : vector<2x4x32xf16> to vector<1x256xf16>
%5 = vector.shape_cast %4 : vector<1x256xf16> to vector<256xf16>
xegpu.store %5, %arg1[%0], %cst <{layout = #xegpu.layout<inst_data = [32] >}> : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
return
}
}

View File

@ -123,3 +123,44 @@ gpu.module @test {
gpu.return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: vector_row_reduction
// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [1, 64]>, dims = [1]>}
gpu.func @vector_row_reduction(%src: memref<32x64xf32>, %dst: memref<32xf32>) kernel attributes
{known_block_size = array<i32: 1, 32, 1>} {
%cst = arith.constant dense<0.000000e+00> : vector<32xf32>
%tdesc_src = xegpu.create_nd_tdesc %src : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32>
%load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x64xf32> -> vector<32x64xf32>
%reduce = vector.multi_reduction <add>, %load, %cst [1] : vector<32x64xf32> to vector<32xf32>
%tdesc_dst = xegpu.create_nd_tdesc %dst : memref<32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<sg_layout = [32], sg_data = [1]>>
xegpu.store_nd %reduce, %tdesc_dst <{layout = #xegpu.layout<sg_layout = [32], sg_data = [1]>}>
: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.layout<sg_layout = [32], sg_data = [1]>>
gpu.return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: vector_nest_reduction
gpu.func @vector_nest_reduction(%src: memref<32x128xf32>, %dst: memref<32xf32>) kernel attributes
{known_block_size = array<i32: 1, 32, 1>} {
%cst = arith.constant dense<0.000000e+00> : vector<32xf32>
%cst1 = arith.constant dense<0.000000e+00> : vector<32x128xf32>
%tdesc_src = xegpu.create_nd_tdesc %src : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32>
%load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x128xf32> -> vector<32x128xf32>
%bcast1 = vector.broadcast %load: vector<32x128xf32> to vector<4x32x128xf32>
// CHECK: %[[BCAST1:.*]] = vector.broadcast %{{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>} : vector<32x128xf32> to vector<4x32x128xf32>
// CHECK: %[[BCAST:.*]] = vector.multi_reduction <add>, %[[BCAST1]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>} [0] : vector<4x32x128xf32> to vector<32x128xf32>
// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[BCAST]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [8, 16]>, dims = [1]>} [1] : vector<32x128xf32> to vector<32xf32>
%bcast = vector.multi_reduction <add>, %bcast1, %cst1 [0]: vector<4x32x128xf32> to vector<32x128xf32>
%reduce = vector.multi_reduction <add>, %bcast, %cst [1] : vector<32x128xf32> to vector<32xf32>
%mask = arith.constant dense<1>: vector<32xi1>
%offset = vector.step : vector<32xindex>
xegpu.store %reduce, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 16]>, dims = [1]>} : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
gpu.return
}
}

View File

@ -104,21 +104,18 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
gpu.module @test {
// CHECK-LABEL: func.func @load_gather_with_chunksize(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
// CHECK-SAME: dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
// CHECK-NEXT: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %{{.*}} = xegpu.load %arg1[%[[OFFSET]]], %[[MASK]] <{chunk_size = 16 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%cst_0 = arith.constant dense<true> : vector<16xi1>
%2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
%3 = xegpu.load %2, %cst_0 : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
%offset = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%mask = arith.constant dense<true> : vector<16xi1>
%3 = xegpu.load %arg1[%offset], %mask <{chunk_size=16}>
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
%4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
%5 = xegpu.dpas %1, %4 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
%6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
@ -151,16 +148,15 @@ func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf
gpu.module @test {
// CHECK-LABEL: func.func @store_scatter_with_chunksize(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) {
// CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} : memref<128xf32>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>,
// CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
// CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>} dense<1.000000e+00> : vector<16x8xf32>
// CHECK-NEXT: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[CST_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
// CHECK-NEXT: xegpu.store %[[CST]], %[[ARG0]][%[[CST_1]]], %[[CST_0]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}> : vector<16x8xf32>, memref<128xf32>, vector<16xindex>, vector<16xi1>
func.func @store_scatter_with_chunksize(%arg0: memref<128xf32>) {
%cst = arith.constant dense<1.000000e+00> : vector<16x8xf32>
%cst_0 = arith.constant dense<true> : vector<16xi1>
%cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
xegpu.store %cst, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
%val = arith.constant dense<1.000000e+00> : vector<16x8xf32>
%mask = arith.constant dense<true> : vector<16xi1>
%offset = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
xegpu.store %val, %arg0[%offset], %mask <{chunk_size = 8}>: vector<16x8xf32>, memref<128xf32>, vector<16xindex>, vector<16xi1>
return
}
}
@ -184,9 +180,9 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}>
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
@ -320,8 +316,9 @@ func.func @vector_bitcast_i16_to_i32(%arg0: memref<8x32xi16>, %arg1: memref<8x16
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_bitcast_require_cross_lane_shuffle(
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
// CHECK: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>
// CHECK-SAME: !xegpu.tensor_desc<8x16xi32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
// CHECK-SAME: vector<8x16xi32> to vector<8x32xi16>
func.func @vector_bitcast_require_cross_lane_shuffle(%arg0: memref<8x16xi32>, %arg1: memref<8x32xi16>) {
%c0 = arith.constant 0 : index
@ -483,7 +480,7 @@ func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
gpu.module @test {
// CHECK-LABEL: func.func @vector_outer_reduction(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf32> to vector<16xf32>
// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf32> to vector<16xf32>
func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@ -495,7 +492,7 @@ func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor
gpu.module @test {
// CHECK-LABEL: func.func @vector_inner_reduction(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
func.func @vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
@ -642,6 +639,52 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_shape_cast_expand_non_unit_dims(
// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[STEP:.*]]], %[[CST:.*]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<1024xf16>, vector<1024xindex>, vector<1024xi1> -> vector<1024xf16>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>} : vector<1024xf16> to vector<8x8x16xf16>
// CHECK: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<8x16xf16>
// CHECK: %[[CST_1:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16>
// CHECK: %[[REDUCE_0:.*]] = vector.multi_reduction <add>, %[[CAST]], %[[CST_0]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [0]>} [0] : vector<8x8x16xf16> to vector<8x16xf16>
// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[REDUCE_0]], %[[CST_1]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<8x16xf16> to vector<16xf16>
func.func @vector_shape_cast_expand_non_unit_dims(%arg0: memref<1024xf16>, %arg1: memref<16xf16>) {
%cst = arith.constant dense<true> : vector<1024xi1>
%0 = vector.step : vector<1024xindex>
%1 = xegpu.load %arg0[%0], %cst : memref<1024xf16>, vector<1024xindex>, vector<1024xi1> -> vector<1024xf16>
%2 = vector.shape_cast %1 : vector<1024xf16> to vector<8x8x16xf16>
%cst_0 = arith.constant dense<0.000000e+00> : vector<8x16xf16>
%cst_1 = arith.constant dense<0.000000e+00> : vector<16xf16>
%3 = vector.multi_reduction <add>, %2, %cst_0 [0] : vector<8x8x16xf16> to vector<8x16xf16>
%4 = vector.multi_reduction <add>, %3, %cst_1 [0] : vector<8x16xf16> to vector<16xf16>
%cst_2 = arith.constant dense<true> : vector<16xi1>
%cst_3 = arith.constant dense<1> : vector<16xindex>
xegpu.store %4, %arg1[%cst_3], %cst_2 <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1] >}> : vector<16xf16>, memref<16xf16>, vector<16xindex>, vector<16xi1>
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_shape_cast_expand_and_merge(
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [2]>} dense<true> : vector<256xi1>
// CHECK: %[[STEP:.*]] = vector.step {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [2]>} : vector<256xindex>
// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[STEP]]], %[[CST]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [2]>}> : memref<256xf16>, vector<256xindex>, vector<256xi1> -> vector<256xf16>
// CHECK: %[[CAST_0:.*]] = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 2]>} : vector<256xf16> to vector<2x4x32xf16>
// CHECK: %[[CAST_1:.*]] = vector.shape_cast %[[CAST_0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} : vector<2x4x32xf16> to vector<1x256xf16>
// CHECK: %[[CAST_2:.*]] = vector.shape_cast %[[CAST_1]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [2]>} : vector<1x256xf16> to vector<256xf16>
// CHECK: xegpu.store %[[CAST_2]], %arg1[%[[STEP]]], %[[CST]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [2]>}> : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
func.func @vector_shape_cast_expand_and_merge(%arg0: memref<256xf16>, %arg1: memref<256xf16>) {
%cst = arith.constant dense<true> : vector<256xi1>
%0 = vector.step : vector<256xindex>
%1 = xegpu.load %arg0[%0], %cst : memref<256xf16>, vector<256xindex>, vector<256xi1> -> vector<256xf16>
%2 = vector.shape_cast %1 : vector<256xf16> to vector<2x4x32xf16>
%4 = vector.shape_cast %2 : vector<2x4x32xf16> to vector<1x256xf16>
%5 = vector.shape_cast %4 : vector<1x256xf16> to vector<256xf16>
xegpu.store %5, %arg1[%0], %cst <{layout = #xegpu.layout<lane_layout = [16], lane_data = [2] >}> : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_broadcast_1d_to_2d_broadcast_along_row(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
@ -702,12 +745,50 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
// -----
gpu.module @test {
// CHECK-LABEL: func.func @store_matrix(
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} dense<0.000000e+00> : vector<16x16xf16>
// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>
func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
%cst = arith.constant dense<0.0000> : vector<16x16xf16>
xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_lane_layout_no_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x64xf32>) {
// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<2x32xf32>
// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<4x64xf32>
// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]} : vector<2x32xf32> into vector<4x64xf32>
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<4x64xf32> -> !xegpu.tensor_desc<4x64xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<4x64xf32>, !xegpu.tensor_desc<4x64xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
func.func @insert_strided_slice_lane_layout_no_packing(%arg0: memref<4x64xf32>) {
%c0 = arith.constant 0 : index
%cst_small = arith.constant dense<1.0> : vector<2x32xf32>
%cst_large = arith.constant dense<0.0> : vector<4x64xf32>
%insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<2x32xf32> into vector<4x64xf32>
%tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x64xf32> -> !xegpu.tensor_desc<4x64xf32>
xegpu.store_nd %insert, %tdesc : vector<4x64xf32>, !xegpu.tensor_desc<4x64xf32>
return
}
}
// -----
gpu.module @test {
// CHECK-LABEL: func.func @insert_strided_slice_lane_layout_with_packing(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x64xf16>) {
// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} dense<1.000000e+00> : vector<2x32xf16>
// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} dense<0.000000e+00> : vector<4x64xf16>
// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, offsets = [0, 0], strides = [1, 1]} : vector<2x32xf16> into vector<4x64xf16>
func.func @insert_strided_slice_lane_layout_with_packing(%arg0: memref<4x64xf16>) {
%c0 = arith.constant 0 : index
%cst_small = arith.constant dense<1.0> : vector<2x32xf16>
%cst_large = arith.constant dense<0.0> : vector<4x64xf16>
%insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<2x32xf16> into vector<4x64xf16>
%tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x64xf16> -> !xegpu.tensor_desc<4x64xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>>
xegpu.store_nd %insert, %tdesc <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}>: vector<4x64xf16>, !xegpu.tensor_desc<4x64xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>>
return
}
}

View File

@ -674,7 +674,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [1]>}>: vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@ -717,7 +717,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
// CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [0]>}>: vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@ -766,7 +766,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [2, 3]>}>: vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x4xf32>, index, index -> vector<16x1xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
@ -810,7 +810,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C256:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [2, 3]>}>: vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x1024xf32>, index, index -> vector<16x256xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<256xf32>

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"