[mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (#154403)
This commit is contained in:
parent
32a5adbd42
commit
68d6866428
@ -11,6 +11,7 @@
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
@ -23,6 +24,7 @@
|
||||
namespace mlir {
|
||||
namespace xegpu {
|
||||
class TensorDescType;
|
||||
class DistributeLayoutAttr;
|
||||
class LayoutAttr;
|
||||
class SliceAttr;
|
||||
} // namespace xegpu
|
||||
|
@ -175,22 +175,36 @@ def XeGPU_FenceScopeAttr:
|
||||
let assemblyFormat = "$value";
|
||||
}
|
||||
|
||||
def LayoutTrait: AttrInterface<"LayoutTrait"> {
|
||||
def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
|
||||
let cppNamespace = "::mlir::xegpu";
|
||||
let description = [{
|
||||
Common trait for all XeGPU layouts.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Check the availability of workgroup level layouts",
|
||||
"bool",
|
||||
"isForWorkgroup">,
|
||||
InterfaceMethod<"Get the rank of attribute",
|
||||
"int64_t",
|
||||
"getRank">,
|
||||
InterfaceMethod<"Get the num of effective subgroups",
|
||||
"int64_t",
|
||||
"getNumSubgroups", (ins), [{
|
||||
std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getSgLayoutAsInt();
|
||||
if (sgLayout.has_value())
|
||||
return computeProduct(*sgLayout);
|
||||
return 0;
|
||||
}], [{}]>,
|
||||
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
|
||||
"std::optional<SmallVector<int64_t>>",
|
||||
"getSgLayoutAsInt">,
|
||||
InterfaceMethod<"Get the SgData field of the attribute as integer array",
|
||||
"std::optional<SmallVector<int64_t>>",
|
||||
"getSgDataAsInt">,
|
||||
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
|
||||
"xegpu::DistributeLayoutAttr",
|
||||
"dropSgLayoutAndData">,
|
||||
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
|
||||
indices based on the effective subgroup layout.}],
|
||||
"FailureOr<SmallVector<Value>>",
|
||||
@ -206,7 +220,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
|
||||
];
|
||||
}
|
||||
|
||||
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
|
||||
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
|
||||
let summary = [{
|
||||
Describes the data distribution to subgroups and work-items for a tensor
|
||||
specified by the tensor descriptor.
|
||||
@ -328,12 +342,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool isWgLayout() {
|
||||
bool isForWorkgroup() {
|
||||
return getSgLayout() != nullptr;
|
||||
}
|
||||
|
||||
bool isSgLayout() {
|
||||
return !isWgLayout();
|
||||
bool isForSubgroup() {
|
||||
return !isForWorkgroup();
|
||||
}
|
||||
|
||||
int64_t getRank() {
|
||||
@ -393,7 +407,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
|
||||
}
|
||||
|
||||
|
||||
def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
|
||||
def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
|
||||
let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
|
||||
|
||||
let description = [{
|
||||
@ -420,7 +434,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
|
||||
}];
|
||||
|
||||
let parameters = (ins
|
||||
"xegpu::LayoutTrait": $parent,
|
||||
"xegpu::DistributeLayoutAttr": $parent,
|
||||
"DenseI64ArrayAttr": $dims
|
||||
);
|
||||
|
||||
@ -438,16 +452,16 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
|
||||
return parent.getOrder();
|
||||
}
|
||||
|
||||
bool isWgLayout() const {
|
||||
bool isForWorkgroup() const {
|
||||
SliceAttr attr = flatten();
|
||||
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
|
||||
return parent.isWgLayout();
|
||||
return parent.isForWorkgroup();
|
||||
}
|
||||
|
||||
bool isSgLayout() const {
|
||||
bool isForSubgroup() const {
|
||||
SliceAttr attr = flatten();
|
||||
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
|
||||
return parent.isSgLayout();
|
||||
return parent.isForSubgroup();
|
||||
}
|
||||
|
||||
/// Returns the SgLayout of the attribute, computed by applying
|
||||
@ -474,6 +488,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
SliceAttr dropSgLayoutAndData() {
|
||||
SliceAttr attr = flatten();
|
||||
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
|
||||
parent = parent.dropSgLayoutAndData();
|
||||
return SliceAttr::get(getContext(), parent, attr.getDims());
|
||||
}
|
||||
|
||||
SliceAttr dropInstData() {
|
||||
SliceAttr attr = flatten();
|
||||
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
|
||||
parent = parent.dropInstData();
|
||||
return SliceAttr::get(getContext(), parent, attr.getDims());
|
||||
}
|
||||
|
||||
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
|
||||
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
|
||||
/// it will coalese two slice operations and return a simplified SliceAttr
|
||||
|
@ -232,6 +232,14 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
|
||||
return static_cast<unsigned>(MemorySpace::Global);
|
||||
}
|
||||
|
||||
xegpu::DistributeLayoutAttr getLayoutAttr() {
|
||||
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
return getTensorDescShape();
|
||||
}
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
@ -262,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
|
||||
xegpu::TensorDescType getTensorDescType() {
|
||||
return getTensorDesc().getType();
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getMixedOffsets() {
|
||||
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
|
||||
auto dynamics = getOffsets();
|
||||
if (statics.size() == 0 && dynamics.size() == 0)
|
||||
return {};
|
||||
return getMixedValues(statics, dynamics, getContext());
|
||||
}
|
||||
|
||||
xegpu::DistributeLayoutAttr getLayoutAttr() {
|
||||
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
return getTensorDescType().getShape();
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
@ -343,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
|
||||
xegpu::TensorDescType getTensorDescType() {
|
||||
return getTensorDesc().getType();
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getMixedOffsets() {
|
||||
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
|
||||
auto dynamics = getOffsets();
|
||||
if (statics.size() == 0 && dynamics.size() == 0)
|
||||
return {};
|
||||
return getMixedValues(statics, dynamics, getContext());
|
||||
}
|
||||
|
||||
xegpu::DistributeLayoutAttr getLayoutAttr() {
|
||||
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
return getTensorDescType().getShape();
|
||||
}
|
||||
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
@ -417,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
|
||||
xegpu::TensorDescType getTensorDescType() {
|
||||
return getTensorDesc().getType();
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getMixedOffsets() {
|
||||
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
|
||||
auto dynamics = getOffsets();
|
||||
if (statics.size() == 0 && dynamics.size() == 0)
|
||||
return {};
|
||||
return getMixedValues(statics, dynamics, getContext());
|
||||
}
|
||||
|
||||
xegpu::DistributeLayoutAttr getLayoutAttr() {
|
||||
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
return getTensorDescType().getShape();
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
@ -640,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
|
||||
xegpu::TensorDescType getTensorDescType() {
|
||||
return dyn_cast<xegpu::TensorDescType>(getSourceType());
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
@ -1150,7 +1211,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
|
||||
let arguments = (ins XeGPU_MemDesc:$mem_desc,
|
||||
Variadic<Index>: $offsets,
|
||||
DenseI64ArrayAttr: $const_offsets,
|
||||
OptionalAttr<LayoutTrait>:$layout
|
||||
OptionalAttr<DistributeLayoutAttr>:$layout
|
||||
);
|
||||
let results = (outs XeGPU_ValueType:$res);
|
||||
let assemblyFormat = [{
|
||||
@ -1175,12 +1236,16 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
|
||||
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
|
||||
"llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<OpFoldResult> getMixedOffsets() {
|
||||
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
return getRes().getType().getShape();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
@ -1194,7 +1259,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
|
||||
XeGPU_MemDesc:$mem_desc,
|
||||
Variadic<Index>: $offsets,
|
||||
DenseI64ArrayAttr: $const_offsets,
|
||||
OptionalAttr<LayoutTrait>:$layout
|
||||
OptionalAttr<DistributeLayoutAttr>:$layout
|
||||
);
|
||||
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
|
||||
prop-dict attr-dict `` `:` type(operands)}];
|
||||
@ -1213,12 +1278,17 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
|
||||
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
|
||||
"llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<OpFoldResult> getMixedOffsets() {
|
||||
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
return getData().getType().getShape();
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
namespace mlir {
|
||||
|
||||
class VectorType;
|
||||
@ -128,6 +129,20 @@ void doSCFStructuralTypeConversionWithTensorType(Operation *op,
|
||||
/// if no GPU module parent or XeVM target attribute exists.
|
||||
std::optional<std::string> getChipStr(Operation *op);
|
||||
|
||||
/// Generates element-wise addition ops of two arrays with automatic alignment.
|
||||
/// When the input arrays have different sizes, the shorter array is
|
||||
/// right-aligned with the longer array, and the unmatched leading elements from
|
||||
/// the longer array are preserved unchanged. This is commonly used for offset
|
||||
/// computation where higher-dimensional offsets need to be added to
|
||||
/// lower-dimensional adjustments.
|
||||
///
|
||||
/// Example:
|
||||
/// lhs = [l1, l2, l3], rhs = [r1, r2]
|
||||
/// Result: [11, l2+r1, l3+r2]
|
||||
SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
|
||||
ArrayRef<OpFoldResult> lhs,
|
||||
ArrayRef<OpFoldResult> rhs);
|
||||
|
||||
} // namespace xegpu
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -271,7 +271,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
|
||||
Value linearId) {
|
||||
// delinearizeSubgroupId is only available for
|
||||
// workgroup-level layout attribute
|
||||
if (!isWgLayout())
|
||||
if (!isForWorkgroup())
|
||||
return failure();
|
||||
|
||||
// TODO: handle order attribute
|
||||
@ -290,12 +290,13 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
|
||||
return affine::delinearizeIndex(builder, loc, linearId, dims);
|
||||
}
|
||||
|
||||
/// Implements LayoutTrait::getOffsets to generate instructions for
|
||||
/// computing multi-dimensional offsets when distributed by LayoutAttr.
|
||||
/// Implements DistributeLayoutAttr::getOffsets to generate
|
||||
/// instructions for computing multi-dimensional offsets when distributed by
|
||||
/// LayoutAttr.
|
||||
FailureOr<SmallVector<SmallVector<Value>>>
|
||||
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
|
||||
ArrayRef<int64_t> shape) {
|
||||
if (!isWgLayout())
|
||||
if (!isForWorkgroup())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
|
||||
@ -322,7 +323,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
|
||||
//===----------------------------------------------------------------------===//
|
||||
LogicalResult
|
||||
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
|
||||
xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
|
||||
xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
|
||||
if (!parent || !dims)
|
||||
return emitError() << "expected parent layout and dims attribute";
|
||||
|
||||
@ -340,7 +341,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
|
||||
}
|
||||
|
||||
SliceAttr SliceAttr::flatten() const {
|
||||
xegpu::LayoutTrait parent = getParent();
|
||||
xegpu::DistributeLayoutAttr parent = getParent();
|
||||
SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
|
||||
|
||||
while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
|
||||
@ -375,13 +376,14 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
|
||||
return parent.delinearizeSubgroupId(builder, loc, linearId);
|
||||
}
|
||||
|
||||
/// Implements LayoutTrait::getOffsets to generate instructions for
|
||||
/// computing multi-dimensional offsets when distributed by SliceAttr.
|
||||
/// Implements DistributeLayoutAttr::getOffsets to generate
|
||||
/// instructions for computing multi-dimensional offsets when distributed by
|
||||
/// SliceAttr.
|
||||
FailureOr<SmallVector<SmallVector<Value>>>
|
||||
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
|
||||
ArrayRef<int64_t> shape) {
|
||||
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
|
||||
if (!isWgLayout())
|
||||
if (!isForWorkgroup())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
|
||||
|
@ -938,8 +938,8 @@ LogicalResult ConvertLayoutOp::verify() {
|
||||
|
||||
// both input and target layouts should be WgLayout or SgLayout at the same
|
||||
// time.
|
||||
if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
|
||||
(!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
|
||||
if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
|
||||
(!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
|
||||
return emitOpError("expected input layout and target layout be WgLayout or "
|
||||
"SgLayout at the same time.");
|
||||
|
||||
@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
|
||||
TypedValue<MemDescType> memDesc,
|
||||
llvm::ArrayRef<OpFoldResult> offsets,
|
||||
LayoutTrait layout) {
|
||||
DistributeLayoutAttr layout) {
|
||||
llvm::SmallVector<Value> dynamicOffsets;
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
|
||||
void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
|
||||
TypedValue<MemDescType> memDesc,
|
||||
llvm::ArrayRef<OpFoldResult> offsets,
|
||||
LayoutTrait layout) {
|
||||
DistributeLayoutAttr layout) {
|
||||
llvm::SmallVector<Value> dynamicOffsets;
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
|
@ -141,7 +141,7 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
|
||||
value = (Value)operandOrResult;
|
||||
|
||||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
|
||||
if (layout && layout.isSgLayout()) {
|
||||
if (layout && layout.isForSubgroup()) {
|
||||
if (auto inst_data = layout.getInstData())
|
||||
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
|
||||
|
||||
@ -205,12 +205,12 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
|
||||
bool hasWgLayoutOperands =
|
||||
llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
|
||||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
|
||||
return layout && layout.isWgLayout();
|
||||
return layout && layout.isForWorkgroup();
|
||||
});
|
||||
bool hasWgLayoutResults =
|
||||
llvm::any_of(op->getOpResults(), [](OpResult result) {
|
||||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
|
||||
return layout && layout.isWgLayout();
|
||||
return layout && layout.isForWorkgroup();
|
||||
});
|
||||
if (hasWgLayoutOperands || hasWgLayoutResults) {
|
||||
LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
|
||||
@ -272,7 +272,7 @@ void XeGPUBlockingPass::runOnOperation() {
|
||||
|
||||
auto layout =
|
||||
llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
|
||||
if (layout && layout.isWgLayout())
|
||||
if (layout && layout.isForWorkgroup())
|
||||
return failure();
|
||||
|
||||
int count;
|
||||
@ -289,7 +289,7 @@ void XeGPUBlockingPass::runOnOperation() {
|
||||
ArrayRef<int64_t> shape = type.getShape();
|
||||
|
||||
xegpu::LayoutAttr layout = type.getLayoutAttr();
|
||||
if (layout && layout.isWgLayout())
|
||||
if (layout && layout.isForWorkgroup())
|
||||
return failure();
|
||||
|
||||
int count;
|
||||
|
@ -34,38 +34,29 @@ using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
// Check if there is sg id range attached to the scf.if op.
|
||||
static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
|
||||
int64_t &endOfRange) {
|
||||
Operation *parent = op->getParentOp();
|
||||
// Find the outermost scf::IfOp with xegpu.sg_id_range.
|
||||
// Retrieve the RangeAttr if it is specified.
|
||||
static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
|
||||
Operation *parent = op->getParentOfType<scf::IfOp>();
|
||||
while (parent) {
|
||||
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
|
||||
if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
|
||||
ifOp->getAttr("sg_id_range"))) {
|
||||
startOfRange = attr.getStart().getInt();
|
||||
endOfRange = attr.getEnd().getInt();
|
||||
break;
|
||||
}
|
||||
}
|
||||
parent = parent->getParentOp();
|
||||
if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
|
||||
parent->getAttr("sg_id_range")))
|
||||
return attr;
|
||||
parent = parent->getParentOfType<scf::IfOp>();
|
||||
}
|
||||
// Return false if startOfRange is 0
|
||||
return (startOfRange > 0 && endOfRange > startOfRange);
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::pair<SmallVector<int64_t>, int>
|
||||
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
|
||||
getSgShapeAndCount(ArrayRef<int64_t> shape,
|
||||
xegpu::DistributeLayoutAttr layout) {
|
||||
int count = 1;
|
||||
SmallVector<int64_t> sgShape(shape);
|
||||
|
||||
if (layout && layout.isWgLayout()) {
|
||||
DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
|
||||
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
|
||||
if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
|
||||
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
|
||||
else
|
||||
sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
|
||||
if (layout && layout.isForWorkgroup()) {
|
||||
SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
|
||||
if (auto maybeSgData = layout.getSgDataAsInt())
|
||||
sgShape = *maybeSgData;
|
||||
else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
|
||||
sgShape = *maybeDerivedSgData;
|
||||
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
|
||||
// Clamp distUnit to the original shape to handle cases where data is
|
||||
// shared among subgroups, which may cause distUnit to exceed the original
|
||||
@ -77,6 +68,67 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
|
||||
return std::make_pair(sgShape, count);
|
||||
}
|
||||
|
||||
/// Utility helper for deriving a list of offsets for each sub-TensorDescs
|
||||
/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
|
||||
/// associated distribute layout attribute, the shape, subgroup id and the
|
||||
/// original offsets of the op
|
||||
template <
|
||||
typename OpType,
|
||||
typename = std::enable_if_t<llvm::is_one_of<
|
||||
OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
|
||||
xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
|
||||
static LogicalResult
|
||||
genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
|
||||
SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
|
||||
// not applicable to ops without offsets operands.
|
||||
if (origOffsets.empty())
|
||||
return failure();
|
||||
|
||||
// not applicable to ops without workgroup layout attributes
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
if (!layout || !layout.isForWorkgroup())
|
||||
return failure();
|
||||
|
||||
Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
|
||||
|
||||
// verify and adjust the sgId if the range specifier is present
|
||||
xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
|
||||
if (sgIdRange) {
|
||||
int64_t startOfRange = sgIdRange.getStart().getInt();
|
||||
int64_t endOfRange = sgIdRange.getEnd().getInt();
|
||||
// verify the RangeAttr against the layout attribute
|
||||
if (layout.getNumSubgroups() != endOfRange - startOfRange)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "sg_layout size must match the sg_id_range");
|
||||
// adjust the sgId if necessary
|
||||
if (startOfRange > 0) {
|
||||
Value startOfRangeVal =
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
|
||||
sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
|
||||
// descriptors to be accessed, based on the layout information.
|
||||
ArrayRef<int64_t> wgShape = op.getDataShape();
|
||||
auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
|
||||
if (failed(maybeDescOffsets))
|
||||
return failure();
|
||||
|
||||
// Compute the final global offsets for each accessed sub-tensor
|
||||
// or sub-memory descriptor.
|
||||
for (const auto &sgOffsets : *maybeDescOffsets) {
|
||||
SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned(
|
||||
rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
|
||||
offsetsList.push_back(std::move(newOffsets));
|
||||
}
|
||||
|
||||
// callback(offsetsList);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
|
||||
/// from a workgroup descriptor. It replaces the offsets and sizes with
|
||||
/// appropriate values for the subgroup.
|
||||
@ -128,79 +180,30 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// Ensure that the op has explicit offsets specified (either dynamic or
|
||||
// constant).
|
||||
if (op.getMixedOffsets().empty())
|
||||
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *ctx = op.getContext();
|
||||
xegpu::TensorDescType tdescTy = op.getType();
|
||||
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
|
||||
if (!layout)
|
||||
return failure();
|
||||
Type elemTy = tdescTy.getElementType();
|
||||
ArrayRef<int64_t> wgShape = tdescTy.getShape();
|
||||
// sgLayout must be present for workgroup-level distribution.
|
||||
SmallVector<int64_t> sgLayout;
|
||||
if (auto sgLayoutAttr = layout.getSgLayout())
|
||||
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
|
||||
else
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "sgLayout attribute is required in layout");
|
||||
|
||||
// Get the subgroup ID
|
||||
Value linearSgId =
|
||||
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
|
||||
|
||||
int64_t startOfRange = -1, endOfRange = -1;
|
||||
bool sgIdRangeSpecified =
|
||||
isSgIdRangeSpecified(op, startOfRange, endOfRange);
|
||||
|
||||
if (sgIdRangeSpecified) {
|
||||
int64_t sgCount = endOfRange - startOfRange;
|
||||
if (computeProduct(sgLayout) != sgCount)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "sg_layout size must match the sg_id_range");
|
||||
// Subtract startOfRange from the original subgroup id to get
|
||||
// the adjusted sg id
|
||||
Value startOfRangeVal =
|
||||
arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
|
||||
linearSgId =
|
||||
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
|
||||
}
|
||||
|
||||
auto maybeTdescOffsets =
|
||||
layout.getOffsets(rewriter, loc, linearSgId, wgShape);
|
||||
if (failed(maybeTdescOffsets))
|
||||
return failure();
|
||||
|
||||
Type elemTy = tdescTy.getElementType();
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
||||
xegpu::TensorDescType newTdescTy =
|
||||
auto newTdescTy =
|
||||
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
|
||||
layout.dropSgLayoutAndData());
|
||||
|
||||
SmallVector<Value> newCreateNdOps;
|
||||
SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
|
||||
|
||||
for (auto tdescOffsets : *maybeTdescOffsets) {
|
||||
SmallVector<OpFoldResult> sgOffsets;
|
||||
size_t rank = tdescOffsets.size();
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
size_t idx = origOffsets.size() - rank + i;
|
||||
Value add = rewriter.createOrFold<index::AddOp>(
|
||||
loc, tdescOffsets[i],
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
|
||||
sgOffsets.push_back(add);
|
||||
}
|
||||
|
||||
SmallVector<Value> newOps;
|
||||
for (auto offsets : offsetsList) {
|
||||
auto newOp = xegpu::CreateNdDescOp::create(
|
||||
rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
|
||||
rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
|
||||
op.getMixedSizes(), op.getMixedStrides());
|
||||
newCreateNdOps.push_back(newOp);
|
||||
|
||||
newOps.push_back(newOp);
|
||||
}
|
||||
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
|
||||
rewriter.replaceOpWithMultiple(op, {newOps});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -223,7 +226,7 @@ struct WgToSgCreateNdOpNoOffset
|
||||
MLIRContext *ctx = op.getContext();
|
||||
xegpu::TensorDescType tdescTy = op.getType();
|
||||
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
|
||||
if (!layout || !layout.isWgLayout())
|
||||
if (!layout || !layout.isForWorkgroup())
|
||||
return failure();
|
||||
|
||||
Type elemTy = tdescTy.getElementType();
|
||||
@ -254,12 +257,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> newLoadOps;
|
||||
|
||||
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
||||
if ((offsetSize != 0) || op.getConstOffsetsAttr())
|
||||
if (!op.getMixedOffsets().empty())
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> newLoadOps;
|
||||
for (auto src : adaptor.getTensorDesc()) {
|
||||
xegpu::TensorDescType tdescTy =
|
||||
dyn_cast<xegpu::TensorDescType>(src.getType());
|
||||
@ -282,9 +283,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
||||
if ((offsetSize != 0) || op.getConstOffsetsAttr())
|
||||
if (!op.getMixedOffsets().empty())
|
||||
return failure();
|
||||
|
||||
for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
|
||||
@ -296,100 +295,6 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Utility function to compute global offsets for subgroup operations.
|
||||
// Returns a vector of new offsets for each subgroup, given the original op's
|
||||
// offsets and subgroup relative offsets.
|
||||
static SmallVector<SmallVector<OpFoldResult>>
|
||||
computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
|
||||
ArrayRef<OpFoldResult> origOffsets,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
SmallVector<SmallVector<OpFoldResult>> finalOffsets;
|
||||
Location loc = op->getLoc();
|
||||
for (const auto &sgOffsets : sgOffsetsList) {
|
||||
SmallVector<OpFoldResult> newOffsets;
|
||||
size_t rank = sgOffsets.size();
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
size_t idx = origOffsets.size() - rank + i;
|
||||
Value add = rewriter.createOrFold<index::AddOp>(
|
||||
loc, sgOffsets[i],
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
|
||||
newOffsets.push_back(add);
|
||||
}
|
||||
finalOffsets.push_back(std::move(newOffsets));
|
||||
}
|
||||
return finalOffsets;
|
||||
}
|
||||
|
||||
// Utility function to get sgShape, sgOffsetList for a given
|
||||
// op.
|
||||
template <typename OpTy, typename AdaptorTy>
|
||||
LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<int64_t> &sgShape,
|
||||
SmallVector<SmallVector<Value>> &sgOffsetList) {
|
||||
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
||||
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value tdesc = op.getTensorDesc();
|
||||
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
|
||||
if (!tdescTy)
|
||||
return failure();
|
||||
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
|
||||
if (!layout)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> sgLayout;
|
||||
auto sgLayoutAttr = layout.getSgLayout();
|
||||
if (!sgLayoutAttr)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "sgLayout attribute is required in layout");
|
||||
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
|
||||
|
||||
ArrayRef<int64_t> wgShape = tdescTy.getShape();
|
||||
int count;
|
||||
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
|
||||
|
||||
// Get the subgroup ID
|
||||
Value linearSgId =
|
||||
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
|
||||
|
||||
int64_t startOfRange = -1, endOfRange = -1;
|
||||
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
|
||||
|
||||
if (sgIdRangeSpecified) {
|
||||
int64_t sgCount = endOfRange - startOfRange;
|
||||
if (computeProduct(sgLayout) != sgCount)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "sg_layout size must match the sg_id_range");
|
||||
Value startOfRangeVal =
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
|
||||
linearSgId =
|
||||
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
|
||||
}
|
||||
|
||||
auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
|
||||
if (failed(sgOffsets))
|
||||
return failure();
|
||||
|
||||
sgOffsetList = *sgOffsets;
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
SmallVector<OpFoldResult> getOffsets(OpTy op,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
SmallVector<OpFoldResult> origOffsets;
|
||||
if (auto constOffsets = op.getConstOffsetsAttr()) {
|
||||
for (auto attr : constOffsets.asArrayRef())
|
||||
origOffsets.push_back(rewriter.getIndexAttr(attr));
|
||||
}
|
||||
for (auto v : op.getOffsets())
|
||||
origOffsets.push_back(v);
|
||||
return origOffsets;
|
||||
}
|
||||
|
||||
// This pattern transforms the LoadNdOp with explicit offsets to load
|
||||
// subgroup data.
|
||||
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
|
||||
@ -398,33 +303,24 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
|
||||
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<int64_t> sgShape;
|
||||
SmallVector<SmallVector<Value>> sgOffsetList;
|
||||
|
||||
// Do the distribution from workgroup to subgroup and get subgroup offsets
|
||||
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
|
||||
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
// Get the original workgroup offsets
|
||||
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
|
||||
|
||||
// Calculate the final offsets for each subgroup
|
||||
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
|
||||
|
||||
SmallVector<Value> newLoadOps;
|
||||
for (auto [offsets, tdesc] :
|
||||
llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
|
||||
VectorType newResTy = VectorType::get(
|
||||
sgShape,
|
||||
dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
|
||||
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
|
||||
op.getLoc(), newResTy, tdesc, offsets,
|
||||
/*packed=*/nullptr,
|
||||
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
|
||||
op.getL3HintAttr());
|
||||
newLoadOps.push_back(newLoadOp);
|
||||
SmallVector<Value> newOps;
|
||||
for (auto [tdesc, offsets] :
|
||||
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
|
||||
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
|
||||
VectorType newResTy =
|
||||
VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
|
||||
auto newOp = xegpu::LoadNdOp::create(
|
||||
rewriter, op.getLoc(), newResTy, tdesc, offsets,
|
||||
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
|
||||
op.getL2HintAttr(), op.getL3HintAttr());
|
||||
newOps.push_back(newOp);
|
||||
}
|
||||
rewriter.replaceOpWithMultiple(op, {newLoadOps});
|
||||
rewriter.replaceOpWithMultiple(op, {newOps});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -437,27 +333,18 @@ struct WgToSgStoreNdOpWithOffset
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<int64_t> sgShape;
|
||||
SmallVector<SmallVector<Value>> sgOffsetList;
|
||||
|
||||
// Do the distribution from workgroup to subgroup and get subgroup offsets
|
||||
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
|
||||
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
// Get the original workgroup offsets
|
||||
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
|
||||
|
||||
// Calculate the final offsets for each subgroup
|
||||
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
|
||||
|
||||
for (auto [offsets, tdesc, value] :
|
||||
llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
|
||||
rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
|
||||
for (auto [v, tdesc, offsets] :
|
||||
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
|
||||
rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets,
|
||||
op.getL1HintAttr(), op.getL2HintAttr(),
|
||||
op.getL3HintAttr());
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -470,27 +357,18 @@ struct WgToSgPrefetchNdOpWithOffset
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<int64_t> sgShape;
|
||||
SmallVector<SmallVector<Value>> sgOffsetList;
|
||||
|
||||
// Do the distribution from workgroup to subgroup and get subgroup offsets
|
||||
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
|
||||
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
// Get the original workgroup offsets
|
||||
SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
|
||||
|
||||
// Calculate the final offsets for each subgroup
|
||||
auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
|
||||
|
||||
for (auto [offsets, tdesc] :
|
||||
llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
|
||||
for (auto [tdesc, offsets] :
|
||||
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
|
||||
rewriter.create<xegpu::PrefetchNdOp>(
|
||||
op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
|
||||
op.getL3HintAttr());
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -736,7 +614,8 @@ struct WgToSgConvertLayoutOp
|
||||
xegpu::LayoutAttr input = op.getInputLayout();
|
||||
xegpu::LayoutAttr target = op.getTargetLayout();
|
||||
|
||||
if (!input || !target || !input.isWgLayout() || !target.isWgLayout())
|
||||
if (!input || !target || !input.isForWorkgroup() ||
|
||||
!target.isForWorkgroup())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input and target layouts must have subgroup layout");
|
||||
|
||||
@ -884,6 +763,56 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
|
||||
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> wgShape = op.getDataShape();
|
||||
VectorType valueTy = op.getRes().getType();
|
||||
Type elemTy = valueTy.getElementType();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
||||
VectorType newResTy = VectorType::get(sgShape, elemTy);
|
||||
SmallVector<Value> newOps;
|
||||
for (auto offsets : offsetsList) {
|
||||
auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
|
||||
op.getLoc(), newResTy, op.getMemDesc(), offsets,
|
||||
layout.dropSgLayoutAndData());
|
||||
newOps.push_back(newOp);
|
||||
}
|
||||
rewriter.replaceOpWithMultiple(op, {newOps});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
|
||||
using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<SmallVector<OpFoldResult>> offsetsList;
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
|
||||
rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(),
|
||||
offsets,
|
||||
layout.dropSgLayoutAndData());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
@ -895,7 +824,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
|
||||
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
|
||||
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
|
||||
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
|
||||
WgToSgArithConstantOp>(patterns.getContext());
|
||||
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
|
||||
patterns.getContext());
|
||||
}
|
||||
} // namespace xegpu
|
||||
} // namespace mlir
|
||||
@ -985,8 +915,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
|
||||
return xegpu::TensorDescType();
|
||||
};
|
||||
|
||||
auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
|
||||
return !layout || !layout.isWgLayout();
|
||||
auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
|
||||
return !layout || !layout.isForWorkgroup();
|
||||
};
|
||||
|
||||
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
|
||||
@ -1002,9 +932,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
|
||||
return isLegal(layout);
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<vector::BroadcastOp>(
|
||||
[=](vector::BroadcastOp op) -> bool {
|
||||
return isLegal(xegpu::getLayoutAttr(op.getResult()));
|
||||
target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
|
||||
[=](xegpu::LoadMatrixOp op) -> bool {
|
||||
return isLegal(op.getLayoutAttr());
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
|
||||
[=](xegpu::StoreMatrixOp op) -> bool {
|
||||
return isLegal(op.getLayoutAttr());
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<arith::ConstantOp>(
|
||||
@ -1015,6 +950,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
|
||||
return isLegal(xegpu::getLayoutAttr(op.getResult()));
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<vector::BroadcastOp>(
|
||||
[=](vector::BroadcastOp op) -> bool {
|
||||
return isLegal(xegpu::getLayoutAttr(op.getResult()));
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
|
||||
[=](xegpu::ConvertLayoutOp op) -> bool {
|
||||
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
|
||||
|
@ -12,6 +12,7 @@
|
||||
|
||||
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
@ -40,7 +41,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
|
||||
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
|
||||
// It only works for subgroup level layout, which only has lane_layout
|
||||
// and lane_data, and is to distribute a SIMD code into SIMT code.
|
||||
if (!layout || !layout.isSgLayout())
|
||||
if (!layout || !layout.isForSubgroup())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
|
||||
@ -424,3 +425,31 @@ std::optional<std::string> xegpu::getChipStr(Operation *op) {
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/// Generates element-wise addition ops of two arrays with automatic alignment.
|
||||
/// When the input arrays have different sizes, the shorter array is
|
||||
/// right-aligned with the longer array, and the unmatched leading elements from
|
||||
/// the longer array are preserved unchanged. This is commonly used for offset
|
||||
/// computation where higher-dimensional offsets need to be added to
|
||||
/// lower-dimensional adjustments.
|
||||
///
|
||||
/// Example:
|
||||
/// lhs = [l1, l2, l3], rhs = [r1, r2]
|
||||
/// Result: [11, l2+r1, l3+r2]
|
||||
SmallVector<OpFoldResult>
|
||||
xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
|
||||
ArrayRef<OpFoldResult> lhs,
|
||||
ArrayRef<OpFoldResult> rhs) {
|
||||
// ensure a is longer than b
|
||||
ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
|
||||
ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
|
||||
SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
|
||||
a = a.slice(a.size() - b.size());
|
||||
for (auto [l, r] : llvm::zip(a, b)) {
|
||||
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
|
||||
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
|
||||
results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
|
||||
}
|
||||
return results;
|
||||
return {};
|
||||
}
|
||||
|
@ -263,4 +263,62 @@ gpu.module @test_distribution {
|
||||
} {sg_id_range = #xegpu.range<[3, 19]>}
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: distribute_load_matrix
|
||||
// CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
|
||||
gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
|
||||
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
|
||||
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
|
||||
//CHECK: [[c2:%.+]] = arith.constant 2 : index
|
||||
//CHECK: [[c4:%.+]] = arith.constant 4 : index
|
||||
//CHECK: [[c4_0:%.+]] = arith.constant 4 : index
|
||||
//CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
|
||||
//CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
|
||||
//CHECK: [[c32:%.+]] = arith.constant 32 : index
|
||||
//CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
|
||||
//CHECK: [[c32_1:%.+]] = arith.constant 32 : index
|
||||
//CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
|
||||
//CHECK: [[c0:%.+]] = arith.constant 0 : index
|
||||
//CHECK: [[c0_1:%.+]] = arith.constant 0 : index
|
||||
//CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
|
||||
//CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
|
||||
//CHECK: [[c64:%.+]] = arith.constant 64 : index
|
||||
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
|
||||
//CHECK: [[c128:%.+]] = arith.constant 128 : index
|
||||
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
|
||||
//CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
|
||||
%0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
|
||||
%1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
//CHECK-LABEL: distribute_store_matrix
|
||||
//CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
|
||||
gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) {
|
||||
//CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
|
||||
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
|
||||
//CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
|
||||
//CHECK: [[c2:%.+]] = arith.constant 2 : index
|
||||
//CHECK: [[c4:%.+]] = arith.constant 4 : index
|
||||
//CHECK: [[c4_0:%.+]] = arith.constant 4 : index
|
||||
//CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
|
||||
//CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
|
||||
//CHECK: [[c32:%.+]] = arith.constant 32 : index
|
||||
//CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
|
||||
//CHECK: [[c32_1:%.+]] = arith.constant 32 : index
|
||||
//CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
|
||||
//CHECK: [[c0:%.+]] = arith.constant 0 : index
|
||||
//CHECK: [[c0_2:%.+]] = arith.constant 0 : index
|
||||
//CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
|
||||
//CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
|
||||
//CHECK: [[c64:%.+]] = arith.constant 64 : index
|
||||
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
|
||||
//CHECK: [[c128:%.+]] = arith.constant 128 : index
|
||||
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
|
||||
//CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index
|
||||
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32>
|
||||
%mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
|
||||
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ struct TestXeGPUUnrollingPatterns
|
||||
|
||||
if (auto layout = tdescTy.getLayoutAttr()) {
|
||||
auto inst_data = layout.getInstData();
|
||||
if (inst_data && layout.isSgLayout())
|
||||
if (inst_data && layout.isForSubgroup())
|
||||
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
|
||||
inst_data.asArrayRef().end());
|
||||
}
|
||||
@ -156,8 +156,8 @@ struct TestXeGPUUnrollingPatterns
|
||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
|
||||
|
||||
// Test pattern for distributing vector::StepOp from workgroup to subgroup.
|
||||
// Validates LayoutTrait interfaces for offset computation abstraction between
|
||||
// LayoutAttr and SliceAttr.
|
||||
// Validates DistributeLayoutAttr interfaces for offset computation
|
||||
// abstraction between LayoutAttr and SliceAttr.
|
||||
class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
|
||||
using OpConversionPattern<vector::StepOp>::OpConversionPattern;
|
||||
|
||||
@ -239,7 +239,7 @@ struct TestXeGPULayoutInterface
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
auto isLegal = [&](xegpu::SliceAttr layout) -> bool {
|
||||
return !layout || !layout.isWgLayout();
|
||||
return !layout || !layout.isForWorkgroup();
|
||||
};
|
||||
|
||||
target.addDynamicallyLegalOp<vector::StepOp>(
|
||||
|
Loading…
x
Reference in New Issue
Block a user